Added tests for grpc server generation

This commit is contained in:
Travis Athougies 2016-07-25 16:37:06 -07:00
parent 86bad965ce
commit 1d7526da70
8 changed files with 254 additions and 2 deletions

14
README.md Normal file
View file

@ -0,0 +1,14 @@
Running the tests
-----------------
In order to run the tests, you will need to have the `grpcio`, `gevent`, and
`grpcio-tools` python packages installed. You can install them using
`pip`. It is recommended that you use a python virtualenv to do this.
```
$ virtualenv path/to/virtualenv # to create a virtualenv
$ . path/to/virtual/env/bin/activate # to use an existing virtualenv
$ pip install grpcio-tools gevent
$ pip install grpcio # Need to install grpcio-tools first to avoid a versioning problem
```

View file

@ -40,6 +40,7 @@ library
, tasty >= 0.11 && <0.12 , tasty >= 0.11 && <0.12
, tasty-hunit >= 0.9 && <0.10 , tasty-hunit >= 0.9 && <0.10
, safe , safe
, vector
c-sources: c-sources:
cbits/grpc_haskell.c cbits/grpc_haskell.c
@ -139,13 +140,17 @@ test-suite test
, containers ==0.5.* , containers ==0.5.*
, managed >= 1.0.5 && < 1.1 , managed >= 1.0.5 && < 1.1
, pipes ==4.1.* , pipes ==4.1.*
, protobuf-wire
, transformers , transformers
, safe , safe
, clock ==0.6.* , clock ==0.6.*
, turtle >= 1.2.0
, text
other-modules: other-modules:
LowLevelTests, LowLevelTests,
LowLevelTests.Op, LowLevelTests.Op,
UnsafeTests UnsafeTests,
GeneratedTests
default-language: Haskell2010 default-language: Haskell2010
ghc-options: -Wall -fwarn-incomplete-patterns -fno-warn-unused-do-bind -g -threaded -rtsopts ghc-options: -Wall -fwarn-incomplete-patterns -fno-warn-unused-do-bind -g -threaded -rtsopts
hs-source-dirs: tests hs-source-dirs: tests

View file

@ -9,7 +9,7 @@ packages:
- '.' - '.'
- location: - location:
git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git
commit: 57490f59e90bd0881f28c7898a152d4bd25c9fcb commit: 62aa5b92f21883d14bf8d3beed5645f84da01ad6
extra-dep: true extra-dep: true
- location: - location:
git: git@github.com:awakenetworks/proto3-wire.git git: git@github.com:awakenetworks/proto3-wire.git

58
tests/GeneratedTests.hs Normal file
View file

@ -0,0 +1,58 @@
{-# LANGUAGE OverloadedStrings #-}
module GeneratedTests where
import Test.Tasty
import Test.Tasty.HUnit (testCase, (@?=))
import Data.String
import Data.Protobuf.Wire.DotProto.Generate
import qualified Data.Text as T
import Turtle
generatedTests :: TestTree
generatedTests = testGroup "Code generator tests"
[ testServerGeneration ]
testServerGeneration :: TestTree
testServerGeneration = testCase "server generation" $ do
mktree hsTmpDir
mktree pyTmpDir
compileSimpleDotProto
exitCode <- shell (T.concat ["stack ghc -- --make -threaded -odir ", hsTmpDir, " -hidir ", hsTmpDir, " -o ", hsTmpDir, "/simple-server ", hsTmpDir, "/Simple.hs tests/TestServer.hs > /dev/null"]) empty
exitCode @?= ExitSuccess
exitCode <- shell (T.concat ["python -m grpc.tools.protoc -I tests --python_out=", pyTmpDir, " --grpc_python_out=", pyTmpDir, " tests/simple.proto"]) empty
exitCode @?= ExitSuccess
runManaged $ do
serverExitCodeA <- fork (shell (hsTmpDir <> "/simple-server") empty)
clientExitCodeA <- fork
(export "PYTHONPATH" pyTmpDir >> shell "python tests/test-client.py" empty)
liftIO $ do
serverExitCode <- liftIO (wait serverExitCodeA)
clientExitCode <- liftIO (wait clientExitCodeA)
serverExitCode @?= ExitSuccess
clientExitCode @?= ExitSuccess
rmtree hsTmpDir
rmtree pyTmpDir
hsTmpDir, pyTmpDir :: IsString a => a
hsTmpDir = "tests/tmp"
pyTmpDir = "tests/py-tmp"
compileSimpleDotProto :: IO ()
compileSimpleDotProto =
do dpRes <- readDotProtoWithContext "tests/simple.proto"
case dpRes of
Left err -> fail (show err)
Right (dp, ctxt) ->
case renderHsModuleForDotProto dp ctxt of
Left err -> fail ("compileSimpleDotProto: Error compiling test.proto: " <> show err)
Right hsSrc -> writeFile (hsTmpDir ++ "/Simple.hs") hsSrc

View file

@ -2,10 +2,12 @@ import LowLevelTests
import LowLevelTests.Op import LowLevelTests.Op
import Test.Tasty import Test.Tasty
import UnsafeTests import UnsafeTests
import GeneratedTests
main :: IO () main :: IO ()
main = defaultMain $ testGroup "GRPC Unit Tests" main = defaultMain $ testGroup "GRPC Unit Tests"
[ unsafeTests [ unsafeTests
, lowLevelOpTests , lowLevelOpTests
, lowLevelTests , lowLevelTests
, generatedTests
] ]

71
tests/TestServer.hs Normal file
View file

@ -0,0 +1,71 @@
{-# LANGUAGE OverloadedStrings #-}
module Main where
import Prelude hiding (sum)
import Simple
import Control.Concurrent
import Control.Concurrent.MVar
import Control.Monad
import Control.Monad.IO.Class
import Data.Monoid
import Data.Foldable (sum)
import Data.String
import Network.GRPC.LowLevel
handleNormalCall :: ServerCall SimpleServiceRequest -> IO (SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
handleNormalCall call =
pure (SimpleServiceResponse request result, mempty, StatusOk, StatusDetails "")
where SimpleServiceRequest request nums = payload call
result = sum nums
handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> Streaming (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
handleClientStreamingCall call recvRequest = go 0 ""
where go sumAccum nameAccum =
recvRequest >>= \req ->
case req of
Left ioError -> pure (Nothing, mempty, StatusCancelled, StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError)))
Right Nothing ->
pure (Just (SimpleServiceResponse nameAccum sumAccum), mempty, StatusOk, StatusDetails "")
Right (Just (SimpleServiceRequest name nums)) ->
go (sumAccum + sum nums) (nameAccum <> name)
handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails)
handleServerStreamingCall call sendResponse = go
where go = do forM_ nums $ \num ->
sendResponse (SimpleServiceResponse requestName num)
pure (mempty, StatusOk, StatusDetails "")
SimpleServiceRequest requestName nums = payload call
handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails)
handleBiDiStreamingCall call recvRequest sendResponse = go
where go = recvRequest >>= \req ->
case req of
Left ioError -> pure (mempty, StatusCancelled, StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError)))
Right Nothing ->
pure (mempty, StatusOk, StatusDetails "")
Right (Just (SimpleServiceRequest name nums)) ->
do sendResponse (SimpleServiceResponse name (sum nums))
go
handleDone :: MVar () -> ServerCall SimpleServiceDone -> IO (SimpleServiceDone, MetadataMap, StatusCode, StatusDetails)
handleDone exitVar req =
do forkIO (threadDelay 5000 >> putMVar exitVar ())
pure (payload req, mempty, StatusOk, StatusDetails "")
main :: IO ()
main = do exitVar <- newEmptyMVar
forkIO $ simpleServiceServer SimpleService
{ simpleServiceDone = handleDone exitVar
, simpleServiceNormalCall = handleNormalCall
, simpleServiceClientStreamingCall = handleClientStreamingCall
, simpleServiceServerStreamingCall = handleServerStreamingCall
, simpleServiceBiDiStreamingCall = handleBiDiStreamingCall }
takeMVar exitVar

