GADT-based high level interface

This commit is contained in:
travis 2016-11-29 13:34:34 -08:00 committed by GitHub Enterprise
commit 7d5df1d204
10 changed files with 440 additions and 35 deletions

View file

@ -72,6 +72,7 @@ library
Network.GRPC.HighLevel.Generated Network.GRPC.HighLevel.Generated
Network.GRPC.HighLevel.Server Network.GRPC.HighLevel.Server
Network.GRPC.HighLevel.Server.Unregistered Network.GRPC.HighLevel.Server.Unregistered
Network.GRPC.HighLevel.Client
extra-libraries: extra-libraries:
grpc grpc
includes: includes:

View file

@ -0,0 +1,122 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
module Network.GRPC.HighLevel.Client
( RegisteredMethod
, TimeoutSeconds
, MetadataMap(..)
, StatusDetails(..)
, GRPCMethodType(..)
, StreamRecv
, StreamSend
, WritesDone
, LL.Client
, ServiceClient
, ClientRequest(..)
, ClientResult(..)
-- , ClientResponse, response, initMD, trailMD, rspCode, details
, ClientRegisterable(..)
, clientRequest ) where
import qualified Network.GRPC.LowLevel.Client as LL
import qualified Network.GRPC.LowLevel.Call as LL
import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds)
import Network.GRPC.LowLevel ( GRPCMethodType(..)
, StatusCode(..)
, StatusDetails(..)
, MetadataMap(..)
, GRPCIOError(..)
, StreamRecv
, StreamSend )
import Network.GRPC.LowLevel.Op (WritesDone)
import Network.GRPC.HighLevel.Server (convertRecv, convertSend)
import Data.Protobuf.Wire (Message, toLazyByteString, fromByteString)
import Proto3.Wire.Decode (ParseError)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as BL
newtype RegisteredMethod (mt :: GRPCMethodType) request response
= RegisteredMethod (LL.RegisteredMethod mt)
deriving Show
type ServiceClient service = service ClientRequest ClientResult
data ClientError
= ClientErrorNoParse ParseError
| ClientIOError GRPCIOError
deriving (Show, Eq)
data ClientRequest (streamType :: GRPCMethodType) request response where
ClientNormalRequest :: request -> TimeoutSeconds -> MetadataMap -> ClientRequest 'Normal request response
ClientWriterRequest :: TimeoutSeconds -> MetadataMap -> (StreamSend request -> IO ()) -> ClientRequest 'ClientStreaming request response
ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response
ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response
data ClientResult (streamType :: GRPCMethodType) response where
ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response
ClientWriterResponse :: Maybe response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ClientStreaming response
ClientReaderResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ServerStreaming response
ClientBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'BiDiStreaming response
ClientError :: ClientError -> ClientResult streamType response
class ClientRegisterable (methodType :: GRPCMethodType) where
clientRegisterMethod :: LL.Client
-> LL.MethodName
-> IO (RegisteredMethod methodType request response)
instance ClientRegisterable 'Normal where
clientRegisterMethod client methodName =
RegisteredMethod <$> LL.clientRegisterMethodNormal client methodName
instance ClientRegisterable 'ClientStreaming where
clientRegisterMethod client methodName =
RegisteredMethod <$> LL.clientRegisterMethodClientStreaming client methodName
instance ClientRegisterable 'ServerStreaming where
clientRegisterMethod client methodName =
RegisteredMethod <$> LL.clientRegisterMethodServerStreaming client methodName
instance ClientRegisterable 'BiDiStreaming where
clientRegisterMethod client methodName =
RegisteredMethod <$> LL.clientRegisterMethodBiDiStreaming client methodName
clientRequest :: (Message request, Message response) =>
LL.Client -> RegisteredMethod streamType request response
-> ClientRequest streamType request response -> IO (ClientResult streamType response)
clientRequest client (RegisteredMethod method) (ClientNormalRequest req timeout meta) =
mkResponse <$> LL.clientRequest client method timeout (BL.toStrict (toLazyByteString req)) meta
where
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
mkResponse (Right rsp) =
case fromByteString (LL.rspBody rsp) of
Left err -> ClientError (ClientErrorNoParse err)
Right parsedRsp ->
ClientNormalResponse parsedRsp (LL.initMD rsp) (LL.trailMD rsp) (LL.rspCode rsp) (LL.details rsp)
clientRequest client (RegisteredMethod method) (ClientWriterRequest timeout meta handler) =
mkResponse <$> LL.clientWriter client method timeout meta (handler . convertSend)
where
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
mkResponse (Right (rsp_, initMD_, trailMD_, rspCode_, details_)) =
case maybe (Right Nothing) (fmap Just . fromByteString) rsp_ of
Left err -> ClientError (ClientErrorNoParse err)
Right parsedRsp ->
ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_
clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) =
mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\m recv -> handler m (convertRecv recv))
where
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
mkResponse (Right (meta_, rspCode_, details_)) =
ClientReaderResponse meta_ rspCode_ details_
clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) =
mkResponse <$> LL.clientRW client method timeout meta (\m recv send writesDone -> handler meta (convertRecv recv) (convertSend send) writesDone)
where
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
mkResponse (Right (meta_, rspCode_, details_)) =
ClientBiDiResponse meta_ rspCode_ details_

