mirror of
https://github.com/unclechu/gRPC-haskell.git
synced 2024-11-05 02:39:42 +01:00
Added tests for grpc server generation
This commit is contained in:
commit
89aa17b5e6
8 changed files with 254 additions and 2 deletions
14
README.md
Normal file
14
README.md
Normal file
|
@ -0,0 +1,14 @@
|
|||
Running the tests
|
||||
-----------------
|
||||
|
||||
In order to run the tests, you will need to have the `grpcio`, `gevent`, and
|
||||
`grpcio-tools` python packages installed. You can install them using
|
||||
`pip`. It is recommended that you use a python virtualenv to do this.
|
||||
|
||||
```
|
||||
$ virtualenv path/to/virtualenv # to create a virtualenv
|
||||
$ . path/to/virtual/env/bin/activate # to use an existing virtualenv
|
||||
$ pip install grpcio-tools gevent
|
||||
$ pip install grpcio # Need to install grpcio-tools first to avoid a versioning problem
|
||||
```
|
||||
|
|
@ -40,6 +40,7 @@ library
|
|||
, tasty >= 0.11 && <0.12
|
||||
, tasty-hunit >= 0.9 && <0.10
|
||||
, safe
|
||||
, vector
|
||||
|
||||
c-sources:
|
||||
cbits/grpc_haskell.c
|
||||
|
@ -139,13 +140,17 @@ test-suite test
|
|||
, containers ==0.5.*
|
||||
, managed >= 1.0.5 && < 1.1
|
||||
, pipes ==4.1.*
|
||||
, protobuf-wire
|
||||
, transformers
|
||||
, safe
|
||||
, clock ==0.6.*
|
||||
, turtle >= 1.2.0
|
||||
, text
|
||||
other-modules:
|
||||
LowLevelTests,
|
||||
LowLevelTests.Op,
|
||||
UnsafeTests
|
||||
UnsafeTests,
|
||||
GeneratedTests
|
||||
default-language: Haskell2010
|
||||
ghc-options: -Wall -fwarn-incomplete-patterns -fno-warn-unused-do-bind -g -threaded -rtsopts
|
||||
hs-source-dirs: tests
|
||||
|
|
|
@ -9,7 +9,7 @@ packages:
|
|||
- '.'
|
||||
- location:
|
||||
git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git
|
||||
commit: 57490f59e90bd0881f28c7898a152d4bd25c9fcb
|
||||
commit: 62aa5b92f21883d14bf8d3beed5645f84da01ad6
|
||||
extra-dep: true
|
||||
- location:
|
||||
git: git@github.com:awakenetworks/proto3-wire.git
|
||||
|
|
58
tests/GeneratedTests.hs
Normal file
58
tests/GeneratedTests.hs
Normal file
|
@ -0,0 +1,58 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module GeneratedTests where
|
||||
|
||||
import Test.Tasty
|
||||
import Test.Tasty.HUnit (testCase, (@?=))
|
||||
|
||||
import Data.String
|
||||
import Data.Protobuf.Wire.DotProto.Generate
|
||||
import qualified Data.Text as T
|
||||
|
||||
import Turtle
|
||||
|
||||
generatedTests :: TestTree
|
||||
generatedTests = testGroup "Code generator tests"
|
||||
[ testServerGeneration ]
|
||||
|
||||
testServerGeneration :: TestTree
|
||||
testServerGeneration = testCase "server generation" $ do
|
||||
mktree hsTmpDir
|
||||
mktree pyTmpDir
|
||||
|
||||
compileSimpleDotProto
|
||||
|
||||
exitCode <- shell (T.concat ["stack ghc -- --make -threaded -odir ", hsTmpDir, " -hidir ", hsTmpDir, " -o ", hsTmpDir, "/simple-server ", hsTmpDir, "/Simple.hs tests/TestServer.hs > /dev/null"]) empty
|
||||
exitCode @?= ExitSuccess
|
||||
|
||||
exitCode <- shell (T.concat ["python -m grpc.tools.protoc -I tests --python_out=", pyTmpDir, " --grpc_python_out=", pyTmpDir, " tests/simple.proto"]) empty
|
||||
exitCode @?= ExitSuccess
|
||||
|
||||
runManaged $ do
|
||||
serverExitCodeA <- fork (shell (hsTmpDir <> "/simple-server") empty)
|
||||
clientExitCodeA <- fork
|
||||
(export "PYTHONPATH" pyTmpDir >> shell "python tests/test-client.py" empty)
|
||||
|
||||
liftIO $ do
|
||||
serverExitCode <- liftIO (wait serverExitCodeA)
|
||||
clientExitCode <- liftIO (wait clientExitCodeA)
|
||||
|
||||
serverExitCode @?= ExitSuccess
|
||||
clientExitCode @?= ExitSuccess
|
||||
|
||||
rmtree hsTmpDir
|
||||
rmtree pyTmpDir
|
||||
|
||||
hsTmpDir, pyTmpDir :: IsString a => a
|
||||
hsTmpDir = "tests/tmp"
|
||||
pyTmpDir = "tests/py-tmp"
|
||||
|
||||
compileSimpleDotProto :: IO ()
|
||||
compileSimpleDotProto =
|
||||
do dpRes <- readDotProtoWithContext "tests/simple.proto"
|
||||
case dpRes of
|
||||
Left err -> fail (show err)
|
||||
Right (dp, ctxt) ->
|
||||
case renderHsModuleForDotProto dp ctxt of
|
||||
Left err -> fail ("compileSimpleDotProto: Error compiling test.proto: " <> show err)
|
||||
Right hsSrc -> writeFile (hsTmpDir ++ "/Simple.hs") hsSrc
|
|
@ -2,10 +2,12 @@ import LowLevelTests
|
|||
import LowLevelTests.Op
|
||||
import Test.Tasty
|
||||
import UnsafeTests
|
||||
import GeneratedTests
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain $ testGroup "GRPC Unit Tests"
|
||||
[ unsafeTests
|
||||
, lowLevelOpTests
|
||||
, lowLevelTests
|
||||
, generatedTests
|
||||
]
|
||||
|
|
71
tests/TestServer.hs
Normal file
71
tests/TestServer.hs
Normal file
|
@ -0,0 +1,71 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
module Main where
|
||||
|
||||
import Prelude hiding (sum)
|
||||
|
||||
import Simple
|
||||
|
||||
import Control.Concurrent
|
||||
import Control.Concurrent.MVar
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Class
|
||||
|
||||
import Data.Monoid
|
||||
import Data.Foldable (sum)
|
||||
import Data.String
|
||||
|
||||
import Network.GRPC.LowLevel
|
||||
|
||||
handleNormalCall :: ServerCall SimpleServiceRequest -> IO (SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
|
||||
handleNormalCall call =
|
||||
pure (SimpleServiceResponse request result, mempty, StatusOk, StatusDetails "")
|
||||
where SimpleServiceRequest request nums = payload call
|
||||
|
||||
result = sum nums
|
||||
|
||||
handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> Streaming (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
|
||||
handleClientStreamingCall call recvRequest = go 0 ""
|
||||
where go sumAccum nameAccum =
|
||||
recvRequest >>= \req ->
|
||||
case req of
|
||||
Left ioError -> pure (Nothing, mempty, StatusCancelled, StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError)))
|
||||
Right Nothing ->
|
||||
pure (Just (SimpleServiceResponse nameAccum sumAccum), mempty, StatusOk, StatusDetails "")
|
||||
Right (Just (SimpleServiceRequest name nums)) ->
|
||||
go (sumAccum + sum nums) (nameAccum <> name)
|
||||
|
||||
handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails)
|
||||
handleServerStreamingCall call sendResponse = go
|
||||
where go = do forM_ nums $ \num ->
|
||||
sendResponse (SimpleServiceResponse requestName num)
|
||||
pure (mempty, StatusOk, StatusDetails "")
|
||||
|
||||
SimpleServiceRequest requestName nums = payload call
|
||||
|
||||
handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails)
|
||||
handleBiDiStreamingCall call recvRequest sendResponse = go
|
||||
where go = recvRequest >>= \req ->
|
||||
case req of
|
||||
Left ioError -> pure (mempty, StatusCancelled, StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError)))
|
||||
Right Nothing ->
|
||||
pure (mempty, StatusOk, StatusDetails "")
|
||||
Right (Just (SimpleServiceRequest name nums)) ->
|
||||
do sendResponse (SimpleServiceResponse name (sum nums))
|
||||
go
|
||||
|
||||
handleDone :: MVar () -> ServerCall SimpleServiceDone -> IO (SimpleServiceDone, MetadataMap, StatusCode, StatusDetails)
|
||||
handleDone exitVar req =
|
||||
do forkIO (threadDelay 5000 >> putMVar exitVar ())
|
||||
pure (payload req, mempty, StatusOk, StatusDetails "")
|
||||
|
||||
main :: IO ()
|
||||
main = do exitVar <- newEmptyMVar
|
||||
|
||||
forkIO $ simpleServiceServer SimpleService
|
||||
{ simpleServiceDone = handleDone exitVar
|
||||
, simpleServiceNormalCall = handleNormalCall
|
||||
, simpleServiceClientStreamingCall = handleClientStreamingCall
|
||||
, simpleServiceServerStreamingCall = handleServerStreamingCall
|
||||
, simpleServiceBiDiStreamingCall = handleBiDiStreamingCall }
|
||||
|
||||
takeMVar exitVar
|
24
tests/simple.proto
Normal file
24
tests/simple.proto
Normal file
|
@ -0,0 +1,24 @@
|
|||
syntax = "proto3";
|
||||
package simple;
|
||||
|
||||
service SimpleService {
|
||||
rpc normalCall (SimpleServiceRequest) returns (SimpleServiceResponse) {}
|
||||
rpc clientStreamingCall (stream SimpleServiceRequest) returns (SimpleServiceResponse) {}
|
||||
rpc serverStreamingCall (SimpleServiceRequest) returns (stream SimpleServiceResponse) {}
|
||||
rpc biDiStreamingCall (stream SimpleServiceRequest) returns (stream SimpleServiceResponse) {}
|
||||
|
||||
rpc done (SimpleServiceDone) returns (SimpleServiceDone) {}
|
||||
}
|
||||
|
||||
message SimpleServiceDone {
|
||||
}
|
||||
|
||||
message SimpleServiceRequest {
|
||||
string request = 1;
|
||||
repeated fixed32 num = 2;
|
||||
}
|
||||
|
||||
message SimpleServiceResponse {
|
||||
string response = 1;
|
||||
fixed32 num = 2;
|
||||
}
|
78
tests/test-client.py
Normal file
78
tests/test-client.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
from simple_pb2 import *
|
||||
from uuid import uuid4
|
||||
import random
|
||||
import Queue
|
||||
|
||||
print "Starting python client"
|
||||
|
||||
channel = beta_implementations.insecure_channel('localhost', 50051)
|
||||
stub = beta_create_SimpleService_stub(channel)
|
||||
|
||||
# Test normal call: return a sum of all numbers sent to it
|
||||
print "Test 100 random sums"
|
||||
for i in xrange(100):
|
||||
randints = [random.randint(0, 1000) for _ in xrange(random.randint(10, 1000))]
|
||||
name = "test%d" % i
|
||||
response = stub.normalCall(SimpleServiceRequest(request = name, num = randints), 10)
|
||||
|
||||
assert response.response == name
|
||||
assert response.num == sum(randints)
|
||||
|
||||
# Test streaming call: The server response will be the sum of all numbers sent in the request along with a concatenation of the request name
|
||||
print "Test 100 random sums (client streaming)"
|
||||
for i in xrange(100):
|
||||
expected_sum = 0
|
||||
expected_response_name = ''
|
||||
|
||||
def send_requests():
|
||||
global expected_sum
|
||||
global expected_response_name
|
||||
|
||||
for _ in xrange(random.randint(5, 50)):
|
||||
nums = [random.randint(0, 1000) for _ in xrange(random.randint(10, 100))]
|
||||
name = str(uuid4())
|
||||
|
||||
expected_sum += sum(nums)
|
||||
expected_response_name += name
|
||||
|
||||
yield SimpleServiceRequest(request = name, num = nums)
|
||||
|
||||
response = stub.clientStreamingCall(send_requests(), 10)
|
||||
assert response.response == expected_response_name
|
||||
assert response.num == expected_sum
|
||||
|
||||
# Test server streaming call: The server should respond once for each number in the request
|
||||
print "Test 100 random server streaming calls"
|
||||
for i in xrange(100):
|
||||
nums = [random.randint(0, 1000) for _ in xrange(random.randint(0, 1000))]
|
||||
|
||||
for response in stub.serverStreamingCall(SimpleServiceRequest(request = "server streaming", num = nums), 60):
|
||||
assert response.num == nums[0]
|
||||
assert response.response == "server streaming"
|
||||
nums = nums[1:]
|
||||
|
||||
# Test bidirectional streaming: for each request, we should get a response indicating the sum of all numbers sent in the last request
|
||||
print "Test bidirectional streaming"
|
||||
for i in xrange(100):
|
||||
requests = Queue.Queue()
|
||||
def send_requests():
|
||||
global cur_request
|
||||
global cur_nums
|
||||
global requests
|
||||
|
||||
for _ in xrange(random.randint(5, 50)):
|
||||
nums = [random.randint(0, 1000) for _ in xrange(random.randint(10, 100))]
|
||||
name = str(uuid4())
|
||||
|
||||
requests.put((name, sum(nums)))
|
||||
|
||||
yield SimpleServiceRequest(request = name, num = nums)
|
||||
|
||||
for response in stub.biDiStreamingCall(send_requests(), 10):
|
||||
(exp_name, exp_sum) = requests.get()
|
||||
|
||||
assert response.response == exp_name
|
||||
assert response.num == exp_sum
|
||||
|
||||
# Signal the ending of the test
|
||||
stub.done(SimpleServiceDone(), 10)
|
Loading…
Reference in a new issue