diff --git a/README.md b/README.md new file mode 100644 index 0000000..a1b0c50 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +Running the tests +----------------- + +In order to run the tests, you will need to have the `grpcio`, `gevent`, and +`grpcio-tools` python packages installed. You can install them using +`pip`. It is recommended that you use a python virtualenv to do this. + +``` +$ virtualenv path/to/virtualenv # to create a virtualenv +$ . path/to/virtual/env/bin/activate # to use an existing virtualenv +$ pip install grpcio-tools gevent +$ pip install grpcio # Need to install grpcio-tools first to avoid a versioning problem +``` + diff --git a/grpc-haskell.cabal b/grpc-haskell.cabal index e846638..ba9a961 100644 --- a/grpc-haskell.cabal +++ b/grpc-haskell.cabal @@ -40,6 +40,7 @@ library , tasty >= 0.11 && <0.12 , tasty-hunit >= 0.9 && <0.10 , safe + , vector c-sources: cbits/grpc_haskell.c @@ -139,13 +140,17 @@ test-suite test , containers ==0.5.* , managed >= 1.0.5 && < 1.1 , pipes ==4.1.* + , protobuf-wire , transformers , safe , clock ==0.6.* + , turtle >= 1.2.0 + , text other-modules: LowLevelTests, LowLevelTests.Op, - UnsafeTests + UnsafeTests, + GeneratedTests default-language: Haskell2010 ghc-options: -Wall -fwarn-incomplete-patterns -fno-warn-unused-do-bind -g -threaded -rtsopts hs-source-dirs: tests diff --git a/stack.yaml b/stack.yaml index 8b091d0..b37a35b 100644 --- a/stack.yaml +++ b/stack.yaml @@ -9,7 +9,7 @@ packages: - '.' - location: git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git - commit: 57490f59e90bd0881f28c7898a152d4bd25c9fcb + commit: 62aa5b92f21883d14bf8d3beed5645f84da01ad6 extra-dep: true - location: git: git@github.com:awakenetworks/proto3-wire.git diff --git a/tests/GeneratedTests.hs b/tests/GeneratedTests.hs new file mode 100644 index 0000000..c5e41f9 --- /dev/null +++ b/tests/GeneratedTests.hs @@ -0,0 +1,58 @@ +{-# LANGUAGE OverloadedStrings #-} + +module GeneratedTests where + +import Test.Tasty +import Test.Tasty.HUnit (testCase, (@?=)) + +import Data.String +import Data.Protobuf.Wire.DotProto.Generate +import qualified Data.Text as T + +import Turtle + +generatedTests :: TestTree +generatedTests = testGroup "Code generator tests" + [ testServerGeneration ] + +testServerGeneration :: TestTree +testServerGeneration = testCase "server generation" $ do + mktree hsTmpDir + mktree pyTmpDir + + compileSimpleDotProto + + exitCode <- shell (T.concat ["stack ghc -- --make -threaded -odir ", hsTmpDir, " -hidir ", hsTmpDir, " -o ", hsTmpDir, "/simple-server ", hsTmpDir, "/Simple.hs tests/TestServer.hs > /dev/null"]) empty + exitCode @?= ExitSuccess + + exitCode <- shell (T.concat ["python -m grpc.tools.protoc -I tests --python_out=", pyTmpDir, " --grpc_python_out=", pyTmpDir, " tests/simple.proto"]) empty + exitCode @?= ExitSuccess + + runManaged $ do + serverExitCodeA <- fork (shell (hsTmpDir <> "/simple-server") empty) + clientExitCodeA <- fork + (export "PYTHONPATH" pyTmpDir >> shell "python tests/test-client.py" empty) + + liftIO $ do + serverExitCode <- liftIO (wait serverExitCodeA) + clientExitCode <- liftIO (wait clientExitCodeA) + + serverExitCode @?= ExitSuccess + clientExitCode @?= ExitSuccess + + rmtree hsTmpDir + rmtree pyTmpDir + +hsTmpDir, pyTmpDir :: IsString a => a +hsTmpDir = "tests/tmp" +pyTmpDir = "tests/py-tmp" + +compileSimpleDotProto :: IO () +compileSimpleDotProto = + do dpRes <- readDotProtoWithContext "tests/simple.proto" + case dpRes of + Left err -> fail (show err) + Right (dp, ctxt) -> + case renderHsModuleForDotProto dp ctxt of + Left err -> fail ("compileSimpleDotProto: Error compiling test.proto: " <> show err) + Right hsSrc -> writeFile (hsTmpDir ++ "/Simple.hs") hsSrc diff --git a/tests/Properties.hs b/tests/Properties.hs index f7bf152..1e8a977 100644 --- a/tests/Properties.hs +++ b/tests/Properties.hs @@ -2,10 +2,12 @@ import LowLevelTests import LowLevelTests.Op import Test.Tasty import UnsafeTests +import GeneratedTests main :: IO () main = defaultMain $ testGroup "GRPC Unit Tests" [ unsafeTests , lowLevelOpTests , lowLevelTests + , generatedTests ] diff --git a/tests/TestServer.hs b/tests/TestServer.hs new file mode 100644 index 0000000..58ae506 --- /dev/null +++ b/tests/TestServer.hs @@ -0,0 +1,71 @@ +{-# LANGUAGE OverloadedStrings #-} +module Main where + +import Prelude hiding (sum) + +import Simple + +import Control.Concurrent +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.IO.Class + +import Data.Monoid +import Data.Foldable (sum) +import Data.String + +import Network.GRPC.LowLevel + +handleNormalCall :: ServerCall SimpleServiceRequest -> IO (SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails) +handleNormalCall call = + pure (SimpleServiceResponse request result, mempty, StatusOk, StatusDetails "") + where SimpleServiceRequest request nums = payload call + + result = sum nums + +handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> Streaming (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails) +handleClientStreamingCall call recvRequest = go 0 "" + where go sumAccum nameAccum = + recvRequest >>= \req -> + case req of + Left ioError -> pure (Nothing, mempty, StatusCancelled, StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError))) + Right Nothing -> + pure (Just (SimpleServiceResponse nameAccum sumAccum), mempty, StatusOk, StatusDetails "") + Right (Just (SimpleServiceRequest name nums)) -> + go (sumAccum + sum nums) (nameAccum <> name) + +handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails) +handleServerStreamingCall call sendResponse = go + where go = do forM_ nums $ \num -> + sendResponse (SimpleServiceResponse requestName num) + pure (mempty, StatusOk, StatusDetails "") + + SimpleServiceRequest requestName nums = payload call + +handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails) +handleBiDiStreamingCall call recvRequest sendResponse = go + where go = recvRequest >>= \req -> + case req of + Left ioError -> pure (mempty, StatusCancelled, StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError))) + Right Nothing -> + pure (mempty, StatusOk, StatusDetails "") + Right (Just (SimpleServiceRequest name nums)) -> + do sendResponse (SimpleServiceResponse name (sum nums)) + go + +handleDone :: MVar () -> ServerCall SimpleServiceDone -> IO (SimpleServiceDone, MetadataMap, StatusCode, StatusDetails) +handleDone exitVar req = + do forkIO (threadDelay 5000 >> putMVar exitVar ()) + pure (payload req, mempty, StatusOk, StatusDetails "") + +main :: IO () +main = do exitVar <- newEmptyMVar + + forkIO $ simpleServiceServer SimpleService + { simpleServiceDone = handleDone exitVar + , simpleServiceNormalCall = handleNormalCall + , simpleServiceClientStreamingCall = handleClientStreamingCall + , simpleServiceServerStreamingCall = handleServerStreamingCall + , simpleServiceBiDiStreamingCall = handleBiDiStreamingCall } + + takeMVar exitVar diff --git a/tests/simple.proto b/tests/simple.proto new file mode 100644 index 0000000..5a18c7f --- /dev/null +++ b/tests/simple.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; +package simple; + +service SimpleService { + rpc normalCall (SimpleServiceRequest) returns (SimpleServiceResponse) {} + rpc clientStreamingCall (stream SimpleServiceRequest) returns (SimpleServiceResponse) {} + rpc serverStreamingCall (SimpleServiceRequest) returns (stream SimpleServiceResponse) {} + rpc biDiStreamingCall (stream SimpleServiceRequest) returns (stream SimpleServiceResponse) {} + + rpc done (SimpleServiceDone) returns (SimpleServiceDone) {} +} + +message SimpleServiceDone { +} + +message SimpleServiceRequest { + string request = 1; + repeated fixed32 num = 2; +} + +message SimpleServiceResponse { + string response = 1; + fixed32 num = 2; +} \ No newline at end of file diff --git a/tests/test-client.py b/tests/test-client.py new file mode 100644 index 0000000..5dd34d3 --- /dev/null +++ b/tests/test-client.py @@ -0,0 +1,78 @@ +from simple_pb2 import * +from uuid import uuid4 +import random +import Queue + +print "Starting python client" + +channel = beta_implementations.insecure_channel('localhost', 50051) +stub = beta_create_SimpleService_stub(channel) + +# Test normal call: return a sum of all numbers sent to it +print "Test 100 random sums" +for i in xrange(100): + randints = [random.randint(0, 1000) for _ in xrange(random.randint(10, 1000))] + name = "test%d" % i + response = stub.normalCall(SimpleServiceRequest(request = name, num = randints), 10) + + assert response.response == name + assert response.num == sum(randints) + +# Test streaming call: The server response will be the sum of all numbers sent in the request along with a concatenation of the request name +print "Test 100 random sums (client streaming)" +for i in xrange(100): + expected_sum = 0 + expected_response_name = '' + + def send_requests(): + global expected_sum + global expected_response_name + + for _ in xrange(random.randint(5, 50)): + nums = [random.randint(0, 1000) for _ in xrange(random.randint(10, 100))] + name = str(uuid4()) + + expected_sum += sum(nums) + expected_response_name += name + + yield SimpleServiceRequest(request = name, num = nums) + + response = stub.clientStreamingCall(send_requests(), 10) + assert response.response == expected_response_name + assert response.num == expected_sum + +# Test server streaming call: The server should respond once for each number in the request +print "Test 100 random server streaming calls" +for i in xrange(100): + nums = [random.randint(0, 1000) for _ in xrange(random.randint(0, 1000))] + + for response in stub.serverStreamingCall(SimpleServiceRequest(request = "server streaming", num = nums), 60): + assert response.num == nums[0] + assert response.response == "server streaming" + nums = nums[1:] + +# Test bidirectional streaming: for each request, we should get a response indicating the sum of all numbers sent in the last request +print "Test bidirectional streaming" +for i in xrange(100): + requests = Queue.Queue() + def send_requests(): + global cur_request + global cur_nums + global requests + + for _ in xrange(random.randint(5, 50)): + nums = [random.randint(0, 1000) for _ in xrange(random.randint(10, 100))] + name = str(uuid4()) + + requests.put((name, sum(nums))) + + yield SimpleServiceRequest(request = name, num = nums) + + for response in stub.biDiStreamingCall(send_requests(), 10): + (exp_name, exp_sum) = requests.get() + + assert response.response == exp_name + assert response.num == exp_sum + +# Signal the ending of the test +stub.done(SimpleServiceDone(), 10)