View file

@ -1,3 +1,4 @@
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}

View file

@ -16,16 +16,46 @@ import Data.Protobuf.Wire.Class
import Network.GRPC.LowLevel import Network.GRPC.LowLevel
import System.IO import System.IO
type ServerHandler a b type ServerCallMetadata = ServerCall ()
= ServerCall a
type ServiceServer service = service ServerRequest ServerResponse
data ServerRequest (streamType :: GRPCMethodType) request response where
ServerNormalRequest :: ServerCallMetadata -> request -> ServerRequest 'Normal request response
ServerReaderRequest :: ServerCallMetadata -> StreamRecv request -> ServerRequest 'ClientStreaming request response
ServerWriterRequest :: ServerCallMetadata -> request -> StreamSend response -> ServerRequest 'ServerStreaming request response
ServerBiDiRequest :: ServerCallMetadata -> StreamRecv request -> StreamSend response -> ServerRequest 'BiDiStreaming request response
data ServerResponse (streamType :: GRPCMethodType) response where
ServerNormalResponse :: response -> MetadataMap -> StatusCode -> StatusDetails
-> ServerResponse 'Normal response
ServerReaderResponse :: Maybe response -> MetadataMap -> StatusCode -> StatusDetails
-> ServerResponse 'ClientStreaming response
ServerWriterResponse :: MetadataMap -> StatusCode -> StatusDetails
-> ServerResponse 'ServerStreaming response
ServerBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails
-> ServerResponse 'BiDiStreaming response
type ServerHandler a b =
ServerCall a
-> IO (b, MetadataMap, StatusCode, StatusDetails) -> IO (b, MetadataMap, StatusCode, StatusDetails)
convertGeneratedServerHandler ::
(Message request, Message response)
=> (ServerRequest 'Normal request response -> IO (ServerResponse 'Normal response))
-> ServerHandler request response
convertGeneratedServerHandler handler call =
do let call' = call { payload = () }
ServerNormalResponse rsp meta stsCode stsDetails <-
handler (ServerNormalRequest call' (payload call))
return (rsp, meta, stsCode, stsDetails)
convertServerHandler :: (Message a, Message b) convertServerHandler :: (Message a, Message b)
=> ServerHandler a b => ServerHandler a b
-> ServerHandlerLL -> ServerHandlerLL
convertServerHandler f c = case fromByteString (payload c) of convertServerHandler f c = case fromByteString (payload c) of
Left x -> CE.throw (GRPCIODecodeError x) Left x -> CE.throw (GRPCIODecodeError x)
Right x -> do (y, tm, sc, sd) <- f (const x <$> c) Right x -> do (y, tm, sc, sd) <- f (fmap (const x) c)
return (toBS y, tm, sc, sd) return (toBS y, tm, sc, sd)
type ServerReaderHandler a b type ServerReaderHandler a b
@ -33,10 +63,20 @@ type ServerReaderHandler a b
-> StreamRecv a -> StreamRecv a
-> IO (Maybe b, MetadataMap, StatusCode, StatusDetails) -> IO (Maybe b, MetadataMap, StatusCode, StatusDetails)
convertGeneratedServerReaderHandler ::
(Message request, Message response)
=> (ServerRequest 'ClientStreaming request response -> IO (ServerResponse 'ClientStreaming response))
-> ServerReaderHandler request response
convertGeneratedServerReaderHandler handler call recv =
do ServerReaderResponse rsp meta stsCode stsDetails <-
handler (ServerReaderRequest call recv)
return (rsp, meta, stsCode, stsDetails)
convertServerReaderHandler :: (Message a, Message b) convertServerReaderHandler :: (Message a, Message b)
=> ServerReaderHandler a b => ServerReaderHandler a b
-> ServerReaderHandlerLL -> ServerReaderHandlerLL
convertServerReaderHandler f c recv = serialize <$> f c (convertRecv recv) convertServerReaderHandler f c recv =
serialize <$> f c (convertRecv recv)
where where
serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd) serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd)
@ -45,10 +85,21 @@ type ServerWriterHandler a b =
-> StreamSend b -> StreamSend b
-> IO (MetadataMap, StatusCode, StatusDetails) -> IO (MetadataMap, StatusCode, StatusDetails)
convertServerWriterHandler :: (Message a, Message b) convertGeneratedServerWriterHandler ::
=> ServerWriterHandler a b (Message request, Message response)
=> (ServerRequest 'ServerStreaming request response -> IO (ServerResponse 'ServerStreaming response))
-> ServerWriterHandler request response
convertGeneratedServerWriterHandler handler call send =
do let call' = call { payload = () }
ServerWriterResponse meta stsCode stsDetails <-
handler (ServerWriterRequest call' (payload call) send)
return (meta, stsCode, stsDetails)
convertServerWriterHandler :: (Message a, Message b) =>
ServerWriterHandler a b
-> ServerWriterHandlerLL -> ServerWriterHandlerLL
convertServerWriterHandler f c send = f (convert <$> c) (convertSend send) convertServerWriterHandler f c send =
f (convert <$> c) (convertSend send)
where where
convert bs = case fromByteString bs of convert bs = case fromByteString bs of
Left x -> CE.throw (GRPCIODecodeError x) Left x -> CE.throw (GRPCIODecodeError x)
@ -60,10 +111,20 @@ type ServerRWHandler a b
-> StreamSend b -> StreamSend b
-> IO (MetadataMap, StatusCode, StatusDetails) -> IO (MetadataMap, StatusCode, StatusDetails)
convertGeneratedServerRWHandler ::
(Message request, Message response)
=> (ServerRequest 'BiDiStreaming request response -> IO (ServerResponse 'BiDiStreaming response))
-> ServerRWHandler request response
convertGeneratedServerRWHandler handler call recv send =
do ServerBiDiResponse meta stsCode stsDetails <-
handler (ServerBiDiRequest call recv send)
return (meta, stsCode, stsDetails)
convertServerRWHandler :: (Message a, Message b) convertServerRWHandler :: (Message a, Message b)
=> ServerRWHandler a b => ServerRWHandler a b
-> ServerRWHandlerLL -> ServerRWHandlerLL
convertServerRWHandler f c r s = f c (convertRecv r) (convertSend s) convertServerRWHandler f c recv send =
f c (convertRecv recv) (convertSend send)
convertRecv :: Message a => StreamRecv ByteString -> StreamRecv a convertRecv :: Message a => StreamRecv ByteString -> StreamRecv a
convertRecv = convertRecv =