24
tests/simple.proto Normal file
View file

@ -0,0 +1,24 @@
syntax = "proto3";
package simple;
service SimpleService {
rpc normalCall (SimpleServiceRequest) returns (SimpleServiceResponse) {}
rpc clientStreamingCall (stream SimpleServiceRequest) returns (SimpleServiceResponse) {}
rpc serverStreamingCall (SimpleServiceRequest) returns (stream SimpleServiceResponse) {}
rpc biDiStreamingCall (stream SimpleServiceRequest) returns (stream SimpleServiceResponse) {}
rpc done (SimpleServiceDone) returns (SimpleServiceDone) {}
}
message SimpleServiceDone {
}
message SimpleServiceRequest {
string request = 1;
repeated fixed32 num = 2;
}
message SimpleServiceResponse {
string response = 1;
fixed32 num = 2;
}

78
tests/test-client.py Normal file
View file

@ -0,0 +1,78 @@
from simple_pb2 import *
from uuid import uuid4
import random
import Queue
print "Starting python client"
channel = beta_implementations.insecure_channel('localhost', 50051)
stub = beta_create_SimpleService_stub(channel)
# Test normal call: return a sum of all numbers sent to it
print "Test 100 random sums"
for i in xrange(100):
randints = [random.randint(0, 1000) for _ in xrange(random.randint(10, 1000))]
name = "test%d" % i
response = stub.normalCall(SimpleServiceRequest(request = name, num = randints), 10)
assert response.response == name
assert response.num == sum(randints)
# Test streaming call: The server response will be the sum of all numbers sent in the request along with a concatenation of the request name
print "Test 100 random sums (client streaming)"
for i in xrange(100):
expected_sum = 0
expected_response_name = ''
def send_requests():
global expected_sum
global expected_response_name
for _ in xrange(random.randint(5, 50)):
nums = [random.randint(0, 1000) for _ in xrange(random.randint(10, 100))]
name = str(uuid4())
expected_sum += sum(nums)
expected_response_name += name
yield SimpleServiceRequest(request = name, num = nums)
response = stub.clientStreamingCall(send_requests(), 10)
assert response.response == expected_response_name
assert response.num == expected_sum
# Test server streaming call: The server should respond once for each number in the request
print "Test 100 random server streaming calls"
for i in xrange(100):
nums = [random.randint(0, 1000) for _ in xrange(random.randint(0, 1000))]
for response in stub.serverStreamingCall(SimpleServiceRequest(request = "server streaming", num = nums), 60):
assert response.num == nums[0]
assert response.response == "server streaming"
nums = nums[1:]
# Test bidirectional streaming: for each request, we should get a response indicating the sum of all numbers sent in the last request
print "Test bidirectional streaming"
for i in xrange(100):
requests = Queue.Queue()
def send_requests():
global cur_request
global cur_nums
global requests
for _ in xrange(random.randint(5, 50)):
nums = [random.randint(0, 1000) for _ in xrange(random.randint(10, 100))]
name = str(uuid4())
requests.put((name, sum(nums)))
yield SimpleServiceRequest(request = name, num = nums)
for response in stub.biDiStreamingCall(send_requests(), 10):
(exp_name, exp_sum) = requests.get()
assert response.response == exp_name
assert response.num == exp_sum
# Signal the ending of the test
stub.done(SimpleServiceDone(), 10)