mirror of
https://github.com/unclechu/gRPC-haskell.git
synced 2024-11-26 21:19:43 +01:00
Added tests for grpc server generation
This commit is contained in:
commit
89aa17b5e6
8 changed files with 254 additions and 2 deletions
14
README.md
Normal file
14
README.md
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
Running the tests
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
In order to run the tests, you will need to have the `grpcio`, `gevent`, and
|
||||||
|
`grpcio-tools` python packages installed. You can install them using
|
||||||
|
`pip`. It is recommended that you use a python virtualenv to do this.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ virtualenv path/to/virtualenv # to create a virtualenv
|
||||||
|
$ . path/to/virtual/env/bin/activate # to use an existing virtualenv
|
||||||
|
$ pip install grpcio-tools gevent
|
||||||
|
$ pip install grpcio # Need to install grpcio-tools first to avoid a versioning problem
|
||||||
|
```
|
||||||
|
|
|
@ -40,6 +40,7 @@ library
|
||||||
, tasty >= 0.11 && <0.12
|
, tasty >= 0.11 && <0.12
|
||||||
, tasty-hunit >= 0.9 && <0.10
|
, tasty-hunit >= 0.9 && <0.10
|
||||||
, safe
|
, safe
|
||||||
|
, vector
|
||||||
|
|
||||||
c-sources:
|
c-sources:
|
||||||
cbits/grpc_haskell.c
|
cbits/grpc_haskell.c
|
||||||
|
@ -139,13 +140,17 @@ test-suite test
|
||||||
, containers ==0.5.*
|
, containers ==0.5.*
|
||||||
, managed >= 1.0.5 && < 1.1
|
, managed >= 1.0.5 && < 1.1
|
||||||
, pipes ==4.1.*
|
, pipes ==4.1.*
|
||||||
|
, protobuf-wire
|
||||||
, transformers
|
, transformers
|
||||||
, safe
|
, safe
|
||||||
, clock ==0.6.*
|
, clock ==0.6.*
|
||||||
|
, turtle >= 1.2.0
|
||||||
|
, text
|
||||||
other-modules:
|
other-modules:
|
||||||
LowLevelTests,
|
LowLevelTests,
|
||||||
LowLevelTests.Op,
|
LowLevelTests.Op,
|
||||||
UnsafeTests
|
UnsafeTests,
|
||||||
|
GeneratedTests
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
ghc-options: -Wall -fwarn-incomplete-patterns -fno-warn-unused-do-bind -g -threaded -rtsopts
|
ghc-options: -Wall -fwarn-incomplete-patterns -fno-warn-unused-do-bind -g -threaded -rtsopts
|
||||||
hs-source-dirs: tests
|
hs-source-dirs: tests
|
||||||
|
|
|
@ -9,7 +9,7 @@ packages:
|
||||||
- '.'
|
- '.'
|
||||||
- location:
|
- location:
|
||||||
git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git
|
git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git
|
||||||
commit: 57490f59e90bd0881f28c7898a152d4bd25c9fcb
|
commit: 62aa5b92f21883d14bf8d3beed5645f84da01ad6
|
||||||
extra-dep: true
|
extra-dep: true
|
||||||
- location:
|
- location:
|
||||||
git: git@github.com:awakenetworks/proto3-wire.git
|
git: git@github.com:awakenetworks/proto3-wire.git
|
||||||
|
|
58
tests/GeneratedTests.hs
Normal file
58
tests/GeneratedTests.hs
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
module GeneratedTests where
|
||||||
|
|
||||||
|
import Test.Tasty
|
||||||
|
import Test.Tasty.HUnit (testCase, (@?=))
|
||||||
|
|
||||||
|
import Data.String
|
||||||
|
import Data.Protobuf.Wire.DotProto.Generate
|
||||||
|
import qualified Data.Text as T
|
||||||
|
|
||||||
|
import Turtle
|
||||||
|
|
||||||
|
generatedTests :: TestTree
|
||||||
|
generatedTests = testGroup "Code generator tests"
|
||||||
|
[ testServerGeneration ]
|
||||||
|
|
||||||
|
testServerGeneration :: TestTree
|
||||||
|
testServerGeneration = testCase "server generation" $ do
|
||||||
|
mktree hsTmpDir
|
||||||
|
mktree pyTmpDir
|
||||||
|
|
||||||
|
compileSimpleDotProto
|
||||||
|
|
||||||
|
exitCode <- shell (T.concat ["stack ghc -- --make -threaded -odir ", hsTmpDir, " -hidir ", hsTmpDir, " -o ", hsTmpDir, "/simple-server ", hsTmpDir, "/Simple.hs tests/TestServer.hs > /dev/null"]) empty
|
||||||
|
exitCode @?= ExitSuccess
|
||||||
|
|
||||||
|
exitCode <- shell (T.concat ["python -m grpc.tools.protoc -I tests --python_out=", pyTmpDir, " --grpc_python_out=", pyTmpDir, " tests/simple.proto"]) empty
|
||||||
|
exitCode @?= ExitSuccess
|
||||||
|
|
||||||
|
runManaged $ do
|
||||||
|
serverExitCodeA <- fork (shell (hsTmpDir <> "/simple-server") empty)
|
||||||
|
clientExitCodeA <- fork
|
||||||
|
(export "PYTHONPATH" pyTmpDir >> shell "python tests/test-client.py" empty)
|
||||||
|
|
||||||
|
liftIO $ do
|
||||||
|
serverExitCode <- liftIO (wait serverExitCodeA)
|
||||||
|
clientExitCode <- liftIO (wait clientExitCodeA)
|
||||||
|
|
||||||
|
serverExitCode @?= ExitSuccess
|
||||||
|
clientExitCode @?= ExitSuccess
|
||||||
|
|
||||||
|
rmtree hsTmpDir
|
||||||
|
rmtree pyTmpDir
|
||||||
|
|
||||||
|
hsTmpDir, pyTmpDir :: IsString a => a
|
||||||
|
hsTmpDir = "tests/tmp"
|
||||||
|
pyTmpDir = "tests/py-tmp"
|
||||||
|
|
||||||
|
compileSimpleDotProto :: IO ()
|
||||||
|
compileSimpleDotProto =
|
||||||
|
do dpRes <- readDotProtoWithContext "tests/simple.proto"
|
||||||
|
case dpRes of
|
||||||
|
Left err -> fail (show err)
|
||||||
|
Right (dp, ctxt) ->
|
||||||
|
case renderHsModuleForDotProto dp ctxt of
|
||||||
|
Left err -> fail ("compileSimpleDotProto: Error compiling test.proto: " <> show err)
|
||||||
|
Right hsSrc -> writeFile (hsTmpDir ++ "/Simple.hs") hsSrc
|
|
@ -2,10 +2,12 @@ import LowLevelTests
|
||||||
import LowLevelTests.Op
|
import LowLevelTests.Op
|
||||||
import Test.Tasty
|
import Test.Tasty
|
||||||
import UnsafeTests
|
import UnsafeTests
|
||||||
|
import GeneratedTests
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = defaultMain $ testGroup "GRPC Unit Tests"
|
main = defaultMain $ testGroup "GRPC Unit Tests"
|
||||||
[ unsafeTests
|
[ unsafeTests
|
||||||
, lowLevelOpTests
|
, lowLevelOpTests
|
||||||
, lowLevelTests
|
, lowLevelTests
|
||||||
|
, generatedTests
|
||||||
]
|
]
|
||||||
|
|
71
tests/TestServer.hs
Normal file
71
tests/TestServer.hs
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Prelude hiding (sum)
|
||||||
|
|
||||||
|
import Simple
|
||||||
|
|
||||||
|
import Control.Concurrent
|
||||||
|
import Control.Concurrent.MVar
|
||||||
|
import Control.Monad
|
||||||
|
import Control.Monad.IO.Class
|
||||||
|
|
||||||
|
import Data.Monoid
|
||||||
|
import Data.Foldable (sum)
|
||||||
|
import Data.String
|
||||||
|
|
||||||
|
import Network.GRPC.LowLevel
|
||||||
|
|
||||||
|
handleNormalCall :: ServerCall SimpleServiceRequest -> IO (SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
|
||||||
|
handleNormalCall call =
|
||||||
|
pure (SimpleServiceResponse request result, mempty, StatusOk, StatusDetails "")
|
||||||
|
where SimpleServiceRequest request nums = payload call
|
||||||
|
|
||||||
|
result = sum nums
|
||||||
|
|
||||||
|
handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> Streaming (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
|
||||||
|
handleClientStreamingCall call recvRequest = go 0 ""
|
||||||
|
where go sumAccum nameAccum =
|
||||||
|
recvRequest >>= \req ->
|
||||||
|
case req of
|
||||||
|
Left ioError -> pure (Nothing, mempty, StatusCancelled, StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError)))
|
||||||
|
Right Nothing ->
|
||||||
|
pure (Just (SimpleServiceResponse nameAccum sumAccum), mempty, StatusOk, StatusDetails "")
|
||||||
|
Right (Just (SimpleServiceRequest name nums)) ->
|
||||||
|
go (sumAccum + sum nums) (nameAccum <> name)
|
||||||
|
|
||||||
|
handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails)
|
||||||
|
handleServerStreamingCall call sendResponse = go
|
||||||
|
where go = do forM_ nums $ \num ->
|
||||||
|
sendResponse (SimpleServiceResponse requestName num)
|
||||||
|
pure (mempty, StatusOk, StatusDetails "")
|
||||||
|
|
||||||
|
SimpleServiceRequest requestName nums = payload call
|
||||||
|
|
||||||
|
handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails)
|
||||||
|
handleBiDiStreamingCall call recvRequest sendResponse = go
|
||||||
|
where go = recvRequest >>= \req ->
|
||||||
|
case req of
|
||||||
|
Left ioError -> pure (mempty, StatusCancelled, StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError)))
|
||||||
|
Right Nothing ->
|
||||||
|
pure (mempty, StatusOk, StatusDetails "")
|
||||||
|
Right (Just (SimpleServiceRequest name nums)) ->
|
||||||
|
do sendResponse (SimpleServiceResponse name (sum nums))
|
||||||
|
go
|
||||||
|
|
||||||
|
handleDone :: MVar () -> ServerCall SimpleServiceDone -> IO (SimpleServiceDone, MetadataMap, StatusCode, StatusDetails)
|
||||||
|
handleDone exitVar req =
|
||||||
|
do forkIO (threadDelay 5000 >> putMVar exitVar ())
|
||||||
|
pure (payload req, mempty, StatusOk, StatusDetails "")
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = do exitVar <- newEmptyMVar
|
||||||
|
|
||||||
|
forkIO $ simpleServiceServer SimpleService
|
||||||
|
{ simpleServiceDone = handleDone exitVar
|
||||||
|
, simpleServiceNormalCall = handleNormalCall
|
||||||
|
, simpleServiceClientStreamingCall = handleClientStreamingCall
|
||||||
|
, simpleServiceServerStreamingCall = handleServerStreamingCall
|
||||||
|
, simpleServiceBiDiStreamingCall = handleBiDiStreamingCall }
|
||||||
|
|
||||||
|
takeMVar exitVar
|
24
tests/simple.proto
Normal file
24
tests/simple.proto
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
syntax = "proto3";
|
||||||
|
package simple;
|
||||||
|
|
||||||
|
service SimpleService {
|
||||||
|
rpc normalCall (SimpleServiceRequest) returns (SimpleServiceResponse) {}
|
||||||
|
rpc clientStreamingCall (stream SimpleServiceRequest) returns (SimpleServiceResponse) {}
|
||||||
|
rpc serverStreamingCall (SimpleServiceRequest) returns (stream SimpleServiceResponse) {}
|
||||||
|
rpc biDiStreamingCall (stream SimpleServiceRequest) returns (stream SimpleServiceResponse) {}
|
||||||
|
|
||||||
|
rpc done (SimpleServiceDone) returns (SimpleServiceDone) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
message SimpleServiceDone {
|
||||||
|
}
|
||||||
|
|
||||||
|
message SimpleServiceRequest {
|
||||||
|
string request = 1;
|
||||||
|
repeated fixed32 num = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SimpleServiceResponse {
|
||||||
|
string response = 1;
|
||||||
|
fixed32 num = 2;
|
||||||
|
}
|
78
tests/test-client.py
Normal file
78
tests/test-client.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
from simple_pb2 import *
|
||||||
|
from uuid import uuid4
|
||||||
|
import random
|
||||||
|
import Queue
|
||||||
|
|
||||||
|
print "Starting python client"
|
||||||
|
|
||||||
|
channel = beta_implementations.insecure_channel('localhost', 50051)
|
||||||
|
stub = beta_create_SimpleService_stub(channel)
|
||||||
|
|
||||||
|
# Test normal call: return a sum of all numbers sent to it
|
||||||
|
print "Test 100 random sums"
|
||||||
|
for i in xrange(100):
|
||||||
|
randints = [random.randint(0, 1000) for _ in xrange(random.randint(10, 1000))]
|
||||||
|
name = "test%d" % i
|
||||||
|
response = stub.normalCall(SimpleServiceRequest(request = name, num = randints), 10)
|
||||||
|
|
||||||
|
assert response.response == name
|
||||||
|
assert response.num == sum(randints)
|
||||||
|
|
||||||
|
# Test streaming call: The server response will be the sum of all numbers sent in the request along with a concatenation of the request name
|
||||||
|
print "Test 100 random sums (client streaming)"
|
||||||
|
for i in xrange(100):
|
||||||
|
expected_sum = 0
|
||||||
|
expected_response_name = ''
|
||||||
|
|
||||||
|
def send_requests():
|
||||||
|
global expected_sum
|
||||||
|
global expected_response_name
|
||||||
|
|
||||||
|
for _ in xrange(random.randint(5, 50)):
|
||||||
|
nums = [random.randint(0, 1000) for _ in xrange(random.randint(10, 100))]
|
||||||
|
name = str(uuid4())
|
||||||
|
|
||||||
|
expected_sum += sum(nums)
|
||||||
|
expected_response_name += name
|
||||||
|
|
||||||
|
yield SimpleServiceRequest(request = name, num = nums)
|
||||||
|
|
||||||
|
response = stub.clientStreamingCall(send_requests(), 10)
|
||||||
|
assert response.response == expected_response_name
|
||||||
|
assert response.num == expected_sum
|
||||||
|
|
||||||
|
# Test server streaming call: The server should respond once for each number in the request
|
||||||
|
print "Test 100 random server streaming calls"
|
||||||
|
for i in xrange(100):
|
||||||
|
nums = [random.randint(0, 1000) for _ in xrange(random.randint(0, 1000))]
|
||||||
|
|
||||||
|
for response in stub.serverStreamingCall(SimpleServiceRequest(request = "server streaming", num = nums), 60):
|
||||||
|
assert response.num == nums[0]
|
||||||
|
assert response.response == "server streaming"
|
||||||
|
nums = nums[1:]
|
||||||
|
|
||||||
|
# Test bidirectional streaming: for each request, we should get a response indicating the sum of all numbers sent in the last request
|
||||||
|
print "Test bidirectional streaming"
|
||||||
|
for i in xrange(100):
|
||||||
|
requests = Queue.Queue()
|
||||||
|
def send_requests():
|
||||||
|
global cur_request
|
||||||
|
global cur_nums
|
||||||
|
global requests
|
||||||
|
|
||||||
|
for _ in xrange(random.randint(5, 50)):
|
||||||
|
nums = [random.randint(0, 1000) for _ in xrange(random.randint(10, 100))]
|
||||||
|
name = str(uuid4())
|
||||||
|
|
||||||
|
requests.put((name, sum(nums)))
|
||||||
|
|
||||||
|
yield SimpleServiceRequest(request = name, num = nums)
|
||||||
|
|
||||||
|
for response in stub.biDiStreamingCall(send_requests(), 10):
|
||||||
|
(exp_name, exp_sum) = requests.get()
|
||||||
|
|
||||||
|
assert response.response == exp_name
|
||||||
|
assert response.num == exp_sum
|
||||||
|
|
||||||
|
# Signal the ending of the test
|
||||||
|
stub.done(SimpleServiceDone(), 10)
|
Loading…
Reference in a new issue