View file

@ -7,9 +7,10 @@ resolver: lts-7.4
# Local packages, usually specified by relative directory name # Local packages, usually specified by relative directory name
packages: packages:
- '.' - '.'
- location: #- location:
git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git # git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git
commit: 2eda7a9e33e8a2f32c3ab8c4ace338b08fb79daa # commit: 2eda7a9e33e8a2f32c3ab8c4ace338b08fb79daa
- location: '../protobuf-wire'
extra-dep: true extra-dep: true
- location: - location:
git: git@github.com:awakenetworks/proto3-wire.git git: git@github.com:awakenetworks/proto3-wire.git

View file

@ -13,7 +13,8 @@ import Turtle
generatedTests :: TestTree generatedTests :: TestTree
generatedTests = testGroup "Code generator tests" generatedTests = testGroup "Code generator tests"
[ testServerGeneration ] [ testServerGeneration
, testClientGeneration ]
testServerGeneration :: TestTree testServerGeneration :: TestTree
testServerGeneration = testCase "server generation" $ do testServerGeneration = testCase "server generation" $ do
@ -43,6 +44,34 @@ testServerGeneration = testCase "server generation" $ do
rmtree hsTmpDir rmtree hsTmpDir
rmtree pyTmpDir rmtree pyTmpDir
testClientGeneration :: TestTree
testClientGeneration = testCase "client generation" $ do
mktree hsTmpDir
mktree pyTmpDir
compileSimpleDotProto
exitCode <- proc "tests/simple-client.sh" [hsTmpDir] empty
exitCode @?= ExitSuccess
exitCode <- proc "tests/protoc.sh" [pyTmpDir] empty
exitCode @?= ExitSuccess
runManaged $ do
serverExitCodeA <- fork
(export "PYTHONPATH" pyTmpDir >> shell "python tests/test-server.py" empty)
clientExitCodeA <- fork (shell (hsTmpDir <> "/simple-client") empty)
liftIO $ do
serverExitCode <- liftIO (wait serverExitCodeA)
clientExitCode <- liftIO (wait clientExitCodeA)
serverExitCode @?= ExitSuccess
clientExitCode @?= ExitSuccess
rmtree hsTmpDir
rmtree pyTmpDir
hsTmpDir, pyTmpDir :: IsString a => a hsTmpDir, pyTmpDir :: IsString a => a
hsTmpDir = "tests/tmp" hsTmpDir = "tests/tmp"
pyTmpDir = "tests/py-tmp" pyTmpDir = "tests/py-tmp"

