mirror of
https://github.com/unclechu/gRPC-haskell.git
synced 2024-11-14 07:09:41 +01:00
GADT-based high level interface
This commit is contained in:
parent
507edf803f
commit
26dc36dc64
10 changed files with 440 additions and 35 deletions
|
@ -72,6 +72,7 @@ library
|
|||
Network.GRPC.HighLevel.Generated
|
||||
Network.GRPC.HighLevel.Server
|
||||
Network.GRPC.HighLevel.Server.Unregistered
|
||||
Network.GRPC.HighLevel.Client
|
||||
extra-libraries:
|
||||
grpc
|
||||
includes:
|
||||
|
|
122
src/Network/GRPC/HighLevel/Client.hs
Normal file
122
src/Network/GRPC/HighLevel/Client.hs
Normal file
|
@ -0,0 +1,122 @@
|
|||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
|
||||
module Network.GRPC.HighLevel.Client
|
||||
( RegisteredMethod
|
||||
, TimeoutSeconds
|
||||
, MetadataMap(..)
|
||||
, StatusDetails(..)
|
||||
, GRPCMethodType(..)
|
||||
, StreamRecv
|
||||
, StreamSend
|
||||
, WritesDone
|
||||
, LL.Client
|
||||
|
||||
, ServiceClient
|
||||
, ClientRequest(..)
|
||||
, ClientResult(..)
|
||||
-- , ClientResponse, response, initMD, trailMD, rspCode, details
|
||||
|
||||
, ClientRegisterable(..)
|
||||
|
||||
, clientRequest ) where
|
||||
|
||||
import qualified Network.GRPC.LowLevel.Client as LL
|
||||
import qualified Network.GRPC.LowLevel.Call as LL
|
||||
import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds)
|
||||
import Network.GRPC.LowLevel ( GRPCMethodType(..)
|
||||
, StatusCode(..)
|
||||
, StatusDetails(..)
|
||||
, MetadataMap(..)
|
||||
, GRPCIOError(..)
|
||||
, StreamRecv
|
||||
, StreamSend )
|
||||
import Network.GRPC.LowLevel.Op (WritesDone)
|
||||
import Network.GRPC.HighLevel.Server (convertRecv, convertSend)
|
||||
|
||||
import Data.Protobuf.Wire (Message, toLazyByteString, fromByteString)
|
||||
import Proto3.Wire.Decode (ParseError)
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
|
||||
newtype RegisteredMethod (mt :: GRPCMethodType) request response
|
||||
= RegisteredMethod (LL.RegisteredMethod mt)
|
||||
deriving Show
|
||||
|
||||
type ServiceClient service = service ClientRequest ClientResult
|
||||
|
||||
data ClientError
|
||||
= ClientErrorNoParse ParseError
|
||||
| ClientIOError GRPCIOError
|
||||
deriving (Show, Eq)
|
||||
|
||||
data ClientRequest (streamType :: GRPCMethodType) request response where
|
||||
ClientNormalRequest :: request -> TimeoutSeconds -> MetadataMap -> ClientRequest 'Normal request response
|
||||
ClientWriterRequest :: TimeoutSeconds -> MetadataMap -> (StreamSend request -> IO ()) -> ClientRequest 'ClientStreaming request response
|
||||
ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response
|
||||
ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response
|
||||
|
||||
data ClientResult (streamType :: GRPCMethodType) response where
|
||||
ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response
|
||||
ClientWriterResponse :: Maybe response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ClientStreaming response
|
||||
ClientReaderResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ServerStreaming response
|
||||
ClientBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'BiDiStreaming response
|
||||
|
||||
ClientError :: ClientError -> ClientResult streamType response
|
||||
|
||||
class ClientRegisterable (methodType :: GRPCMethodType) where
|
||||
clientRegisterMethod :: LL.Client
|
||||
-> LL.MethodName
|
||||
-> IO (RegisteredMethod methodType request response)
|
||||
|
||||
instance ClientRegisterable 'Normal where
|
||||
clientRegisterMethod client methodName =
|
||||
RegisteredMethod <$> LL.clientRegisterMethodNormal client methodName
|
||||
|
||||
instance ClientRegisterable 'ClientStreaming where
|
||||
clientRegisterMethod client methodName =
|
||||
RegisteredMethod <$> LL.clientRegisterMethodClientStreaming client methodName
|
||||
|
||||
instance ClientRegisterable 'ServerStreaming where
|
||||
clientRegisterMethod client methodName =
|
||||
RegisteredMethod <$> LL.clientRegisterMethodServerStreaming client methodName
|
||||
|
||||
instance ClientRegisterable 'BiDiStreaming where
|
||||
clientRegisterMethod client methodName =
|
||||
RegisteredMethod <$> LL.clientRegisterMethodBiDiStreaming client methodName
|
||||
|
||||
clientRequest :: (Message request, Message response) =>
|
||||
LL.Client -> RegisteredMethod streamType request response
|
||||
-> ClientRequest streamType request response -> IO (ClientResult streamType response)
|
||||
clientRequest client (RegisteredMethod method) (ClientNormalRequest req timeout meta) =
|
||||
mkResponse <$> LL.clientRequest client method timeout (BL.toStrict (toLazyByteString req)) meta
|
||||
where
|
||||
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
|
||||
mkResponse (Right rsp) =
|
||||
case fromByteString (LL.rspBody rsp) of
|
||||
Left err -> ClientError (ClientErrorNoParse err)
|
||||
Right parsedRsp ->
|
||||
ClientNormalResponse parsedRsp (LL.initMD rsp) (LL.trailMD rsp) (LL.rspCode rsp) (LL.details rsp)
|
||||
clientRequest client (RegisteredMethod method) (ClientWriterRequest timeout meta handler) =
|
||||
mkResponse <$> LL.clientWriter client method timeout meta (handler . convertSend)
|
||||
where
|
||||
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
|
||||
mkResponse (Right (rsp_, initMD_, trailMD_, rspCode_, details_)) =
|
||||
case maybe (Right Nothing) (fmap Just . fromByteString) rsp_ of
|
||||
Left err -> ClientError (ClientErrorNoParse err)
|
||||
Right parsedRsp ->
|
||||
ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_
|
||||
clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) =
|
||||
mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\m recv -> handler m (convertRecv recv))
|
||||
where
|
||||
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
|
||||
mkResponse (Right (meta_, rspCode_, details_)) =
|
||||
ClientReaderResponse meta_ rspCode_ details_
|
||||
clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) =
|
||||
mkResponse <$> LL.clientRW client method timeout meta (\m recv send writesDone -> handler meta (convertRecv recv) (convertSend send) writesDone)
|
||||
where
|
||||
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
|
||||
mkResponse (Right (meta_, rspCode_, details_)) =
|
||||
ClientBiDiResponse meta_ rspCode_ details_
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
|
|
@ -16,16 +16,46 @@ import Data.Protobuf.Wire.Class
|
|||
import Network.GRPC.LowLevel
|
||||
import System.IO
|
||||
|
||||
type ServerHandler a b
|
||||
= ServerCall a
|
||||
type ServerCallMetadata = ServerCall ()
|
||||
|
||||
type ServiceServer service = service ServerRequest ServerResponse
|
||||
|
||||
data ServerRequest (streamType :: GRPCMethodType) request response where
|
||||
ServerNormalRequest :: ServerCallMetadata -> request -> ServerRequest 'Normal request response
|
||||
ServerReaderRequest :: ServerCallMetadata -> StreamRecv request -> ServerRequest 'ClientStreaming request response
|
||||
ServerWriterRequest :: ServerCallMetadata -> request -> StreamSend response -> ServerRequest 'ServerStreaming request response
|
||||
ServerBiDiRequest :: ServerCallMetadata -> StreamRecv request -> StreamSend response -> ServerRequest 'BiDiStreaming request response
|
||||
|
||||
data ServerResponse (streamType :: GRPCMethodType) response where
|
||||
ServerNormalResponse :: response -> MetadataMap -> StatusCode -> StatusDetails
|
||||
-> ServerResponse 'Normal response
|
||||
ServerReaderResponse :: Maybe response -> MetadataMap -> StatusCode -> StatusDetails
|
||||
-> ServerResponse 'ClientStreaming response
|
||||
ServerWriterResponse :: MetadataMap -> StatusCode -> StatusDetails
|
||||
-> ServerResponse 'ServerStreaming response
|
||||
ServerBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails
|
||||
-> ServerResponse 'BiDiStreaming response
|
||||
|
||||
type ServerHandler a b =
|
||||
ServerCall a
|
||||
-> IO (b, MetadataMap, StatusCode, StatusDetails)
|
||||
|
||||
convertGeneratedServerHandler ::
|
||||
(Message request, Message response)
|
||||
=> (ServerRequest 'Normal request response -> IO (ServerResponse 'Normal response))
|
||||
-> ServerHandler request response
|
||||
convertGeneratedServerHandler handler call =
|
||||
do let call' = call { payload = () }
|
||||
ServerNormalResponse rsp meta stsCode stsDetails <-
|
||||
handler (ServerNormalRequest call' (payload call))
|
||||
return (rsp, meta, stsCode, stsDetails)
|
||||
|
||||
convertServerHandler :: (Message a, Message b)
|
||||
=> ServerHandler a b
|
||||
-> ServerHandlerLL
|
||||
convertServerHandler f c = case fromByteString (payload c) of
|
||||
Left x -> CE.throw (GRPCIODecodeError x)
|
||||
Right x -> do (y, tm, sc, sd) <- f (const x <$> c)
|
||||
Right x -> do (y, tm, sc, sd) <- f (fmap (const x) c)
|
||||
return (toBS y, tm, sc, sd)
|
||||
|
||||
type ServerReaderHandler a b
|
||||
|
@ -33,10 +63,20 @@ type ServerReaderHandler a b
|
|||
-> StreamRecv a
|
||||
-> IO (Maybe b, MetadataMap, StatusCode, StatusDetails)
|
||||
|
||||
convertGeneratedServerReaderHandler ::
|
||||
(Message request, Message response)
|
||||
=> (ServerRequest 'ClientStreaming request response -> IO (ServerResponse 'ClientStreaming response))
|
||||
-> ServerReaderHandler request response
|
||||
convertGeneratedServerReaderHandler handler call recv =
|
||||
do ServerReaderResponse rsp meta stsCode stsDetails <-
|
||||
handler (ServerReaderRequest call recv)
|
||||
return (rsp, meta, stsCode, stsDetails)
|
||||
|
||||
convertServerReaderHandler :: (Message a, Message b)
|
||||
=> ServerReaderHandler a b
|
||||
-> ServerReaderHandlerLL
|
||||
convertServerReaderHandler f c recv = serialize <$> f c (convertRecv recv)
|
||||
convertServerReaderHandler f c recv =
|
||||
serialize <$> f c (convertRecv recv)
|
||||
where
|
||||
serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd)
|
||||
|
||||
|
@ -45,10 +85,21 @@ type ServerWriterHandler a b =
|
|||
-> StreamSend b
|
||||
-> IO (MetadataMap, StatusCode, StatusDetails)
|
||||
|
||||
convertServerWriterHandler :: (Message a, Message b)
|
||||
=> ServerWriterHandler a b
|
||||
-> ServerWriterHandlerLL
|
||||
convertServerWriterHandler f c send = f (convert <$> c) (convertSend send)
|
||||
convertGeneratedServerWriterHandler ::
|
||||
(Message request, Message response)
|
||||
=> (ServerRequest 'ServerStreaming request response -> IO (ServerResponse 'ServerStreaming response))
|
||||
-> ServerWriterHandler request response
|
||||
convertGeneratedServerWriterHandler handler call send =
|
||||
do let call' = call { payload = () }
|
||||
ServerWriterResponse meta stsCode stsDetails <-
|
||||
handler (ServerWriterRequest call' (payload call) send)
|
||||
return (meta, stsCode, stsDetails)
|
||||
|
||||
convertServerWriterHandler :: (Message a, Message b) =>
|
||||
ServerWriterHandler a b
|
||||
-> ServerWriterHandlerLL
|
||||
convertServerWriterHandler f c send =
|
||||
f (convert <$> c) (convertSend send)
|
||||
where
|
||||
convert bs = case fromByteString bs of
|
||||
Left x -> CE.throw (GRPCIODecodeError x)
|
||||
|
@ -60,10 +111,20 @@ type ServerRWHandler a b
|
|||
-> StreamSend b
|
||||
-> IO (MetadataMap, StatusCode, StatusDetails)
|
||||
|
||||
convertGeneratedServerRWHandler ::
|
||||
(Message request, Message response)
|
||||
=> (ServerRequest 'BiDiStreaming request response -> IO (ServerResponse 'BiDiStreaming response))
|
||||
-> ServerRWHandler request response
|
||||
convertGeneratedServerRWHandler handler call recv send =
|
||||
do ServerBiDiResponse meta stsCode stsDetails <-
|
||||
handler (ServerBiDiRequest call recv send)
|
||||
return (meta, stsCode, stsDetails)
|
||||
|
||||
convertServerRWHandler :: (Message a, Message b)
|
||||
=> ServerRWHandler a b
|
||||
-> ServerRWHandlerLL
|
||||
convertServerRWHandler f c r s = f c (convertRecv r) (convertSend s)
|
||||
convertServerRWHandler f c recv send =
|
||||
f c (convertRecv recv) (convertSend send)
|
||||
|
||||
convertRecv :: Message a => StreamRecv ByteString -> StreamRecv a
|
||||
convertRecv =
|
||||
|
|
|
@ -7,9 +7,10 @@ resolver: lts-7.4
|
|||
# Local packages, usually specified by relative directory name
|
||||
packages:
|
||||
- '.'
|
||||
- location:
|
||||
git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git
|
||||
commit: 2eda7a9e33e8a2f32c3ab8c4ace338b08fb79daa
|
||||
#- location:
|
||||
# git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git
|
||||
# commit: 2eda7a9e33e8a2f32c3ab8c4ace338b08fb79daa
|
||||
- location: '../protobuf-wire'
|
||||
extra-dep: true
|
||||
- location:
|
||||
git: git@github.com:awakenetworks/proto3-wire.git
|
||||
|
|
|
@ -13,7 +13,8 @@ import Turtle
|
|||
|
||||
generatedTests :: TestTree
|
||||
generatedTests = testGroup "Code generator tests"
|
||||
[ testServerGeneration ]
|
||||
[ testServerGeneration
|
||||
, testClientGeneration ]
|
||||
|
||||
testServerGeneration :: TestTree
|
||||
testServerGeneration = testCase "server generation" $ do
|
||||
|
@ -43,6 +44,34 @@ testServerGeneration = testCase "server generation" $ do
|
|||
rmtree hsTmpDir
|
||||
rmtree pyTmpDir
|
||||
|
||||
testClientGeneration :: TestTree
|
||||
testClientGeneration = testCase "client generation" $ do
|
||||
mktree hsTmpDir
|
||||
mktree pyTmpDir
|
||||
|
||||
compileSimpleDotProto
|
||||
|
||||
exitCode <- proc "tests/simple-client.sh" [hsTmpDir] empty
|
||||
exitCode @?= ExitSuccess
|
||||
|
||||
exitCode <- proc "tests/protoc.sh" [pyTmpDir] empty
|
||||
exitCode @?= ExitSuccess
|
||||
|
||||
runManaged $ do
|
||||
serverExitCodeA <- fork
|
||||
(export "PYTHONPATH" pyTmpDir >> shell "python tests/test-server.py" empty)
|
||||
clientExitCodeA <- fork (shell (hsTmpDir <> "/simple-client") empty)
|
||||
|
||||
liftIO $ do
|
||||
serverExitCode <- liftIO (wait serverExitCodeA)
|
||||
clientExitCode <- liftIO (wait clientExitCodeA)
|
||||
|
||||
serverExitCode @?= ExitSuccess
|
||||
clientExitCode @?= ExitSuccess
|
||||
|
||||
rmtree hsTmpDir
|
||||
rmtree pyTmpDir
|
||||
|
||||
hsTmpDir, pyTmpDir :: IsString a => a
|
||||
hsTmpDir = "tests/tmp"
|
||||
pyTmpDir = "tests/py-tmp"
|
||||
|
|
135
tests/TestClient.hs
Normal file
135
tests/TestClient.hs
Normal file
|
@ -0,0 +1,135 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Prelude hiding (sum)
|
||||
|
||||
import Simple
|
||||
|
||||
import Control.Concurrent
|
||||
import Control.Concurrent.MVar
|
||||
import Control.Monad
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Exception
|
||||
import Control.Arrow
|
||||
|
||||
import Data.Monoid
|
||||
import Data.Foldable (sum)
|
||||
import Data.String
|
||||
import Data.Word
|
||||
import Data.Vector (fromList)
|
||||
import Data.Protobuf.Wire
|
||||
|
||||
import Network.GRPC.LowLevel
|
||||
import Network.GRPC.HighLevel.Client
|
||||
|
||||
import System.Random
|
||||
|
||||
import Test.Tasty
|
||||
import Test.Tasty.HUnit ((@?=), assertString, testCase)
|
||||
|
||||
testNormalCall client = testCase "Normal call" $
|
||||
do randoms <- fromList <$> replicateM 1000 (Fixed <$> randomRIO (1, 1000))
|
||||
let req = SimpleServiceRequest "NormalRequest" randoms
|
||||
res <- simpleServiceNormalCall client
|
||||
(ClientNormalRequest req 10 mempty)
|
||||
case res of
|
||||
ClientError err -> assertString ("ClientError: " <> show err)
|
||||
ClientNormalResponse res _ _ stsCode _ ->
|
||||
do stsCode @?= StatusOk
|
||||
simpleServiceResponseResponse res @?= "NormalRequest"
|
||||
simpleServiceResponseNum res @?= sum randoms
|
||||
|
||||
testClientStreamingCall client = testCase "Client-streaming call" $
|
||||
do iterationCount <- randomRIO (5, 50)
|
||||
v <- newEmptyMVar
|
||||
res <- simpleServiceClientStreamingCall client . ClientWriterRequest 10 mempty $ \send ->
|
||||
do (finalName, totalSum) <-
|
||||
fmap ((mconcat *** (sum . mconcat)) . unzip) .
|
||||
replicateM iterationCount $
|
||||
do randoms <- fromList <$> replicateM 1000 (Fixed <$> randomRIO (1, 1000))
|
||||
name <- fromString <$> replicateM 10 (randomRIO ('a', 'z'))
|
||||
send (SimpleServiceRequest name randoms)
|
||||
pure (name, randoms)
|
||||
putMVar v (finalName, totalSum)
|
||||
|
||||
(finalName, totalSum) <- readMVar v
|
||||
case res of
|
||||
ClientError err -> assertString ("ClientError: " <> show err)
|
||||
ClientWriterResponse Nothing _ _ _ _ -> assertString "No response received"
|
||||
ClientWriterResponse (Just res) _ _ stsCode _ ->
|
||||
do stsCode @?= StatusOk
|
||||
simpleServiceResponseResponse res @?= finalName
|
||||
simpleServiceResponseNum res @?= totalSum
|
||||
|
||||
testServerStreamingCall client = testCase "Server-streaming call" $
|
||||
do numCount <- randomRIO (50, 500)
|
||||
nums <- replicateM numCount (Fixed <$> randomIO)
|
||||
|
||||
let checkResults [] recv =
|
||||
do res <- recv
|
||||
case res of
|
||||
Left err -> assertString ("recv error: " <> show err)
|
||||
Right Nothing -> pure ()
|
||||
Right (Just _) -> assertString "recv: elements past end of stream"
|
||||
checkResults (expNum:nums) recv =
|
||||
do res <- recv
|
||||
case res of
|
||||
Left err -> assertString ("recv error: " <> show err)
|
||||
Right Nothing -> assertString ("recv: stream ended earlier than expected")
|
||||
Right (Just (SimpleServiceResponse response num)) ->
|
||||
do response @?= "Test"
|
||||
num @?= expNum
|
||||
checkResults nums recv
|
||||
res <- simpleServiceServerStreamingCall client $
|
||||
ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty
|
||||
(\_ -> checkResults nums)
|
||||
case res of
|
||||
ClientError err -> assertString ("ClientError: " <> show err)
|
||||
ClientReaderResponse _ sts _ ->
|
||||
sts @?= StatusOk
|
||||
|
||||
testBiDiStreamingCall client = testCase "Bidi-streaming call" $
|
||||
do let handleRequests (0 :: Int) _ _ done = done >> pure ()
|
||||
handleRequests n recv send done =
|
||||
do numCount <- randomRIO (10, 1000)
|
||||
nums <- fromList <$> replicateM numCount (Fixed <$> randomRIO (1, 1000))
|
||||
testName <- fromString <$> replicateM 10 (randomRIO ('a', 'z'))
|
||||
send (SimpleServiceRequest testName nums)
|
||||
|
||||
res <- recv
|
||||
case res of
|
||||
Left err -> assertString ("recv error: " <> show err)
|
||||
Right Nothing -> pure ()
|
||||
Right (Just (SimpleServiceResponse name total)) ->
|
||||
do name @?= testName
|
||||
total @?= sum nums
|
||||
handleRequests (n - 1) recv send done
|
||||
|
||||
iterations <- randomRIO (50, 500)
|
||||
|
||||
res <- simpleServiceBiDiStreamingCall client $
|
||||
ClientBiDiRequest 10 mempty (\_ -> handleRequests iterations)
|
||||
case res of
|
||||
ClientError err -> assertString ("ClientError: " <> show err)
|
||||
ClientBiDiResponse _ sts _ ->
|
||||
sts @?= StatusOk
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
threadDelay 10000000
|
||||
withGRPC $ \grpc ->
|
||||
withClient grpc (ClientConfig "localhost" 50051 [] Nothing) $ \client ->
|
||||
do service <- simpleServiceClient client
|
||||
|
||||
(defaultMain $ testGroup "Send gRPC requests"
|
||||
[ testNormalCall service
|
||||
, testClientStreamingCall service
|
||||
, testServerStreamingCall service
|
||||
, testBiDiStreamingCall service ]) `finally`
|
||||
(simpleServiceDone service (ClientNormalRequest SimpleServiceDone 10 mempty))
|
||||
|
|
@ -1,4 +1,7 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Prelude hiding (sum)
|
||||
|
@ -15,48 +18,46 @@ import Data.Foldable (sum)
|
|||
import Data.String
|
||||
|
||||
import Network.GRPC.LowLevel
|
||||
import Network.GRPC.HighLevel.Server
|
||||
|
||||
handleNormalCall :: ServerCall SimpleServiceRequest -> IO (SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
|
||||
handleNormalCall call =
|
||||
pure (SimpleServiceResponse request result, mempty, StatusOk, StatusDetails "")
|
||||
where SimpleServiceRequest request nums = payload call
|
||||
handleNormalCall :: ServerRequest 'Normal SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'Normal SimpleServiceResponse)
|
||||
handleNormalCall (ServerNormalRequest meta (SimpleServiceRequest request nums)) =
|
||||
pure (ServerNormalResponse (SimpleServiceResponse request result) mempty StatusOk (StatusDetails ""))
|
||||
where result = sum nums
|
||||
|
||||
result = sum nums
|
||||
|
||||
handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> IO (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
|
||||
handleClientStreamingCall call recvRequest = go 0 ""
|
||||
handleClientStreamingCall :: ServerRequest 'ClientStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ClientStreaming SimpleServiceResponse)
|
||||
handleClientStreamingCall (ServerReaderRequest call recvRequest) = go 0 ""
|
||||
where go sumAccum nameAccum =
|
||||
recvRequest >>= \req ->
|
||||
case req of
|
||||
Left ioError -> pure (Nothing, mempty, StatusCancelled, StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError)))
|
||||
Left ioError -> pure (ServerReaderResponse Nothing mempty StatusCancelled (StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError))))
|
||||
Right Nothing ->
|
||||
pure (Just (SimpleServiceResponse nameAccum sumAccum), mempty, StatusOk, StatusDetails "")
|
||||
pure (ServerReaderResponse (Just (SimpleServiceResponse nameAccum sumAccum)) mempty StatusOk (StatusDetails ""))
|
||||
Right (Just (SimpleServiceRequest name nums)) ->
|
||||
go (sumAccum + sum nums) (nameAccum <> name)
|
||||
|
||||
handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails)
|
||||
handleServerStreamingCall call sendResponse = go
|
||||
handleServerStreamingCall :: ServerRequest 'ServerStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ServerStreaming SimpleServiceResponse)
|
||||
handleServerStreamingCall (ServerWriterRequest call (SimpleServiceRequest requestName nums) sendResponse) = go
|
||||
where go = do forM_ nums $ \num ->
|
||||
sendResponse (SimpleServiceResponse requestName num)
|
||||
pure (mempty, StatusOk, StatusDetails "")
|
||||
pure (ServerWriterResponse mempty StatusOk (StatusDetails ""))
|
||||
|
||||
SimpleServiceRequest requestName nums = payload call
|
||||
|
||||
handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails)
|
||||
handleBiDiStreamingCall call recvRequest sendResponse = go
|
||||
handleBiDiStreamingCall :: ServerRequest 'BiDiStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'BiDiStreaming SimpleServiceResponse)
|
||||
handleBiDiStreamingCall (ServerBiDiRequest call recvRequest sendResponse) = go
|
||||
where go = recvRequest >>= \req ->
|
||||
case req of
|
||||
Left ioError -> pure (mempty, StatusCancelled, StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError)))
|
||||
Left ioError ->
|
||||
pure (ServerBiDiResponse mempty StatusCancelled (StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError))))
|
||||
Right Nothing ->
|
||||
pure (mempty, StatusOk, StatusDetails "")
|
||||
pure (ServerBiDiResponse mempty StatusOk (StatusDetails ""))
|
||||
Right (Just (SimpleServiceRequest name nums)) ->
|
||||
do sendResponse (SimpleServiceResponse name (sum nums))
|
||||
go
|
||||
|
||||
handleDone :: MVar () -> ServerCall SimpleServiceDone -> IO (SimpleServiceDone, MetadataMap, StatusCode, StatusDetails)
|
||||
handleDone exitVar req =
|
||||
handleDone :: MVar () -> ServerRequest 'Normal SimpleServiceDone SimpleServiceDone -> IO (ServerResponse 'Normal SimpleServiceDone)
|
||||
handleDone exitVar (ServerNormalRequest _ req) =
|
||||
do forkIO (threadDelay 5000 >> putMVar exitVar ())
|
||||
pure (payload req, mempty, StatusOk, StatusDetails "")
|
||||
pure (ServerNormalResponse req mempty StatusOk (StatusDetails ""))
|
||||
|
||||
main :: IO ()
|
||||
main = do exitVar <- newEmptyMVar
|
||||
|
|
13
tests/simple-client.sh
Executable file
13
tests/simple-client.sh
Executable file
|
@ -0,0 +1,13 @@
|
|||
#!/bin/bash -eu
|
||||
|
||||
hsTmpDir=$1
|
||||
|
||||
stack ghc -- \
|
||||
--make \
|
||||
-threaded \
|
||||
-odir $hsTmpDir \
|
||||
-hidir $hsTmpDir \
|
||||
-o $hsTmpDir/simple-client \
|
||||
$hsTmpDir/Simple.hs \
|
||||
tests/TestClient.hs \
|
||||
> /dev/null
|
41
tests/test-server.py
Normal file
41
tests/test-server.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from simple_pb2 import *
|
||||
from uuid import uuid4
|
||||
import random
|
||||
import Queue
|
||||
import grpc
|
||||
|
||||
print "Starting python server"
|
||||
|
||||
done_queue = Queue.Queue()
|
||||
|
||||
class SimpleServiceServer(BetaSimpleServiceServicer):
|
||||
def done(self, request, context):
|
||||
global server
|
||||
done_queue.put_nowait(())
|
||||
|
||||
return SimpleServiceDone()
|
||||
|
||||
def normalCall(self, request, context):
|
||||
return SimpleServiceResponse(response = "NormalRequest", num = sum(request.num))
|
||||
|
||||
def clientStreamingCall(self, requests, context):
|
||||
cur_name = ""
|
||||
cur_sum = 0
|
||||
for request in requests:
|
||||
cur_name += request.request
|
||||
cur_sum += sum(request.num)
|
||||
return SimpleServiceResponse(response = cur_name, num = cur_sum)
|
||||
|
||||
def serverStreamingCall(self, request, context):
|
||||
for num in request.num:
|
||||
yield SimpleServiceResponse(response = request.request, num = num)
|
||||
|
||||
def biDiStreamingCall(self, requests, context):
|
||||
for request in requests:
|
||||
yield SimpleServiceResponse(response = request.request, num = sum(request.num))
|
||||
|
||||
server = beta_create_SimpleService_server(SimpleServiceServer())
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
|
||||
done_queue.get()
|
Loading…
Reference in a new issue