diff --git a/grpc-haskell.cabal b/grpc-haskell.cabal index c2180e1..39bf62a 100644 --- a/grpc-haskell.cabal +++ b/grpc-haskell.cabal @@ -72,6 +72,7 @@ library Network.GRPC.HighLevel.Generated Network.GRPC.HighLevel.Server Network.GRPC.HighLevel.Server.Unregistered + Network.GRPC.HighLevel.Client extra-libraries: grpc includes: diff --git a/src/Network/GRPC/HighLevel/Client.hs b/src/Network/GRPC/HighLevel/Client.hs new file mode 100644 index 0000000..8134858 --- /dev/null +++ b/src/Network/GRPC/HighLevel/Client.hs @@ -0,0 +1,122 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} + +module Network.GRPC.HighLevel.Client + ( RegisteredMethod + , TimeoutSeconds + , MetadataMap(..) + , StatusDetails(..) + , GRPCMethodType(..) + , StreamRecv + , StreamSend + , WritesDone + , LL.Client + + , ServiceClient + , ClientRequest(..) + , ClientResult(..) +-- , ClientResponse, response, initMD, trailMD, rspCode, details + + , ClientRegisterable(..) + + , clientRequest ) where + +import qualified Network.GRPC.LowLevel.Client as LL +import qualified Network.GRPC.LowLevel.Call as LL +import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds) +import Network.GRPC.LowLevel ( GRPCMethodType(..) + , StatusCode(..) + , StatusDetails(..) + , MetadataMap(..) + , GRPCIOError(..) + , StreamRecv + , StreamSend ) +import Network.GRPC.LowLevel.Op (WritesDone) +import Network.GRPC.HighLevel.Server (convertRecv, convertSend) + +import Data.Protobuf.Wire (Message, toLazyByteString, fromByteString) +import Proto3.Wire.Decode (ParseError) +import Data.ByteString (ByteString) +import qualified Data.ByteString.Lazy as BL + +newtype RegisteredMethod (mt :: GRPCMethodType) request response + = RegisteredMethod (LL.RegisteredMethod mt) + deriving Show + +type ServiceClient service = service ClientRequest ClientResult + +data ClientError + = ClientErrorNoParse ParseError + | ClientIOError GRPCIOError + deriving (Show, Eq) + +data ClientRequest (streamType :: GRPCMethodType) request response where + ClientNormalRequest :: request -> TimeoutSeconds -> MetadataMap -> ClientRequest 'Normal request response + ClientWriterRequest :: TimeoutSeconds -> MetadataMap -> (StreamSend request -> IO ()) -> ClientRequest 'ClientStreaming request response + ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response + ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response + +data ClientResult (streamType :: GRPCMethodType) response where + ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response + ClientWriterResponse :: Maybe response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ClientStreaming response + ClientReaderResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ServerStreaming response + ClientBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'BiDiStreaming response + + ClientError :: ClientError -> ClientResult streamType response + +class ClientRegisterable (methodType :: GRPCMethodType) where + clientRegisterMethod :: LL.Client + -> LL.MethodName + -> IO (RegisteredMethod methodType request response) + +instance ClientRegisterable 'Normal where + clientRegisterMethod client methodName = + RegisteredMethod <$> LL.clientRegisterMethodNormal client methodName + +instance ClientRegisterable 'ClientStreaming where + clientRegisterMethod client methodName = + RegisteredMethod <$> LL.clientRegisterMethodClientStreaming client methodName + +instance ClientRegisterable 'ServerStreaming where + clientRegisterMethod client methodName = + RegisteredMethod <$> LL.clientRegisterMethodServerStreaming client methodName + +instance ClientRegisterable 'BiDiStreaming where + clientRegisterMethod client methodName = + RegisteredMethod <$> LL.clientRegisterMethodBiDiStreaming client methodName + +clientRequest :: (Message request, Message response) => + LL.Client -> RegisteredMethod streamType request response + -> ClientRequest streamType request response -> IO (ClientResult streamType response) +clientRequest client (RegisteredMethod method) (ClientNormalRequest req timeout meta) = + mkResponse <$> LL.clientRequest client method timeout (BL.toStrict (toLazyByteString req)) meta + where + mkResponse (Left ioError_) = ClientError (ClientIOError ioError_) + mkResponse (Right rsp) = + case fromByteString (LL.rspBody rsp) of + Left err -> ClientError (ClientErrorNoParse err) + Right parsedRsp -> + ClientNormalResponse parsedRsp (LL.initMD rsp) (LL.trailMD rsp) (LL.rspCode rsp) (LL.details rsp) +clientRequest client (RegisteredMethod method) (ClientWriterRequest timeout meta handler) = + mkResponse <$> LL.clientWriter client method timeout meta (handler . convertSend) + where + mkResponse (Left ioError_) = ClientError (ClientIOError ioError_) + mkResponse (Right (rsp_, initMD_, trailMD_, rspCode_, details_)) = + case maybe (Right Nothing) (fmap Just . fromByteString) rsp_ of + Left err -> ClientError (ClientErrorNoParse err) + Right parsedRsp -> + ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_ +clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) = + mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\m recv -> handler m (convertRecv recv)) + where + mkResponse (Left ioError_) = ClientError (ClientIOError ioError_) + mkResponse (Right (meta_, rspCode_, details_)) = + ClientReaderResponse meta_ rspCode_ details_ +clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) = + mkResponse <$> LL.clientRW client method timeout meta (\m recv send writesDone -> handler meta (convertRecv recv) (convertSend send) writesDone) + where + mkResponse (Left ioError_) = ClientError (ClientIOError ioError_) + mkResponse (Right (meta_, rspCode_, details_)) = + ClientBiDiResponse meta_ rspCode_ details_ diff --git a/src/Network/GRPC/HighLevel/Generated.hs b/src/Network/GRPC/HighLevel/Generated.hs index fca6c8d..dfcc79e 100644 --- a/src/Network/GRPC/HighLevel/Generated.hs +++ b/src/Network/GRPC/HighLevel/Generated.hs @@ -1,3 +1,4 @@ + {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} diff --git a/src/Network/GRPC/HighLevel/Server.hs b/src/Network/GRPC/HighLevel/Server.hs index a71cd0b..8739df8 100644 --- a/src/Network/GRPC/HighLevel/Server.hs +++ b/src/Network/GRPC/HighLevel/Server.hs @@ -16,16 +16,46 @@ import Data.Protobuf.Wire.Class import Network.GRPC.LowLevel import System.IO -type ServerHandler a b - = ServerCall a +type ServerCallMetadata = ServerCall () + +type ServiceServer service = service ServerRequest ServerResponse + +data ServerRequest (streamType :: GRPCMethodType) request response where + ServerNormalRequest :: ServerCallMetadata -> request -> ServerRequest 'Normal request response + ServerReaderRequest :: ServerCallMetadata -> StreamRecv request -> ServerRequest 'ClientStreaming request response + ServerWriterRequest :: ServerCallMetadata -> request -> StreamSend response -> ServerRequest 'ServerStreaming request response + ServerBiDiRequest :: ServerCallMetadata -> StreamRecv request -> StreamSend response -> ServerRequest 'BiDiStreaming request response + +data ServerResponse (streamType :: GRPCMethodType) response where + ServerNormalResponse :: response -> MetadataMap -> StatusCode -> StatusDetails + -> ServerResponse 'Normal response + ServerReaderResponse :: Maybe response -> MetadataMap -> StatusCode -> StatusDetails + -> ServerResponse 'ClientStreaming response + ServerWriterResponse :: MetadataMap -> StatusCode -> StatusDetails + -> ServerResponse 'ServerStreaming response + ServerBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails + -> ServerResponse 'BiDiStreaming response + +type ServerHandler a b = + ServerCall a -> IO (b, MetadataMap, StatusCode, StatusDetails) +convertGeneratedServerHandler :: + (Message request, Message response) + => (ServerRequest 'Normal request response -> IO (ServerResponse 'Normal response)) + -> ServerHandler request response +convertGeneratedServerHandler handler call = + do let call' = call { payload = () } + ServerNormalResponse rsp meta stsCode stsDetails <- + handler (ServerNormalRequest call' (payload call)) + return (rsp, meta, stsCode, stsDetails) + convertServerHandler :: (Message a, Message b) => ServerHandler a b -> ServerHandlerLL convertServerHandler f c = case fromByteString (payload c) of Left x -> CE.throw (GRPCIODecodeError x) - Right x -> do (y, tm, sc, sd) <- f (const x <$> c) + Right x -> do (y, tm, sc, sd) <- f (fmap (const x) c) return (toBS y, tm, sc, sd) type ServerReaderHandler a b @@ -33,10 +63,20 @@ type ServerReaderHandler a b -> StreamRecv a -> IO (Maybe b, MetadataMap, StatusCode, StatusDetails) +convertGeneratedServerReaderHandler :: + (Message request, Message response) + => (ServerRequest 'ClientStreaming request response -> IO (ServerResponse 'ClientStreaming response)) + -> ServerReaderHandler request response +convertGeneratedServerReaderHandler handler call recv = + do ServerReaderResponse rsp meta stsCode stsDetails <- + handler (ServerReaderRequest call recv) + return (rsp, meta, stsCode, stsDetails) + convertServerReaderHandler :: (Message a, Message b) => ServerReaderHandler a b -> ServerReaderHandlerLL -convertServerReaderHandler f c recv = serialize <$> f c (convertRecv recv) +convertServerReaderHandler f c recv = + serialize <$> f c (convertRecv recv) where serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd) @@ -45,10 +85,21 @@ type ServerWriterHandler a b = -> StreamSend b -> IO (MetadataMap, StatusCode, StatusDetails) -convertServerWriterHandler :: (Message a, Message b) - => ServerWriterHandler a b - -> ServerWriterHandlerLL -convertServerWriterHandler f c send = f (convert <$> c) (convertSend send) +convertGeneratedServerWriterHandler :: + (Message request, Message response) + => (ServerRequest 'ServerStreaming request response -> IO (ServerResponse 'ServerStreaming response)) + -> ServerWriterHandler request response +convertGeneratedServerWriterHandler handler call send = + do let call' = call { payload = () } + ServerWriterResponse meta stsCode stsDetails <- + handler (ServerWriterRequest call' (payload call) send) + return (meta, stsCode, stsDetails) + +convertServerWriterHandler :: (Message a, Message b) => + ServerWriterHandler a b + -> ServerWriterHandlerLL +convertServerWriterHandler f c send = + f (convert <$> c) (convertSend send) where convert bs = case fromByteString bs of Left x -> CE.throw (GRPCIODecodeError x) @@ -60,10 +111,20 @@ type ServerRWHandler a b -> StreamSend b -> IO (MetadataMap, StatusCode, StatusDetails) +convertGeneratedServerRWHandler :: + (Message request, Message response) + => (ServerRequest 'BiDiStreaming request response -> IO (ServerResponse 'BiDiStreaming response)) + -> ServerRWHandler request response +convertGeneratedServerRWHandler handler call recv send = + do ServerBiDiResponse meta stsCode stsDetails <- + handler (ServerBiDiRequest call recv send) + return (meta, stsCode, stsDetails) + convertServerRWHandler :: (Message a, Message b) => ServerRWHandler a b -> ServerRWHandlerLL -convertServerRWHandler f c r s = f c (convertRecv r) (convertSend s) +convertServerRWHandler f c recv send = + f c (convertRecv recv) (convertSend send) convertRecv :: Message a => StreamRecv ByteString -> StreamRecv a convertRecv = diff --git a/stack.yaml b/stack.yaml index 1a17c67..f6ea078 100644 --- a/stack.yaml +++ b/stack.yaml @@ -7,9 +7,10 @@ resolver: lts-7.4 # Local packages, usually specified by relative directory name packages: - '.' -- location: - git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git - commit: 2eda7a9e33e8a2f32c3ab8c4ace338b08fb79daa +#- location: +# git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git +# commit: 2eda7a9e33e8a2f32c3ab8c4ace338b08fb79daa +- location: '../protobuf-wire' extra-dep: true - location: git: git@github.com:awakenetworks/proto3-wire.git diff --git a/tests/GeneratedTests.hs b/tests/GeneratedTests.hs index 17a9d53..0a29118 100644 --- a/tests/GeneratedTests.hs +++ b/tests/GeneratedTests.hs @@ -13,7 +13,8 @@ import Turtle generatedTests :: TestTree generatedTests = testGroup "Code generator tests" - [ testServerGeneration ] + [ testServerGeneration + , testClientGeneration ] testServerGeneration :: TestTree testServerGeneration = testCase "server generation" $ do @@ -43,6 +44,34 @@ testServerGeneration = testCase "server generation" $ do rmtree hsTmpDir rmtree pyTmpDir +testClientGeneration :: TestTree +testClientGeneration = testCase "client generation" $ do + mktree hsTmpDir + mktree pyTmpDir + + compileSimpleDotProto + + exitCode <- proc "tests/simple-client.sh" [hsTmpDir] empty + exitCode @?= ExitSuccess + + exitCode <- proc "tests/protoc.sh" [pyTmpDir] empty + exitCode @?= ExitSuccess + + runManaged $ do + serverExitCodeA <- fork + (export "PYTHONPATH" pyTmpDir >> shell "python tests/test-server.py" empty) + clientExitCodeA <- fork (shell (hsTmpDir <> "/simple-client") empty) + + liftIO $ do + serverExitCode <- liftIO (wait serverExitCodeA) + clientExitCode <- liftIO (wait clientExitCodeA) + + serverExitCode @?= ExitSuccess + clientExitCode @?= ExitSuccess + + rmtree hsTmpDir + rmtree pyTmpDir + hsTmpDir, pyTmpDir :: IsString a => a hsTmpDir = "tests/tmp" pyTmpDir = "tests/py-tmp" diff --git a/tests/TestClient.hs b/tests/TestClient.hs new file mode 100644 index 0000000..d467c19 --- /dev/null +++ b/tests/TestClient.hs @@ -0,0 +1,135 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Main where + +import Prelude hiding (sum) + +import Simple + +import Control.Concurrent +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.IO.Class +import Control.Exception +import Control.Arrow + +import Data.Monoid +import Data.Foldable (sum) +import Data.String +import Data.Word +import Data.Vector (fromList) +import Data.Protobuf.Wire + +import Network.GRPC.LowLevel +import Network.GRPC.HighLevel.Client + +import System.Random + +import Test.Tasty +import Test.Tasty.HUnit ((@?=), assertString, testCase) + +testNormalCall client = testCase "Normal call" $ + do randoms <- fromList <$> replicateM 1000 (Fixed <$> randomRIO (1, 1000)) + let req = SimpleServiceRequest "NormalRequest" randoms + res <- simpleServiceNormalCall client + (ClientNormalRequest req 10 mempty) + case res of + ClientError err -> assertString ("ClientError: " <> show err) + ClientNormalResponse res _ _ stsCode _ -> + do stsCode @?= StatusOk + simpleServiceResponseResponse res @?= "NormalRequest" + simpleServiceResponseNum res @?= sum randoms + +testClientStreamingCall client = testCase "Client-streaming call" $ + do iterationCount <- randomRIO (5, 50) + v <- newEmptyMVar + res <- simpleServiceClientStreamingCall client . ClientWriterRequest 10 mempty $ \send -> + do (finalName, totalSum) <- + fmap ((mconcat *** (sum . mconcat)) . unzip) . + replicateM iterationCount $ + do randoms <- fromList <$> replicateM 1000 (Fixed <$> randomRIO (1, 1000)) + name <- fromString <$> replicateM 10 (randomRIO ('a', 'z')) + send (SimpleServiceRequest name randoms) + pure (name, randoms) + putMVar v (finalName, totalSum) + + (finalName, totalSum) <- readMVar v + case res of + ClientError err -> assertString ("ClientError: " <> show err) + ClientWriterResponse Nothing _ _ _ _ -> assertString "No response received" + ClientWriterResponse (Just res) _ _ stsCode _ -> + do stsCode @?= StatusOk + simpleServiceResponseResponse res @?= finalName + simpleServiceResponseNum res @?= totalSum + +testServerStreamingCall client = testCase "Server-streaming call" $ + do numCount <- randomRIO (50, 500) + nums <- replicateM numCount (Fixed <$> randomIO) + + let checkResults [] recv = + do res <- recv + case res of + Left err -> assertString ("recv error: " <> show err) + Right Nothing -> pure () + Right (Just _) -> assertString "recv: elements past end of stream" + checkResults (expNum:nums) recv = + do res <- recv + case res of + Left err -> assertString ("recv error: " <> show err) + Right Nothing -> assertString ("recv: stream ended earlier than expected") + Right (Just (SimpleServiceResponse response num)) -> + do response @?= "Test" + num @?= expNum + checkResults nums recv + res <- simpleServiceServerStreamingCall client $ + ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty + (\_ -> checkResults nums) + case res of + ClientError err -> assertString ("ClientError: " <> show err) + ClientReaderResponse _ sts _ -> + sts @?= StatusOk + +testBiDiStreamingCall client = testCase "Bidi-streaming call" $ + do let handleRequests (0 :: Int) _ _ done = done >> pure () + handleRequests n recv send done = + do numCount <- randomRIO (10, 1000) + nums <- fromList <$> replicateM numCount (Fixed <$> randomRIO (1, 1000)) + testName <- fromString <$> replicateM 10 (randomRIO ('a', 'z')) + send (SimpleServiceRequest testName nums) + + res <- recv + case res of + Left err -> assertString ("recv error: " <> show err) + Right Nothing -> pure () + Right (Just (SimpleServiceResponse name total)) -> + do name @?= testName + total @?= sum nums + handleRequests (n - 1) recv send done + + iterations <- randomRIO (50, 500) + + res <- simpleServiceBiDiStreamingCall client $ + ClientBiDiRequest 10 mempty (\_ -> handleRequests iterations) + case res of + ClientError err -> assertString ("ClientError: " <> show err) + ClientBiDiResponse _ sts _ -> + sts @?= StatusOk + +main :: IO () +main = do + threadDelay 10000000 + withGRPC $ \grpc -> + withClient grpc (ClientConfig "localhost" 50051 [] Nothing) $ \client -> + do service <- simpleServiceClient client + + (defaultMain $ testGroup "Send gRPC requests" + [ testNormalCall service + , testClientStreamingCall service + , testServerStreamingCall service + , testBiDiStreamingCall service ]) `finally` + (simpleServiceDone service (ClientNormalRequest SimpleServiceDone 10 mempty)) + diff --git a/tests/TestServer.hs b/tests/TestServer.hs index 81a028b..e598a83 100644 --- a/tests/TestServer.hs +++ b/tests/TestServer.hs @@ -1,4 +1,7 @@ {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} + module Main where import Prelude hiding (sum) @@ -15,48 +18,46 @@ import Data.Foldable (sum) import Data.String import Network.GRPC.LowLevel +import Network.GRPC.HighLevel.Server -handleNormalCall :: ServerCall SimpleServiceRequest -> IO (SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails) -handleNormalCall call = - pure (SimpleServiceResponse request result, mempty, StatusOk, StatusDetails "") - where SimpleServiceRequest request nums = payload call +handleNormalCall :: ServerRequest 'Normal SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'Normal SimpleServiceResponse) +handleNormalCall (ServerNormalRequest meta (SimpleServiceRequest request nums)) = + pure (ServerNormalResponse (SimpleServiceResponse request result) mempty StatusOk (StatusDetails "")) + where result = sum nums - result = sum nums - -handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> IO (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails) -handleClientStreamingCall call recvRequest = go 0 "" +handleClientStreamingCall :: ServerRequest 'ClientStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ClientStreaming SimpleServiceResponse) +handleClientStreamingCall (ServerReaderRequest call recvRequest) = go 0 "" where go sumAccum nameAccum = recvRequest >>= \req -> case req of - Left ioError -> pure (Nothing, mempty, StatusCancelled, StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError))) + Left ioError -> pure (ServerReaderResponse Nothing mempty StatusCancelled (StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError)))) Right Nothing -> - pure (Just (SimpleServiceResponse nameAccum sumAccum), mempty, StatusOk, StatusDetails "") + pure (ServerReaderResponse (Just (SimpleServiceResponse nameAccum sumAccum)) mempty StatusOk (StatusDetails "")) Right (Just (SimpleServiceRequest name nums)) -> go (sumAccum + sum nums) (nameAccum <> name) -handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails) -handleServerStreamingCall call sendResponse = go +handleServerStreamingCall :: ServerRequest 'ServerStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ServerStreaming SimpleServiceResponse) +handleServerStreamingCall (ServerWriterRequest call (SimpleServiceRequest requestName nums) sendResponse) = go where go = do forM_ nums $ \num -> sendResponse (SimpleServiceResponse requestName num) - pure (mempty, StatusOk, StatusDetails "") + pure (ServerWriterResponse mempty StatusOk (StatusDetails "")) - SimpleServiceRequest requestName nums = payload call - -handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails) -handleBiDiStreamingCall call recvRequest sendResponse = go +handleBiDiStreamingCall :: ServerRequest 'BiDiStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'BiDiStreaming SimpleServiceResponse) +handleBiDiStreamingCall (ServerBiDiRequest call recvRequest sendResponse) = go where go = recvRequest >>= \req -> case req of - Left ioError -> pure (mempty, StatusCancelled, StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError))) + Left ioError -> + pure (ServerBiDiResponse mempty StatusCancelled (StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError)))) Right Nothing -> - pure (mempty, StatusOk, StatusDetails "") + pure (ServerBiDiResponse mempty StatusOk (StatusDetails "")) Right (Just (SimpleServiceRequest name nums)) -> do sendResponse (SimpleServiceResponse name (sum nums)) go -handleDone :: MVar () -> ServerCall SimpleServiceDone -> IO (SimpleServiceDone, MetadataMap, StatusCode, StatusDetails) -handleDone exitVar req = +handleDone :: MVar () -> ServerRequest 'Normal SimpleServiceDone SimpleServiceDone -> IO (ServerResponse 'Normal SimpleServiceDone) +handleDone exitVar (ServerNormalRequest _ req) = do forkIO (threadDelay 5000 >> putMVar exitVar ()) - pure (payload req, mempty, StatusOk, StatusDetails "") + pure (ServerNormalResponse req mempty StatusOk (StatusDetails "")) main :: IO () main = do exitVar <- newEmptyMVar diff --git a/tests/simple-client.sh b/tests/simple-client.sh new file mode 100755 index 0000000..8feb1f4 --- /dev/null +++ b/tests/simple-client.sh @@ -0,0 +1,13 @@ +#!/bin/bash -eu + +hsTmpDir=$1 + +stack ghc -- \ + --make \ + -threaded \ + -odir $hsTmpDir \ + -hidir $hsTmpDir \ + -o $hsTmpDir/simple-client \ + $hsTmpDir/Simple.hs \ + tests/TestClient.hs \ + > /dev/null diff --git a/tests/test-server.py b/tests/test-server.py new file mode 100644 index 0000000..c6c54b1 --- /dev/null +++ b/tests/test-server.py @@ -0,0 +1,41 @@ +from simple_pb2 import * +from uuid import uuid4 +import random +import Queue +import grpc + +print "Starting python server" + +done_queue = Queue.Queue() + +class SimpleServiceServer(BetaSimpleServiceServicer): + def done(self, request, context): + global server + done_queue.put_nowait(()) + + return SimpleServiceDone() + + def normalCall(self, request, context): + return SimpleServiceResponse(response = "NormalRequest", num = sum(request.num)) + + def clientStreamingCall(self, requests, context): + cur_name = "" + cur_sum = 0 + for request in requests: + cur_name += request.request + cur_sum += sum(request.num) + return SimpleServiceResponse(response = cur_name, num = cur_sum) + + def serverStreamingCall(self, request, context): + for num in request.num: + yield SimpleServiceResponse(response = request.request, num = num) + + def biDiStreamingCall(self, requests, context): + for request in requests: + yield SimpleServiceResponse(response = request.request, num = sum(request.num)) + +server = beta_create_SimpleService_server(SimpleServiceServer()) +server.add_insecure_port('[::]:50051') +server.start() + +done_queue.get()