135
tests/TestClient.hs Normal file
View file

@ -0,0 +1,135 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Prelude hiding (sum)
import Simple
import Control.Concurrent
import Control.Concurrent.MVar
import Control.Monad
import Control.Monad.IO.Class
import Control.Exception
import Control.Arrow
import Data.Monoid
import Data.Foldable (sum)
import Data.String
import Data.Word
import Data.Vector (fromList)
import Data.Protobuf.Wire
import Network.GRPC.LowLevel
import Network.GRPC.HighLevel.Client
import System.Random
import Test.Tasty
import Test.Tasty.HUnit ((@?=), assertString, testCase)
testNormalCall client = testCase "Normal call" $
do randoms <- fromList <$> replicateM 1000 (Fixed <$> randomRIO (1, 1000))
let req = SimpleServiceRequest "NormalRequest" randoms
res <- simpleServiceNormalCall client
(ClientNormalRequest req 10 mempty)
case res of
ClientError err -> assertString ("ClientError: " <> show err)
ClientNormalResponse res _ _ stsCode _ ->
do stsCode @?= StatusOk
simpleServiceResponseResponse res @?= "NormalRequest"
simpleServiceResponseNum res @?= sum randoms
testClientStreamingCall client = testCase "Client-streaming call" $
do iterationCount <- randomRIO (5, 50)
v <- newEmptyMVar
res <- simpleServiceClientStreamingCall client . ClientWriterRequest 10 mempty $ \send ->
do (finalName, totalSum) <-
fmap ((mconcat *** (sum . mconcat)) . unzip) .
replicateM iterationCount $
do randoms <- fromList <$> replicateM 1000 (Fixed <$> randomRIO (1, 1000))
name <- fromString <$> replicateM 10 (randomRIO ('a', 'z'))
send (SimpleServiceRequest name randoms)
pure (name, randoms)
putMVar v (finalName, totalSum)
(finalName, totalSum) <- readMVar v
case res of
ClientError err -> assertString ("ClientError: " <> show err)
ClientWriterResponse Nothing _ _ _ _ -> assertString "No response received"
ClientWriterResponse (Just res) _ _ stsCode _ ->
do stsCode @?= StatusOk
simpleServiceResponseResponse res @?= finalName
simpleServiceResponseNum res @?= totalSum
testServerStreamingCall client = testCase "Server-streaming call" $
do numCount <- randomRIO (50, 500)
nums <- replicateM numCount (Fixed <$> randomIO)
let checkResults [] recv =
do res <- recv
case res of
Left err -> assertString ("recv error: " <> show err)
Right Nothing -> pure ()
Right (Just _) -> assertString "recv: elements past end of stream"
checkResults (expNum:nums) recv =
do res <- recv
case res of
Left err -> assertString ("recv error: " <> show err)
Right Nothing -> assertString ("recv: stream ended earlier than expected")
Right (Just (SimpleServiceResponse response num)) ->
do response @?= "Test"
num @?= expNum
checkResults nums recv
res <- simpleServiceServerStreamingCall client $
ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty
(\_ -> checkResults nums)
case res of
ClientError err -> assertString ("ClientError: " <> show err)
ClientReaderResponse _ sts _ ->
sts @?= StatusOk
testBiDiStreamingCall client = testCase "Bidi-streaming call" $
do let handleRequests (0 :: Int) _ _ done = done >> pure ()
handleRequests n recv send done =
do numCount <- randomRIO (10, 1000)
nums <- fromList <$> replicateM numCount (Fixed <$> randomRIO (1, 1000))
testName <- fromString <$> replicateM 10 (randomRIO ('a', 'z'))
send (SimpleServiceRequest testName nums)
res <- recv
case res of
Left err -> assertString ("recv error: " <> show err)
Right Nothing -> pure ()
Right (Just (SimpleServiceResponse name total)) ->
do name @?= testName
total @?= sum nums
handleRequests (n - 1) recv send done
iterations <- randomRIO (50, 500)
res <- simpleServiceBiDiStreamingCall client $
ClientBiDiRequest 10 mempty (\_ -> handleRequests iterations)
case res of
ClientError err -> assertString ("ClientError: " <> show err)
ClientBiDiResponse _ sts _ ->
sts @?= StatusOk
main :: IO ()
main = do
threadDelay 10000000
withGRPC $ \grpc ->
withClient grpc (ClientConfig "localhost" 50051 [] Nothing) $ \client ->
do service <- simpleServiceClient client
(defaultMain $ testGroup "Send gRPC requests"
[ testNormalCall service
, testClientStreamingCall service
, testServerStreamingCall service
, testBiDiStreamingCall service ]) `finally`
(simpleServiceDone service (ClientNormalRequest SimpleServiceDone 10 mempty))

