mirror of
https://github.com/unclechu/gRPC-haskell.git
synced 2024-11-14 23:29:42 +01:00
GADT-based high level interface
This commit is contained in:
parent
507edf803f
commit
26dc36dc64
10 changed files with 440 additions and 35 deletions
|
@ -72,6 +72,7 @@ library
|
||||||
Network.GRPC.HighLevel.Generated
|
Network.GRPC.HighLevel.Generated
|
||||||
Network.GRPC.HighLevel.Server
|
Network.GRPC.HighLevel.Server
|
||||||
Network.GRPC.HighLevel.Server.Unregistered
|
Network.GRPC.HighLevel.Server.Unregistered
|
||||||
|
Network.GRPC.HighLevel.Client
|
||||||
extra-libraries:
|
extra-libraries:
|
||||||
grpc
|
grpc
|
||||||
includes:
|
includes:
|
||||||
|
|
122
src/Network/GRPC/HighLevel/Client.hs
Normal file
122
src/Network/GRPC/HighLevel/Client.hs
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
|
||||||
|
module Network.GRPC.HighLevel.Client
|
||||||
|
( RegisteredMethod
|
||||||
|
, TimeoutSeconds
|
||||||
|
, MetadataMap(..)
|
||||||
|
, StatusDetails(..)
|
||||||
|
, GRPCMethodType(..)
|
||||||
|
, StreamRecv
|
||||||
|
, StreamSend
|
||||||
|
, WritesDone
|
||||||
|
, LL.Client
|
||||||
|
|
||||||
|
, ServiceClient
|
||||||
|
, ClientRequest(..)
|
||||||
|
, ClientResult(..)
|
||||||
|
-- , ClientResponse, response, initMD, trailMD, rspCode, details
|
||||||
|
|
||||||
|
, ClientRegisterable(..)
|
||||||
|
|
||||||
|
, clientRequest ) where
|
||||||
|
|
||||||
|
import qualified Network.GRPC.LowLevel.Client as LL
|
||||||
|
import qualified Network.GRPC.LowLevel.Call as LL
|
||||||
|
import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds)
|
||||||
|
import Network.GRPC.LowLevel ( GRPCMethodType(..)
|
||||||
|
, StatusCode(..)
|
||||||
|
, StatusDetails(..)
|
||||||
|
, MetadataMap(..)
|
||||||
|
, GRPCIOError(..)
|
||||||
|
, StreamRecv
|
||||||
|
, StreamSend )
|
||||||
|
import Network.GRPC.LowLevel.Op (WritesDone)
|
||||||
|
import Network.GRPC.HighLevel.Server (convertRecv, convertSend)
|
||||||
|
|
||||||
|
import Data.Protobuf.Wire (Message, toLazyByteString, fromByteString)
|
||||||
|
import Proto3.Wire.Decode (ParseError)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import qualified Data.ByteString.Lazy as BL
|
||||||
|
|
||||||
|
newtype RegisteredMethod (mt :: GRPCMethodType) request response
|
||||||
|
= RegisteredMethod (LL.RegisteredMethod mt)
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
type ServiceClient service = service ClientRequest ClientResult
|
||||||
|
|
||||||
|
data ClientError
|
||||||
|
= ClientErrorNoParse ParseError
|
||||||
|
| ClientIOError GRPCIOError
|
||||||
|
deriving (Show, Eq)
|
||||||
|
|
||||||
|
data ClientRequest (streamType :: GRPCMethodType) request response where
|
||||||
|
ClientNormalRequest :: request -> TimeoutSeconds -> MetadataMap -> ClientRequest 'Normal request response
|
||||||
|
ClientWriterRequest :: TimeoutSeconds -> MetadataMap -> (StreamSend request -> IO ()) -> ClientRequest 'ClientStreaming request response
|
||||||
|
ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response
|
||||||
|
ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response
|
||||||
|
|
||||||
|
data ClientResult (streamType :: GRPCMethodType) response where
|
||||||
|
ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response
|
||||||
|
ClientWriterResponse :: Maybe response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ClientStreaming response
|
||||||
|
ClientReaderResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ServerStreaming response
|
||||||
|
ClientBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'BiDiStreaming response
|
||||||
|
|
||||||
|
ClientError :: ClientError -> ClientResult streamType response
|
||||||
|
|
||||||
|
class ClientRegisterable (methodType :: GRPCMethodType) where
|
||||||
|
clientRegisterMethod :: LL.Client
|
||||||
|
-> LL.MethodName
|
||||||
|
-> IO (RegisteredMethod methodType request response)
|
||||||
|
|
||||||
|
instance ClientRegisterable 'Normal where
|
||||||
|
clientRegisterMethod client methodName =
|
||||||
|
RegisteredMethod <$> LL.clientRegisterMethodNormal client methodName
|
||||||
|
|
||||||
|
instance ClientRegisterable 'ClientStreaming where
|
||||||
|
clientRegisterMethod client methodName =
|
||||||
|
RegisteredMethod <$> LL.clientRegisterMethodClientStreaming client methodName
|
||||||
|
|
||||||
|
instance ClientRegisterable 'ServerStreaming where
|
||||||
|
clientRegisterMethod client methodName =
|
||||||
|
RegisteredMethod <$> LL.clientRegisterMethodServerStreaming client methodName
|
||||||
|
|
||||||
|
instance ClientRegisterable 'BiDiStreaming where
|
||||||
|
clientRegisterMethod client methodName =
|
||||||
|
RegisteredMethod <$> LL.clientRegisterMethodBiDiStreaming client methodName
|
||||||
|
|
||||||
|
clientRequest :: (Message request, Message response) =>
|
||||||
|
LL.Client -> RegisteredMethod streamType request response
|
||||||
|
-> ClientRequest streamType request response -> IO (ClientResult streamType response)
|
||||||
|
clientRequest client (RegisteredMethod method) (ClientNormalRequest req timeout meta) =
|
||||||
|
mkResponse <$> LL.clientRequest client method timeout (BL.toStrict (toLazyByteString req)) meta
|
||||||
|
where
|
||||||
|
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
|
||||||
|
mkResponse (Right rsp) =
|
||||||
|
case fromByteString (LL.rspBody rsp) of
|
||||||
|
Left err -> ClientError (ClientErrorNoParse err)
|
||||||
|
Right parsedRsp ->
|
||||||
|
ClientNormalResponse parsedRsp (LL.initMD rsp) (LL.trailMD rsp) (LL.rspCode rsp) (LL.details rsp)
|
||||||
|
clientRequest client (RegisteredMethod method) (ClientWriterRequest timeout meta handler) =
|
||||||
|
mkResponse <$> LL.clientWriter client method timeout meta (handler . convertSend)
|
||||||
|
where
|
||||||
|
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
|
||||||
|
mkResponse (Right (rsp_, initMD_, trailMD_, rspCode_, details_)) =
|
||||||
|
case maybe (Right Nothing) (fmap Just . fromByteString) rsp_ of
|
||||||
|
Left err -> ClientError (ClientErrorNoParse err)
|
||||||
|
Right parsedRsp ->
|
||||||
|
ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_
|
||||||
|
clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) =
|
||||||
|
mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\m recv -> handler m (convertRecv recv))
|
||||||
|
where
|
||||||
|
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
|
||||||
|
mkResponse (Right (meta_, rspCode_, details_)) =
|
||||||
|
ClientReaderResponse meta_ rspCode_ details_
|
||||||
|
clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) =
|
||||||
|
mkResponse <$> LL.clientRW client method timeout meta (\m recv send writesDone -> handler meta (convertRecv recv) (convertSend send) writesDone)
|
||||||
|
where
|
||||||
|
mkResponse (Left ioError_) = ClientError (ClientIOError ioError_)
|
||||||
|
mkResponse (Right (meta_, rspCode_, details_)) =
|
||||||
|
ClientBiDiResponse meta_ rspCode_ details_
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
|
|
@ -16,16 +16,46 @@ import Data.Protobuf.Wire.Class
|
||||||
import Network.GRPC.LowLevel
|
import Network.GRPC.LowLevel
|
||||||
import System.IO
|
import System.IO
|
||||||
|
|
||||||
type ServerHandler a b
|
type ServerCallMetadata = ServerCall ()
|
||||||
= ServerCall a
|
|
||||||
|
type ServiceServer service = service ServerRequest ServerResponse
|
||||||
|
|
||||||
|
data ServerRequest (streamType :: GRPCMethodType) request response where
|
||||||
|
ServerNormalRequest :: ServerCallMetadata -> request -> ServerRequest 'Normal request response
|
||||||
|
ServerReaderRequest :: ServerCallMetadata -> StreamRecv request -> ServerRequest 'ClientStreaming request response
|
||||||
|
ServerWriterRequest :: ServerCallMetadata -> request -> StreamSend response -> ServerRequest 'ServerStreaming request response
|
||||||
|
ServerBiDiRequest :: ServerCallMetadata -> StreamRecv request -> StreamSend response -> ServerRequest 'BiDiStreaming request response
|
||||||
|
|
||||||
|
data ServerResponse (streamType :: GRPCMethodType) response where
|
||||||
|
ServerNormalResponse :: response -> MetadataMap -> StatusCode -> StatusDetails
|
||||||
|
-> ServerResponse 'Normal response
|
||||||
|
ServerReaderResponse :: Maybe response -> MetadataMap -> StatusCode -> StatusDetails
|
||||||
|
-> ServerResponse 'ClientStreaming response
|
||||||
|
ServerWriterResponse :: MetadataMap -> StatusCode -> StatusDetails
|
||||||
|
-> ServerResponse 'ServerStreaming response
|
||||||
|
ServerBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails
|
||||||
|
-> ServerResponse 'BiDiStreaming response
|
||||||
|
|
||||||
|
type ServerHandler a b =
|
||||||
|
ServerCall a
|
||||||
-> IO (b, MetadataMap, StatusCode, StatusDetails)
|
-> IO (b, MetadataMap, StatusCode, StatusDetails)
|
||||||
|
|
||||||
|
convertGeneratedServerHandler ::
|
||||||
|
(Message request, Message response)
|
||||||
|
=> (ServerRequest 'Normal request response -> IO (ServerResponse 'Normal response))
|
||||||
|
-> ServerHandler request response
|
||||||
|
convertGeneratedServerHandler handler call =
|
||||||
|
do let call' = call { payload = () }
|
||||||
|
ServerNormalResponse rsp meta stsCode stsDetails <-
|
||||||
|
handler (ServerNormalRequest call' (payload call))
|
||||||
|
return (rsp, meta, stsCode, stsDetails)
|
||||||
|
|
||||||
convertServerHandler :: (Message a, Message b)
|
convertServerHandler :: (Message a, Message b)
|
||||||
=> ServerHandler a b
|
=> ServerHandler a b
|
||||||
-> ServerHandlerLL
|
-> ServerHandlerLL
|
||||||
convertServerHandler f c = case fromByteString (payload c) of
|
convertServerHandler f c = case fromByteString (payload c) of
|
||||||
Left x -> CE.throw (GRPCIODecodeError x)
|
Left x -> CE.throw (GRPCIODecodeError x)
|
||||||
Right x -> do (y, tm, sc, sd) <- f (const x <$> c)
|
Right x -> do (y, tm, sc, sd) <- f (fmap (const x) c)
|
||||||
return (toBS y, tm, sc, sd)
|
return (toBS y, tm, sc, sd)
|
||||||
|
|
||||||
type ServerReaderHandler a b
|
type ServerReaderHandler a b
|
||||||
|
@ -33,10 +63,20 @@ type ServerReaderHandler a b
|
||||||
-> StreamRecv a
|
-> StreamRecv a
|
||||||
-> IO (Maybe b, MetadataMap, StatusCode, StatusDetails)
|
-> IO (Maybe b, MetadataMap, StatusCode, StatusDetails)
|
||||||
|
|
||||||
|
convertGeneratedServerReaderHandler ::
|
||||||
|
(Message request, Message response)
|
||||||
|
=> (ServerRequest 'ClientStreaming request response -> IO (ServerResponse 'ClientStreaming response))
|
||||||
|
-> ServerReaderHandler request response
|
||||||
|
convertGeneratedServerReaderHandler handler call recv =
|
||||||
|
do ServerReaderResponse rsp meta stsCode stsDetails <-
|
||||||
|
handler (ServerReaderRequest call recv)
|
||||||
|
return (rsp, meta, stsCode, stsDetails)
|
||||||
|
|
||||||
convertServerReaderHandler :: (Message a, Message b)
|
convertServerReaderHandler :: (Message a, Message b)
|
||||||
=> ServerReaderHandler a b
|
=> ServerReaderHandler a b
|
||||||
-> ServerReaderHandlerLL
|
-> ServerReaderHandlerLL
|
||||||
convertServerReaderHandler f c recv = serialize <$> f c (convertRecv recv)
|
convertServerReaderHandler f c recv =
|
||||||
|
serialize <$> f c (convertRecv recv)
|
||||||
where
|
where
|
||||||
serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd)
|
serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd)
|
||||||
|
|
||||||
|
@ -45,10 +85,21 @@ type ServerWriterHandler a b =
|
||||||
-> StreamSend b
|
-> StreamSend b
|
||||||
-> IO (MetadataMap, StatusCode, StatusDetails)
|
-> IO (MetadataMap, StatusCode, StatusDetails)
|
||||||
|
|
||||||
convertServerWriterHandler :: (Message a, Message b)
|
convertGeneratedServerWriterHandler ::
|
||||||
=> ServerWriterHandler a b
|
(Message request, Message response)
|
||||||
-> ServerWriterHandlerLL
|
=> (ServerRequest 'ServerStreaming request response -> IO (ServerResponse 'ServerStreaming response))
|
||||||
convertServerWriterHandler f c send = f (convert <$> c) (convertSend send)
|
-> ServerWriterHandler request response
|
||||||
|
convertGeneratedServerWriterHandler handler call send =
|
||||||
|
do let call' = call { payload = () }
|
||||||
|
ServerWriterResponse meta stsCode stsDetails <-
|
||||||
|
handler (ServerWriterRequest call' (payload call) send)
|
||||||
|
return (meta, stsCode, stsDetails)
|
||||||
|
|
||||||
|
convertServerWriterHandler :: (Message a, Message b) =>
|
||||||
|
ServerWriterHandler a b
|
||||||
|
-> ServerWriterHandlerLL
|
||||||
|
convertServerWriterHandler f c send =
|
||||||
|
f (convert <$> c) (convertSend send)
|
||||||
where
|
where
|
||||||
convert bs = case fromByteString bs of
|
convert bs = case fromByteString bs of
|
||||||
Left x -> CE.throw (GRPCIODecodeError x)
|
Left x -> CE.throw (GRPCIODecodeError x)
|
||||||
|
@ -60,10 +111,20 @@ type ServerRWHandler a b
|
||||||
-> StreamSend b
|
-> StreamSend b
|
||||||
-> IO (MetadataMap, StatusCode, StatusDetails)
|
-> IO (MetadataMap, StatusCode, StatusDetails)
|
||||||
|
|
||||||
|
convertGeneratedServerRWHandler ::
|
||||||
|
(Message request, Message response)
|
||||||
|
=> (ServerRequest 'BiDiStreaming request response -> IO (ServerResponse 'BiDiStreaming response))
|
||||||
|
-> ServerRWHandler request response
|
||||||
|
convertGeneratedServerRWHandler handler call recv send =
|
||||||
|
do ServerBiDiResponse meta stsCode stsDetails <-
|
||||||
|
handler (ServerBiDiRequest call recv send)
|
||||||
|
return (meta, stsCode, stsDetails)
|
||||||
|
|
||||||
convertServerRWHandler :: (Message a, Message b)
|
convertServerRWHandler :: (Message a, Message b)
|
||||||
=> ServerRWHandler a b
|
=> ServerRWHandler a b
|
||||||
-> ServerRWHandlerLL
|
-> ServerRWHandlerLL
|
||||||
convertServerRWHandler f c r s = f c (convertRecv r) (convertSend s)
|
convertServerRWHandler f c recv send =
|
||||||
|
f c (convertRecv recv) (convertSend send)
|
||||||
|
|
||||||
convertRecv :: Message a => StreamRecv ByteString -> StreamRecv a
|
convertRecv :: Message a => StreamRecv ByteString -> StreamRecv a
|
||||||
convertRecv =
|
convertRecv =
|
||||||
|
|
|
@ -7,9 +7,10 @@ resolver: lts-7.4
|
||||||
# Local packages, usually specified by relative directory name
|
# Local packages, usually specified by relative directory name
|
||||||
packages:
|
packages:
|
||||||
- '.'
|
- '.'
|
||||||
- location:
|
#- location:
|
||||||
git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git
|
# git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git
|
||||||
commit: 2eda7a9e33e8a2f32c3ab8c4ace338b08fb79daa
|
# commit: 2eda7a9e33e8a2f32c3ab8c4ace338b08fb79daa
|
||||||
|
- location: '../protobuf-wire'
|
||||||
extra-dep: true
|
extra-dep: true
|
||||||
- location:
|
- location:
|
||||||
git: git@github.com:awakenetworks/proto3-wire.git
|
git: git@github.com:awakenetworks/proto3-wire.git
|
||||||
|
|
|
@ -13,7 +13,8 @@ import Turtle
|
||||||
|
|
||||||
generatedTests :: TestTree
|
generatedTests :: TestTree
|
||||||
generatedTests = testGroup "Code generator tests"
|
generatedTests = testGroup "Code generator tests"
|
||||||
[ testServerGeneration ]
|
[ testServerGeneration
|
||||||
|
, testClientGeneration ]
|
||||||
|
|
||||||
testServerGeneration :: TestTree
|
testServerGeneration :: TestTree
|
||||||
testServerGeneration = testCase "server generation" $ do
|
testServerGeneration = testCase "server generation" $ do
|
||||||
|
@ -43,6 +44,34 @@ testServerGeneration = testCase "server generation" $ do
|
||||||
rmtree hsTmpDir
|
rmtree hsTmpDir
|
||||||
rmtree pyTmpDir
|
rmtree pyTmpDir
|
||||||
|
|
||||||
|
testClientGeneration :: TestTree
|
||||||
|
testClientGeneration = testCase "client generation" $ do
|
||||||
|
mktree hsTmpDir
|
||||||
|
mktree pyTmpDir
|
||||||
|
|
||||||
|
compileSimpleDotProto
|
||||||
|
|
||||||
|
exitCode <- proc "tests/simple-client.sh" [hsTmpDir] empty
|
||||||
|
exitCode @?= ExitSuccess
|
||||||
|
|
||||||
|
exitCode <- proc "tests/protoc.sh" [pyTmpDir] empty
|
||||||
|
exitCode @?= ExitSuccess
|
||||||
|
|
||||||
|
runManaged $ do
|
||||||
|
serverExitCodeA <- fork
|
||||||
|
(export "PYTHONPATH" pyTmpDir >> shell "python tests/test-server.py" empty)
|
||||||
|
clientExitCodeA <- fork (shell (hsTmpDir <> "/simple-client") empty)
|
||||||
|
|
||||||
|
liftIO $ do
|
||||||
|
serverExitCode <- liftIO (wait serverExitCodeA)
|
||||||
|
clientExitCode <- liftIO (wait clientExitCodeA)
|
||||||
|
|
||||||
|
serverExitCode @?= ExitSuccess
|
||||||
|
clientExitCode @?= ExitSuccess
|
||||||
|
|
||||||
|
rmtree hsTmpDir
|
||||||
|
rmtree pyTmpDir
|
||||||
|
|
||||||
hsTmpDir, pyTmpDir :: IsString a => a
|
hsTmpDir, pyTmpDir :: IsString a => a
|
||||||
hsTmpDir = "tests/tmp"
|
hsTmpDir = "tests/tmp"
|
||||||
pyTmpDir = "tests/py-tmp"
|
pyTmpDir = "tests/py-tmp"
|
||||||
|
|
135
tests/TestClient.hs
Normal file
135
tests/TestClient.hs
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Prelude hiding (sum)
|
||||||
|
|
||||||
|
import Simple
|
||||||
|
|
||||||
|
import Control.Concurrent
|
||||||
|
import Control.Concurrent.MVar
|
||||||
|
import Control.Monad
|
||||||
|
import Control.Monad.IO.Class
|
||||||
|
import Control.Exception
|
||||||
|
import Control.Arrow
|
||||||
|
|
||||||
|
import Data.Monoid
|
||||||
|
import Data.Foldable (sum)
|
||||||
|
import Data.String
|
||||||
|
import Data.Word
|
||||||
|
import Data.Vector (fromList)
|
||||||
|
import Data.Protobuf.Wire
|
||||||
|
|
||||||
|
import Network.GRPC.LowLevel
|
||||||
|
import Network.GRPC.HighLevel.Client
|
||||||
|
|
||||||
|
import System.Random
|
||||||
|
|
||||||
|
import Test.Tasty
|
||||||
|
import Test.Tasty.HUnit ((@?=), assertString, testCase)
|
||||||
|
|
||||||
|
testNormalCall client = testCase "Normal call" $
|
||||||
|
do randoms <- fromList <$> replicateM 1000 (Fixed <$> randomRIO (1, 1000))
|
||||||
|
let req = SimpleServiceRequest "NormalRequest" randoms
|
||||||
|
res <- simpleServiceNormalCall client
|
||||||
|
(ClientNormalRequest req 10 mempty)
|
||||||
|
case res of
|
||||||
|
ClientError err -> assertString ("ClientError: " <> show err)
|
||||||
|
ClientNormalResponse res _ _ stsCode _ ->
|
||||||
|
do stsCode @?= StatusOk
|
||||||
|
simpleServiceResponseResponse res @?= "NormalRequest"
|
||||||
|
simpleServiceResponseNum res @?= sum randoms
|
||||||
|
|
||||||
|
testClientStreamingCall client = testCase "Client-streaming call" $
|
||||||
|
do iterationCount <- randomRIO (5, 50)
|
||||||
|
v <- newEmptyMVar
|
||||||
|
res <- simpleServiceClientStreamingCall client . ClientWriterRequest 10 mempty $ \send ->
|
||||||
|
do (finalName, totalSum) <-
|
||||||
|
fmap ((mconcat *** (sum . mconcat)) . unzip) .
|
||||||
|
replicateM iterationCount $
|
||||||
|
do randoms <- fromList <$> replicateM 1000 (Fixed <$> randomRIO (1, 1000))
|
||||||
|
name <- fromString <$> replicateM 10 (randomRIO ('a', 'z'))
|
||||||
|
send (SimpleServiceRequest name randoms)
|
||||||
|
pure (name, randoms)
|
||||||
|
putMVar v (finalName, totalSum)
|
||||||
|
|
||||||
|
(finalName, totalSum) <- readMVar v
|
||||||
|
case res of
|
||||||
|
ClientError err -> assertString ("ClientError: " <> show err)
|
||||||
|
ClientWriterResponse Nothing _ _ _ _ -> assertString "No response received"
|
||||||
|
ClientWriterResponse (Just res) _ _ stsCode _ ->
|
||||||
|
do stsCode @?= StatusOk
|
||||||
|
simpleServiceResponseResponse res @?= finalName
|
||||||
|
simpleServiceResponseNum res @?= totalSum
|
||||||
|
|
||||||
|
testServerStreamingCall client = testCase "Server-streaming call" $
|
||||||
|
do numCount <- randomRIO (50, 500)
|
||||||
|
nums <- replicateM numCount (Fixed <$> randomIO)
|
||||||
|
|
||||||
|
let checkResults [] recv =
|
||||||
|
do res <- recv
|
||||||
|
case res of
|
||||||
|
Left err -> assertString ("recv error: " <> show err)
|
||||||
|
Right Nothing -> pure ()
|
||||||
|
Right (Just _) -> assertString "recv: elements past end of stream"
|
||||||
|
checkResults (expNum:nums) recv =
|
||||||
|
do res <- recv
|
||||||
|
case res of
|
||||||
|
Left err -> assertString ("recv error: " <> show err)
|
||||||
|
Right Nothing -> assertString ("recv: stream ended earlier than expected")
|
||||||
|
Right (Just (SimpleServiceResponse response num)) ->
|
||||||
|
do response @?= "Test"
|
||||||
|
num @?= expNum
|
||||||
|
checkResults nums recv
|
||||||
|
res <- simpleServiceServerStreamingCall client $
|
||||||
|
ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty
|
||||||
|
(\_ -> checkResults nums)
|
||||||
|
case res of
|
||||||
|
ClientError err -> assertString ("ClientError: " <> show err)
|
||||||
|
ClientReaderResponse _ sts _ ->
|
||||||
|
sts @?= StatusOk
|
||||||
|
|
||||||
|
testBiDiStreamingCall client = testCase "Bidi-streaming call" $
|
||||||
|
do let handleRequests (0 :: Int) _ _ done = done >> pure ()
|
||||||
|
handleRequests n recv send done =
|
||||||
|
do numCount <- randomRIO (10, 1000)
|
||||||
|
nums <- fromList <$> replicateM numCount (Fixed <$> randomRIO (1, 1000))
|
||||||
|
testName <- fromString <$> replicateM 10 (randomRIO ('a', 'z'))
|
||||||
|
send (SimpleServiceRequest testName nums)
|
||||||
|
|
||||||
|
res <- recv
|
||||||
|
case res of
|
||||||
|
Left err -> assertString ("recv error: " <> show err)
|
||||||
|
Right Nothing -> pure ()
|
||||||
|
Right (Just (SimpleServiceResponse name total)) ->
|
||||||
|
do name @?= testName
|
||||||
|
total @?= sum nums
|
||||||
|
handleRequests (n - 1) recv send done
|
||||||
|
|
||||||
|
iterations <- randomRIO (50, 500)
|
||||||
|
|
||||||
|
res <- simpleServiceBiDiStreamingCall client $
|
||||||
|
ClientBiDiRequest 10 mempty (\_ -> handleRequests iterations)
|
||||||
|
case res of
|
||||||
|
ClientError err -> assertString ("ClientError: " <> show err)
|
||||||
|
ClientBiDiResponse _ sts _ ->
|
||||||
|
sts @?= StatusOk
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = do
|
||||||
|
threadDelay 10000000
|
||||||
|
withGRPC $ \grpc ->
|
||||||
|
withClient grpc (ClientConfig "localhost" 50051 [] Nothing) $ \client ->
|
||||||
|
do service <- simpleServiceClient client
|
||||||
|
|
||||||
|
(defaultMain $ testGroup "Send gRPC requests"
|
||||||
|
[ testNormalCall service
|
||||||
|
, testClientStreamingCall service
|
||||||
|
, testServerStreamingCall service
|
||||||
|
, testBiDiStreamingCall service ]) `finally`
|
||||||
|
(simpleServiceDone service (ClientNormalRequest SimpleServiceDone 10 mempty))
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
|
||||||
module Main where
|
module Main where
|
||||||
|
|
||||||
import Prelude hiding (sum)
|
import Prelude hiding (sum)
|
||||||
|
@ -15,48 +18,46 @@ import Data.Foldable (sum)
|
||||||
import Data.String
|
import Data.String
|
||||||
|
|
||||||
import Network.GRPC.LowLevel
|
import Network.GRPC.LowLevel
|
||||||
|
import Network.GRPC.HighLevel.Server
|
||||||
|
|
||||||
handleNormalCall :: ServerCall SimpleServiceRequest -> IO (SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
|
handleNormalCall :: ServerRequest 'Normal SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'Normal SimpleServiceResponse)
|
||||||
handleNormalCall call =
|
handleNormalCall (ServerNormalRequest meta (SimpleServiceRequest request nums)) =
|
||||||
pure (SimpleServiceResponse request result, mempty, StatusOk, StatusDetails "")
|
pure (ServerNormalResponse (SimpleServiceResponse request result) mempty StatusOk (StatusDetails ""))
|
||||||
where SimpleServiceRequest request nums = payload call
|
where result = sum nums
|
||||||
|
|
||||||
result = sum nums
|
handleClientStreamingCall :: ServerRequest 'ClientStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ClientStreaming SimpleServiceResponse)
|
||||||
|
handleClientStreamingCall (ServerReaderRequest call recvRequest) = go 0 ""
|
||||||
handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> IO (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails)
|
|
||||||
handleClientStreamingCall call recvRequest = go 0 ""
|
|
||||||
where go sumAccum nameAccum =
|
where go sumAccum nameAccum =
|
||||||
recvRequest >>= \req ->
|
recvRequest >>= \req ->
|
||||||
case req of
|
case req of
|
||||||
Left ioError -> pure (Nothing, mempty, StatusCancelled, StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError)))
|
Left ioError -> pure (ServerReaderResponse Nothing mempty StatusCancelled (StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError))))
|
||||||
Right Nothing ->
|
Right Nothing ->
|
||||||
pure (Just (SimpleServiceResponse nameAccum sumAccum), mempty, StatusOk, StatusDetails "")
|
pure (ServerReaderResponse (Just (SimpleServiceResponse nameAccum sumAccum)) mempty StatusOk (StatusDetails ""))
|
||||||
Right (Just (SimpleServiceRequest name nums)) ->
|
Right (Just (SimpleServiceRequest name nums)) ->
|
||||||
go (sumAccum + sum nums) (nameAccum <> name)
|
go (sumAccum + sum nums) (nameAccum <> name)
|
||||||
|
|
||||||
handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails)
|
handleServerStreamingCall :: ServerRequest 'ServerStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ServerStreaming SimpleServiceResponse)
|
||||||
handleServerStreamingCall call sendResponse = go
|
handleServerStreamingCall (ServerWriterRequest call (SimpleServiceRequest requestName nums) sendResponse) = go
|
||||||
where go = do forM_ nums $ \num ->
|
where go = do forM_ nums $ \num ->
|
||||||
sendResponse (SimpleServiceResponse requestName num)
|
sendResponse (SimpleServiceResponse requestName num)
|
||||||
pure (mempty, StatusOk, StatusDetails "")
|
pure (ServerWriterResponse mempty StatusOk (StatusDetails ""))
|
||||||
|
|
||||||
SimpleServiceRequest requestName nums = payload call
|
handleBiDiStreamingCall :: ServerRequest 'BiDiStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'BiDiStreaming SimpleServiceResponse)
|
||||||
|
handleBiDiStreamingCall (ServerBiDiRequest call recvRequest sendResponse) = go
|
||||||
handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails)
|
|
||||||
handleBiDiStreamingCall call recvRequest sendResponse = go
|
|
||||||
where go = recvRequest >>= \req ->
|
where go = recvRequest >>= \req ->
|
||||||
case req of
|
case req of
|
||||||
Left ioError -> pure (mempty, StatusCancelled, StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError)))
|
Left ioError ->
|
||||||
|
pure (ServerBiDiResponse mempty StatusCancelled (StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError))))
|
||||||
Right Nothing ->
|
Right Nothing ->
|
||||||
pure (mempty, StatusOk, StatusDetails "")
|
pure (ServerBiDiResponse mempty StatusOk (StatusDetails ""))
|
||||||
Right (Just (SimpleServiceRequest name nums)) ->
|
Right (Just (SimpleServiceRequest name nums)) ->
|
||||||
do sendResponse (SimpleServiceResponse name (sum nums))
|
do sendResponse (SimpleServiceResponse name (sum nums))
|
||||||
go
|
go
|
||||||
|
|
||||||
handleDone :: MVar () -> ServerCall SimpleServiceDone -> IO (SimpleServiceDone, MetadataMap, StatusCode, StatusDetails)
|
handleDone :: MVar () -> ServerRequest 'Normal SimpleServiceDone SimpleServiceDone -> IO (ServerResponse 'Normal SimpleServiceDone)
|
||||||
handleDone exitVar req =
|
handleDone exitVar (ServerNormalRequest _ req) =
|
||||||
do forkIO (threadDelay 5000 >> putMVar exitVar ())
|
do forkIO (threadDelay 5000 >> putMVar exitVar ())
|
||||||
pure (payload req, mempty, StatusOk, StatusDetails "")
|
pure (ServerNormalResponse req mempty StatusOk (StatusDetails ""))
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = do exitVar <- newEmptyMVar
|
main = do exitVar <- newEmptyMVar
|
||||||
|
|
13
tests/simple-client.sh
Executable file
13
tests/simple-client.sh
Executable file
|
@ -0,0 +1,13 @@
|
||||||
|
#!/bin/bash -eu
|
||||||
|
|
||||||
|
hsTmpDir=$1
|
||||||
|
|
||||||
|
stack ghc -- \
|
||||||
|
--make \
|
||||||
|
-threaded \
|
||||||
|
-odir $hsTmpDir \
|
||||||
|
-hidir $hsTmpDir \
|
||||||
|
-o $hsTmpDir/simple-client \
|
||||||
|
$hsTmpDir/Simple.hs \
|
||||||
|
tests/TestClient.hs \
|
||||||
|
> /dev/null
|
41
tests/test-server.py
Normal file
41
tests/test-server.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
from simple_pb2 import *
|
||||||
|
from uuid import uuid4
|
||||||
|
import random
|
||||||
|
import Queue
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
print "Starting python server"
|
||||||
|
|
||||||
|
done_queue = Queue.Queue()
|
||||||
|
|
||||||
|
class SimpleServiceServer(BetaSimpleServiceServicer):
|
||||||
|
def done(self, request, context):
|
||||||
|
global server
|
||||||
|
done_queue.put_nowait(())
|
||||||
|
|
||||||
|
return SimpleServiceDone()
|
||||||
|
|
||||||
|
def normalCall(self, request, context):
|
||||||
|
return SimpleServiceResponse(response = "NormalRequest", num = sum(request.num))
|
||||||
|
|
||||||
|
def clientStreamingCall(self, requests, context):
|
||||||
|
cur_name = ""
|
||||||
|
cur_sum = 0
|
||||||
|
for request in requests:
|
||||||
|
cur_name += request.request
|
||||||
|
cur_sum += sum(request.num)
|
||||||
|
return SimpleServiceResponse(response = cur_name, num = cur_sum)
|
||||||
|
|
||||||
|
def serverStreamingCall(self, request, context):
|
||||||
|
for num in request.num:
|
||||||
|
yield SimpleServiceResponse(response = request.request, num = num)
|
||||||
|
|
||||||
|
def biDiStreamingCall(self, requests, context):
|
||||||
|
for request in requests:
|
||||||
|
yield SimpleServiceResponse(response = request.request, num = sum(request.num))
|
||||||
|
|
||||||
|
server = beta_create_SimpleService_server(SimpleServiceServer())
|
||||||
|
server.add_insecure_port('[::]:50051')
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
done_queue.get()
|
Loading…
Reference in a new issue