View file

@ -1,4 +1,7 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
module Main where module Main where
import Prelude hiding (sum) import Prelude hiding (sum)
@ -15,48 +18,46 @@ import Data.Foldable (sum)
import Data.String import Data.String
import Network.GRPC.LowLevel import Network.GRPC.LowLevel
import Network.GRPC.HighLevel.Server
handleNormalCall :: ServerCall SimpleServiceRequest -> IO (SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails) handleNormalCall :: ServerRequest 'Normal SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'Normal SimpleServiceResponse)
handleNormalCall call = handleNormalCall (ServerNormalRequest meta (SimpleServiceRequest request nums)) =
pure (SimpleServiceResponse request result, mempty, StatusOk, StatusDetails "") pure (ServerNormalResponse (SimpleServiceResponse request result) mempty StatusOk (StatusDetails ""))
where SimpleServiceRequest request nums = payload call where result = sum nums
result = sum nums handleClientStreamingCall :: ServerRequest 'ClientStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ClientStreaming SimpleServiceResponse)
handleClientStreamingCall (ServerReaderRequest call recvRequest) = go 0 ""
handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> IO (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
handleClientStreamingCall call recvRequest = go 0 ""
where go sumAccum nameAccum = where go sumAccum nameAccum =
recvRequest >>= \req -> recvRequest >>= \req ->
case req of case req of
Left ioError -> pure (Nothing, mempty, StatusCancelled, StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError))) Left ioError -> pure (ServerReaderResponse Nothing mempty StatusCancelled (StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError))))
Right Nothing -> Right Nothing ->
pure (Just (SimpleServiceResponse nameAccum sumAccum), mempty, StatusOk, StatusDetails "") pure (ServerReaderResponse (Just (SimpleServiceResponse nameAccum sumAccum)) mempty StatusOk (StatusDetails ""))
Right (Just (SimpleServiceRequest name nums)) -> Right (Just (SimpleServiceRequest name nums)) ->
go (sumAccum + sum nums) (nameAccum <> name) go (sumAccum + sum nums) (nameAccum <> name)
handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails) handleServerStreamingCall :: ServerRequest 'ServerStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ServerStreaming SimpleServiceResponse)
handleServerStreamingCall call sendResponse = go handleServerStreamingCall (ServerWriterRequest call (SimpleServiceRequest requestName nums) sendResponse) = go
where go = do forM_ nums $ \num -> where go = do forM_ nums $ \num ->
sendResponse (SimpleServiceResponse requestName num) sendResponse (SimpleServiceResponse requestName num)
pure (mempty, StatusOk, StatusDetails "") pure (ServerWriterResponse mempty StatusOk (StatusDetails ""))
SimpleServiceRequest requestName nums = payload call handleBiDiStreamingCall :: ServerRequest 'BiDiStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'BiDiStreaming SimpleServiceResponse)
handleBiDiStreamingCall (ServerBiDiRequest call recvRequest sendResponse) = go
handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails)
handleBiDiStreamingCall call recvRequest sendResponse = go
where go = recvRequest >>= \req -> where go = recvRequest >>= \req ->
case req of case req of
Left ioError -> pure (mempty, StatusCancelled, StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError))) Left ioError ->
pure (ServerBiDiResponse mempty StatusCancelled (StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError))))
Right Nothing -> Right Nothing ->
pure (mempty, StatusOk, StatusDetails "") pure (ServerBiDiResponse mempty StatusOk (StatusDetails ""))
Right (Just (SimpleServiceRequest name nums)) -> Right (Just (SimpleServiceRequest name nums)) ->
do sendResponse (SimpleServiceResponse name (sum nums)) do sendResponse (SimpleServiceResponse name (sum nums))
go go
handleDone :: MVar () -> ServerCall SimpleServiceDone -> IO (SimpleServiceDone, MetadataMap, StatusCode, StatusDetails) handleDone :: MVar () -> ServerRequest 'Normal SimpleServiceDone SimpleServiceDone -> IO (ServerResponse 'Normal SimpleServiceDone)
handleDone exitVar req = handleDone exitVar (ServerNormalRequest _ req) =
do forkIO (threadDelay 5000 >> putMVar exitVar ()) do forkIO (threadDelay 5000 >> putMVar exitVar ())
pure (payload req, mempty, StatusOk, StatusDetails "") pure (ServerNormalResponse req mempty StatusOk (StatusDetails ""))
main :: IO () main :: IO ()
main = do exitVar <- newEmptyMVar main = do exitVar <- newEmptyMVar

13
tests/simple-client.sh Executable file
View file

@ -0,0 +1,13 @@
#!/bin/bash -eu
hsTmpDir=$1
stack ghc -- \
--make \
-threaded \
-odir $hsTmpDir \
-hidir $hsTmpDir \
-o $hsTmpDir/simple-client \
$hsTmpDir/Simple.hs \
tests/TestClient.hs \
> /dev/null

41
tests/test-server.py Normal file
View file

@ -0,0 +1,41 @@
from simple_pb2 import *
from uuid import uuid4
import random
import Queue
import grpc
print "Starting python server"
done_queue = Queue.Queue()
class SimpleServiceServer(BetaSimpleServiceServicer):
def done(self, request, context):
global server
done_queue.put_nowait(())
return SimpleServiceDone()
def normalCall(self, request, context):
return SimpleServiceResponse(response = "NormalRequest", num = sum(request.num))
def clientStreamingCall(self, requests, context):
cur_name = ""
cur_sum = 0
for request in requests:
cur_name += request.request
cur_sum += sum(request.num)
return SimpleServiceResponse(response = cur_name, num = cur_sum)
def serverStreamingCall(self, request, context):
for num in request.num:
yield SimpleServiceResponse(response = request.request, num = num)
def biDiStreamingCall(self, requests, context):
for request in requests:
yield SimpleServiceResponse(response = request.request, num = sum(request.num))
server = beta_create_SimpleService_server(SimpleServiceServer())
server.add_insecure_port('[::]:50051')
server.start()
done_queue.get()