mirror of
https://github.com/unclechu/gRPC-haskell.git
synced 2024-11-26 21:19:43 +01:00
High-level unregistered concurrent interface (#41)
* remove parent ptr from unregistered calls -- unneeded * begin unregistered high level server loop * undo changes to highlevel server, add mkConfig for unregistered server * move call CQ create/destroy into call create/destroy * async normal call function * preliminary unregistered server loop for non-streaming methods * working unregistered highlevel example * loop counters for benchmarking * changes for benchmarking, add ruby example server for benchmarking * async version of withCall, refactor unregistered server loop to handle all method types * unregistered client streaming * add remaining streaming modes * unregistered server streaming test * unregistered streaming tests * add error logging * fix bug in add example * remove old TODOs * fix bug: don't assume slices are null-terminated * add TODO re: unregistered client streaming functions
This commit is contained in:
parent
9113e416e7
commit
e4a28e9e4b
24 changed files with 548 additions and 131 deletions
|
@ -50,10 +50,10 @@ uint8_t *gpr_slice_start_(gpr_slice *slice){
|
|||
return GPR_SLICE_START_PTR(*slice);
|
||||
}
|
||||
|
||||
gpr_slice* gpr_slice_from_copied_string_(const char *source){
|
||||
gpr_slice* gpr_slice_from_copied_buffer_(const char *source, size_t len){
|
||||
gpr_slice* retval = malloc(sizeof(gpr_slice));
|
||||
//note: 'gpr_slice_from_copied_string' handles allocating space for 'source'.
|
||||
*retval = gpr_slice_from_copied_string(source);
|
||||
*retval = gpr_slice_from_copied_buffer(source, len);
|
||||
return retval;
|
||||
}
|
||||
|
||||
|
|
|
@ -7,13 +7,16 @@
|
|||
import Control.Monad
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import Data.Protobuf.Wire.Class
|
||||
import Data.Protobuf.Wire.Types
|
||||
import qualified Data.Text as T
|
||||
import Data.Word
|
||||
import GHC.Generics (Generic)
|
||||
import Network.GRPC.LowLevel
|
||||
import qualified Network.GRPC.LowLevel.Client.Unregistered as U
|
||||
import Proto3.Wire.Decode (ParseError)
|
||||
|
||||
echoMethod = MethodName "/echo.Echo/DoEcho"
|
||||
addMethod = MethodName "/echo.Add/DoAdd"
|
||||
|
||||
_unregistered c = U.clientRequest c echoMethod 1 "hi" mempty
|
||||
|
||||
|
@ -30,9 +33,9 @@ regMain = withGRPC $ \g ->
|
|||
-- TODO: Put these in a common location (or just hack around it until CG is working)
|
||||
data EchoRequest = EchoRequest {message :: T.Text} deriving (Show, Eq, Ord, Generic)
|
||||
instance Message EchoRequest
|
||||
data AddRequest = AddRequest {addX :: Word32, addY :: Word32} deriving (Show, Eq, Ord, Generic)
|
||||
data AddRequest = AddRequest {addX :: Fixed Word32, addY :: Fixed Word32} deriving (Show, Eq, Ord, Generic)
|
||||
instance Message AddRequest
|
||||
data AddResponse = AddResponse {answer :: Word32} deriving (Show, Eq, Ord, Generic)
|
||||
data AddResponse = AddResponse {answer :: Fixed Word32} deriving (Show, Eq, Ord, Generic)
|
||||
instance Message AddResponse
|
||||
|
||||
-- TODO: Create Network.GRPC.HighLevel.Client w/ request variants
|
||||
|
@ -49,5 +52,15 @@ highlevelMain = withGRPC $ \g ->
|
|||
Right dec
|
||||
| dec == pay -> return ()
|
||||
| otherwise -> error $ "Got unexpected payload: " ++ show dec
|
||||
rmAdd <- clientRegisterMethodNormal c addMethod
|
||||
let addPay = AddRequest 1 2
|
||||
addEnc = BL.toStrict . toLazyByteString $ addPay
|
||||
replicateM_ 1 $ clientRequest c rmAdd 5 addEnc mempty >>= \case
|
||||
Left e -> error $ "Got client error on add request: " ++ show e
|
||||
Right r -> case fromByteString (rspBody r) of
|
||||
Left e -> error $ "failed to decode add response: " ++ show e
|
||||
Right dec
|
||||
| dec == AddResponse 3 -> return ()
|
||||
| otherwise -> error $ "Got wrong add answer: " ++ show dec
|
||||
|
||||
main = highlevelMain
|
||||
|
|
|
@ -52,22 +52,23 @@ private:
|
|||
};
|
||||
|
||||
int main(){
|
||||
/*
|
||||
|
||||
EchoClient client(grpc::CreateChannel("localhost:50051",
|
||||
grpc::InsecureChannelCredentials()));
|
||||
string msg("hi");
|
||||
for(int i = 0; i < 100000; i++){
|
||||
/*
|
||||
while(true){
|
||||
Status status = client.DoEcho(msg);
|
||||
if(!status.ok()){
|
||||
cout<<"Error: "<<status.error_code()<<endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
*/
|
||||
*/
|
||||
|
||||
AddClient client (grpc::CreateChannel("localhost:50051",
|
||||
AddClient addClient (grpc::CreateChannel("localhost:50051",
|
||||
grpc::InsecureChannelCredentials()));
|
||||
AddResponse answer = client.DoAdd(1,2);
|
||||
AddResponse answer = addClient.DoAdd(1,2);
|
||||
cout<<"Got answer: "<<answer.answer()<<endl;
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#include <string>
|
||||
#include <iostream>
|
||||
#include <atomic>
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
#include "echo.grpc.pb.h"
|
||||
|
@ -13,9 +15,15 @@ using grpc::Status;
|
|||
using echo::EchoRequest;
|
||||
using echo::Echo;
|
||||
|
||||
atomic_int reqCount;
|
||||
|
||||
class EchoServiceImpl final : public Echo::Service {
|
||||
Status DoEcho(ServerContext* ctx, const EchoRequest* req,
|
||||
EchoRequest* resp) override {
|
||||
reqCount++;
|
||||
if(reqCount % 100 == 0){
|
||||
cout<<reqCount<<endl;
|
||||
}
|
||||
resp->set_message(req->message());
|
||||
return Status::OK;
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ from grpc.beta import implementations
|
|||
import echo_pb2
|
||||
|
||||
def main():
|
||||
for _ in xrange(1000):
|
||||
channel = implementations.insecure_channel('localhost', 50051)
|
||||
stub = echo_pb2.beta_create_Echo_stub(channel)
|
||||
message = echo_pb2.EchoRequest(message='foo')
|
||||
|
|
26
examples/echo/echo-ruby/echo-server.rb
Normal file
26
examples/echo/echo-ruby/echo-server.rb
Normal file
|
@ -0,0 +1,26 @@
|
|||
this_dir = File.expand_path(File.dirname(__FILE__))
|
||||
$LOAD_PATH.unshift(this_dir)
|
||||
|
||||
require 'grpc'
|
||||
require 'echo_services'
|
||||
|
||||
$i = 0
|
||||
|
||||
class EchoServer < Echo::Echo::Service
|
||||
def do_echo(echo_req, _unused_call)
|
||||
$i = $i+1
|
||||
if $i % 100 == 0
|
||||
puts($i)
|
||||
end
|
||||
return echo_req
|
||||
end
|
||||
end
|
||||
|
||||
def main
|
||||
s = GRPC::RpcServer.new
|
||||
s.add_http2_port('0.0.0.0:50051', :this_port_is_insecure)
|
||||
s.handle(EchoServer)
|
||||
s.run_till_terminated
|
||||
end
|
||||
|
||||
main
|
23
examples/echo/echo-ruby/echo.rb
Normal file
23
examples/echo/echo-ruby/echo.rb
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: echo.proto
|
||||
|
||||
require 'google/protobuf'
|
||||
|
||||
Google::Protobuf::DescriptorPool.generated_pool.build do
|
||||
add_message "echo.EchoRequest" do
|
||||
optional :message, :string, 1
|
||||
end
|
||||
add_message "echo.AddRequest" do
|
||||
optional :addX, :fixed32, 1
|
||||
optional :addY, :fixed32, 2
|
||||
end
|
||||
add_message "echo.AddResponse" do
|
||||
optional :answer, :fixed32, 1
|
||||
end
|
||||
end
|
||||
|
||||
module Echo
|
||||
EchoRequest = Google::Protobuf::DescriptorPool.generated_pool.lookup("echo.EchoRequest").msgclass
|
||||
AddRequest = Google::Protobuf::DescriptorPool.generated_pool.lookup("echo.AddRequest").msgclass
|
||||
AddResponse = Google::Protobuf::DescriptorPool.generated_pool.lookup("echo.AddResponse").msgclass
|
||||
end
|
40
examples/echo/echo-ruby/echo_services.rb
Normal file
40
examples/echo/echo-ruby/echo_services.rb
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# Source: echo.proto for package 'echo'
|
||||
|
||||
require 'grpc'
|
||||
require 'echo'
|
||||
|
||||
module Echo
|
||||
module Echo
|
||||
|
||||
# TODO: add proto service documentation here
|
||||
class Service
|
||||
|
||||
include GRPC::GenericService
|
||||
|
||||
self.marshal_class_method = :encode
|
||||
self.unmarshal_class_method = :decode
|
||||
self.service_name = 'echo.Echo'
|
||||
|
||||
rpc :DoEcho, EchoRequest, EchoRequest
|
||||
end
|
||||
|
||||
Stub = Service.rpc_stub_class
|
||||
end
|
||||
module Add
|
||||
|
||||
# TODO: add proto service documentation here
|
||||
class Service
|
||||
|
||||
include GRPC::GenericService
|
||||
|
||||
self.marshal_class_method = :encode
|
||||
self.unmarshal_class_method = :decode
|
||||
self.service_name = 'echo.Add'
|
||||
|
||||
rpc :DoAdd, AddRequest, AddResponse
|
||||
end
|
||||
|
||||
Stub = Service.rpc_stub_class
|
||||
end
|
||||
end
|
|
@ -11,10 +11,12 @@ import Control.Concurrent.Async
|
|||
import Control.Monad
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Protobuf.Wire.Class
|
||||
import Data.Protobuf.Wire.Types
|
||||
import qualified Data.Text as T
|
||||
import Data.Word
|
||||
import GHC.Generics (Generic)
|
||||
import Network.GRPC.HighLevel.Server
|
||||
import qualified Network.GRPC.HighLevel.Server.Unregistered as U
|
||||
import Network.GRPC.LowLevel
|
||||
import qualified Network.GRPC.LowLevel.Call.Unregistered as U
|
||||
import qualified Network.GRPC.LowLevel.Server.Unregistered as U
|
||||
|
@ -80,34 +82,30 @@ regMainThreaded = do
|
|||
-- TODO: Put these in a common location (or just hack around it until CG is working)
|
||||
data EchoRequest = EchoRequest {message :: T.Text} deriving (Show, Eq, Ord, Generic)
|
||||
instance Message EchoRequest
|
||||
data AddRequest = AddRequest {addX :: Word32, addY :: Word32} deriving (Show, Eq, Ord, Generic)
|
||||
instance Message AddRequest
|
||||
data AddResponse = AddResponse {answer :: Word32} deriving (Show, Eq, Ord, Generic)
|
||||
instance Message AddResponse
|
||||
|
||||
highlevelMain :: IO ()
|
||||
highlevelMain =
|
||||
serverLoop defaultOptions{optNormalHandlers = [echoHandler, addHandler]}
|
||||
where echoHandler =
|
||||
echoHandler :: Handler 'Normal
|
||||
echoHandler =
|
||||
UnaryHandler "/echo.Echo/DoEcho" $
|
||||
\_c body m -> do
|
||||
tputStrLn $ "UnaryHandler for DoEcho hit, body=" ++ show body
|
||||
return ( body :: EchoRequest
|
||||
, m
|
||||
, StatusOk
|
||||
, StatusDetails ""
|
||||
)
|
||||
addHandler =
|
||||
--TODO: I can't get this one to execute. Is the generated method
|
||||
--name different?
|
||||
|
||||
-- static const char* Add_method_names[] = {
|
||||
-- "/echo.Add/DoAdd",
|
||||
-- };
|
||||
data AddRequest = AddRequest {addX :: Fixed Word32
|
||||
, addY :: Fixed Word32}
|
||||
deriving (Show, Eq, Ord, Generic)
|
||||
instance Message AddRequest
|
||||
data AddResponse = AddResponse {answer :: Fixed Word32}
|
||||
deriving (Show, Eq, Ord, Generic)
|
||||
instance Message AddResponse
|
||||
|
||||
addHandler :: Handler 'Normal
|
||||
addHandler =
|
||||
UnaryHandler "/echo.Add/DoAdd" $
|
||||
\_c b m -> do
|
||||
tputStrLn $ "UnaryHandler for DoAdd hit, b=" ++ show b
|
||||
--tputStrLn $ "UnaryHandler for DoAdd hit, b=" ++ show b
|
||||
print (addX b)
|
||||
print (addY b)
|
||||
return ( AddResponse $ addX b + addY b
|
||||
|
@ -116,8 +114,16 @@ highlevelMain =
|
|||
, StatusDetails ""
|
||||
)
|
||||
|
||||
highlevelMain :: IO ()
|
||||
highlevelMain =
|
||||
serverLoop defaultOptions{optNormalHandlers = [echoHandler, addHandler]}
|
||||
|
||||
highlevelMainUnregistered :: IO ()
|
||||
highlevelMainUnregistered =
|
||||
U.serverLoop defaultOptions{optNormalHandlers = [echoHandler, addHandler]}
|
||||
|
||||
main :: IO ()
|
||||
main = highlevelMain
|
||||
main = highlevelMainUnregistered
|
||||
|
||||
defConfig :: ServerConfig
|
||||
defConfig = ServerConfig "localhost" 50051 [] [] [] [] []
|
||||
|
|
|
@ -67,6 +67,7 @@ library
|
|||
Network.GRPC.LowLevel.Client
|
||||
Network.GRPC.HighLevel
|
||||
Network.GRPC.HighLevel.Server
|
||||
Network.GRPC.HighLevel.Server.Unregistered
|
||||
extra-libraries:
|
||||
grpc
|
||||
includes:
|
||||
|
|
|
@ -26,7 +26,7 @@ size_t gpr_slice_length_(gpr_slice *slice);
|
|||
|
||||
uint8_t *gpr_slice_start_(gpr_slice *slice);
|
||||
|
||||
gpr_slice* gpr_slice_from_copied_string_(const char *source);
|
||||
gpr_slice* gpr_slice_from_copied_buffer_(const char *source, size_t len);
|
||||
|
||||
void free_slice(gpr_slice *slice);
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
@ -28,7 +29,7 @@ convertServerHandler :: (Message a, Message b)
|
|||
=> ServerHandler' a b
|
||||
-> ServerHandler
|
||||
convertServerHandler f c bs m = case fromByteString bs of
|
||||
Left{} -> error "TODO: find a way to keep this from killing the server."
|
||||
Left x -> error $ "Failed to deserialize message: " ++ show x
|
||||
Right x -> do (y, tm, sc, sd) <- f c x m
|
||||
return (toBS y, tm, sc, sd)
|
||||
|
||||
|
@ -88,7 +89,7 @@ convertSend s = s . toBS
|
|||
toBS :: Message a => a -> ByteString
|
||||
toBS = BL.toStrict . toLazyByteString
|
||||
|
||||
data Handler a where
|
||||
data Handler (a :: GRPCMethodType) where
|
||||
UnaryHandler
|
||||
:: (Message c, Message d)
|
||||
=> MethodName
|
||||
|
@ -113,6 +114,11 @@ data Handler a where
|
|||
-> ServerRWHandler' c d
|
||||
-> Handler 'BiDiStreaming
|
||||
|
||||
data AnyHandler = forall (a :: GRPCMethodType) . AnyHandler (Handler a)
|
||||
|
||||
anyHandlerMethodName :: AnyHandler -> MethodName
|
||||
anyHandlerMethodName (AnyHandler m) = handlerMethodName m
|
||||
|
||||
handlerMethodName :: Handler a -> MethodName
|
||||
handlerMethodName (UnaryHandler m _) = m
|
||||
handlerMethodName (ClientStreamHandler m _) = m
|
||||
|
@ -142,21 +148,28 @@ handleCallError (Left GRPCIOShutdown) =
|
|||
return ()
|
||||
handleCallError (Left x) = logAskReport x
|
||||
|
||||
loopWError :: IO (Either GRPCIOError a) -> IO ()
|
||||
loopWError f = forever $ f >>= handleCallError
|
||||
loopWError :: Int
|
||||
-> IO (Either GRPCIOError a)
|
||||
-> IO ()
|
||||
loopWError i f = do
|
||||
when (i `mod` 100 == 0) $ putStrLn $ "i = " ++ show i
|
||||
f >>= handleCallError
|
||||
loopWError (i + 1) f
|
||||
|
||||
--TODO: options for setting initial/trailing metadata
|
||||
handleLoop :: Server -> (Handler a, RegisteredMethod a) -> IO ()
|
||||
handleLoop :: Server
|
||||
-> (Handler a, RegisteredMethod a)
|
||||
-> IO ()
|
||||
handleLoop s (UnaryHandler _ f, rm) =
|
||||
loopWError $ do
|
||||
grpcDebug' "handleLoop about to block on serverHandleNormalCall"
|
||||
loopWError 0 $ do
|
||||
--grpcDebug' "handleLoop about to block on serverHandleNormalCall"
|
||||
serverHandleNormalCall s rm mempty $ convertServerHandler f
|
||||
handleLoop s (ClientStreamHandler _ f, rm) =
|
||||
loopWError $ serverReader s rm mempty $ convertServerReaderHandler f
|
||||
loopWError 0 $ serverReader s rm mempty $ convertServerReaderHandler f
|
||||
handleLoop s (ServerStreamHandler _ f, rm) =
|
||||
loopWError $ serverWriter s rm mempty $ convertServerWriterHandler f
|
||||
loopWError 0 $ serverWriter s rm mempty $ convertServerWriterHandler f
|
||||
handleLoop s (BiDiStreamHandler _ f, rm) =
|
||||
loopWError $ serverRW s rm mempty $ convertServerRWHandler f
|
||||
loopWError 0 $ serverRW s rm mempty $ convertServerRWHandler f
|
||||
|
||||
data ServerOptions = ServerOptions
|
||||
{optNormalHandlers :: [Handler 'Normal],
|
||||
|
@ -194,7 +207,7 @@ serverLoop opts =
|
|||
asyncsCS <- mapM async $ map loop rmsCS
|
||||
asyncsSS <- mapM async $ map loop rmsSS
|
||||
asyncsB <- mapM async $ map loop rmsB
|
||||
asyncUnk <- async $ loopWError $ unknownHandler server
|
||||
asyncUnk <- async $ loopWError 0 $ unknownHandler server
|
||||
waitAnyCancel $ asyncUnk : asyncsN ++ asyncsCS ++ asyncsSS ++ asyncsB
|
||||
return ()
|
||||
where
|
||||
|
|
105
src/Network/GRPC/HighLevel/Server/Unregistered.hs
Normal file
105
src/Network/GRPC/HighLevel/Server/Unregistered.hs
Normal file
|
@ -0,0 +1,105 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
|
||||
module Network.GRPC.HighLevel.Server.Unregistered where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Concurrent.Async
|
||||
import Control.Monad
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Protobuf.Wire.Class
|
||||
import Data.Foldable (find)
|
||||
import Network.GRPC.HighLevel.Server
|
||||
import Network.GRPC.LowLevel
|
||||
import Network.GRPC.LowLevel.GRPC
|
||||
import Network.GRPC.LowLevel.Call
|
||||
import qualified Network.GRPC.LowLevel.Server.Unregistered as U
|
||||
import qualified Network.GRPC.LowLevel.Call.Unregistered as U
|
||||
|
||||
dispatchLoop :: Server
|
||||
-> [Handler 'Normal]
|
||||
-> [Handler 'ClientStreaming]
|
||||
-> [Handler 'ServerStreaming]
|
||||
-> [Handler 'BiDiStreaming]
|
||||
-> IO ()
|
||||
dispatchLoop server hN hC hS hB =
|
||||
forever $ U.withServerCallAsync server $ \call -> do
|
||||
case findHandler call allHandlers of
|
||||
Just (AnyHandler (UnaryHandler _ h)) -> unaryHandler call h
|
||||
Just (AnyHandler (ClientStreamHandler _ h)) -> csHandler call h
|
||||
Just (AnyHandler (ServerStreamHandler _ h)) -> ssHandler call h
|
||||
Just (AnyHandler (BiDiStreamHandler _ h)) -> bdHandler call h
|
||||
Nothing -> unknownHandler call
|
||||
where allHandlers = map AnyHandler hN
|
||||
++ map AnyHandler hC
|
||||
++ map AnyHandler hS
|
||||
++ map AnyHandler hB
|
||||
findHandler call = find ((== (U.callMethod call))
|
||||
. anyHandlerMethodName)
|
||||
unknownHandler call =
|
||||
void $ U.serverHandleNormalCall' server call mempty $ \_ _ ->
|
||||
return (mempty
|
||||
, mempty
|
||||
, StatusNotFound
|
||||
, StatusDetails "unknown method")
|
||||
handleError f = f >>= handleCallError
|
||||
unaryHandler :: (Message a, Message b) =>
|
||||
U.ServerCall
|
||||
-> ServerHandler' a b
|
||||
-> IO ()
|
||||
unaryHandler call h =
|
||||
handleError $
|
||||
U.serverHandleNormalCall' server call mempty $ \call' bs -> do
|
||||
let h' = convertServerHandler h
|
||||
h' (fmap (const bs) $ U.convertCall call)
|
||||
bs
|
||||
(U.requestMetadataRecv call)
|
||||
csHandler :: (Message a, Message b) =>
|
||||
U.ServerCall
|
||||
-> ServerReaderHandler' a b
|
||||
-> IO ()
|
||||
csHandler call h =
|
||||
handleError $
|
||||
U.serverReader server call mempty (convertServerReaderHandler h)
|
||||
ssHandler :: (Message a, Message b) =>
|
||||
U.ServerCall
|
||||
-> ServerWriterHandler' a b
|
||||
-> IO ()
|
||||
ssHandler call h =
|
||||
handleError $
|
||||
U.serverWriter server call mempty (convertServerWriterHandler h)
|
||||
bdHandler :: (Message a, Message b) =>
|
||||
U.ServerCall
|
||||
-> ServerRWHandler' a b
|
||||
-> IO ()
|
||||
bdHandler call h =
|
||||
handleError $
|
||||
U.serverRW server call mempty (convertServerRWHandler h)
|
||||
|
||||
serverLoop :: ServerOptions -> IO ()
|
||||
serverLoop opts@ServerOptions{..} =
|
||||
withGRPC $ \grpc ->
|
||||
withServer grpc (mkConfig opts) $ \server -> do
|
||||
dispatchLoop server
|
||||
optNormalHandlers
|
||||
optClientStreamHandlers
|
||||
optServerStreamHandlers
|
||||
optBiDiStreamHandlers
|
||||
where
|
||||
mkConfig ServerOptions{..} =
|
||||
ServerConfig
|
||||
{ host = "localhost"
|
||||
, port = optServerPort
|
||||
, methodsToRegisterNormal = []
|
||||
, methodsToRegisterClientStreaming = []
|
||||
, methodsToRegisterServerStreaming = []
|
||||
, methodsToRegisterBiDiStreaming = []
|
||||
, serverArgs =
|
||||
([CompressionAlgArg GrpcCompressDeflate | optUseCompression]
|
||||
++
|
||||
[UserAgentPrefix optUserAgentPrefix
|
||||
, UserAgentSuffix optUserAgentSuffix])
|
||||
}
|
|
@ -207,8 +207,9 @@ destroyClientCall cc = do
|
|||
C.grpcCallDestroy (unsafeCC cc)
|
||||
|
||||
destroyServerCall :: ServerCall a -> IO ()
|
||||
destroyServerCall sc@ServerCall{ unsafeSC = c } = do
|
||||
destroyServerCall sc@ServerCall{ unsafeSC = c, .. } = do
|
||||
grpcDebug "destroyServerCall(R): entered."
|
||||
debugServerCall sc
|
||||
_ <- shutdownCompletionQueue callCQ
|
||||
grpcDebug $ "Destroying server-side call object: " ++ show c
|
||||
C.grpcCallDestroy c
|
||||
|
|
|
@ -8,8 +8,8 @@ import Foreign.Ptr (Ptr)
|
|||
#ifdef DEBUG
|
||||
import Foreign.Storable (peek)
|
||||
#endif
|
||||
import Network.GRPC.LowLevel.Call (Host (..),
|
||||
MethodName (..))
|
||||
import qualified Network.GRPC.LowLevel.Call as Reg
|
||||
import Network.GRPC.LowLevel.CompletionQueue
|
||||
import Network.GRPC.LowLevel.CompletionQueue.Internal
|
||||
import Network.GRPC.LowLevel.GRPC (MetadataMap,
|
||||
grpcDebug)
|
||||
|
@ -23,12 +23,15 @@ data ServerCall = ServerCall
|
|||
{ unsafeSC :: C.Call
|
||||
, callCQ :: CompletionQueue
|
||||
, requestMetadataRecv :: MetadataMap
|
||||
, parentPtr :: Maybe (Ptr C.Call)
|
||||
, callDeadline :: TimeSpec
|
||||
, callMethod :: MethodName
|
||||
, callHost :: Host
|
||||
, callMethod :: Reg.MethodName
|
||||
, callHost :: Reg.Host
|
||||
}
|
||||
|
||||
convertCall :: ServerCall -> Reg.ServerCall ()
|
||||
convertCall ServerCall{..} =
|
||||
Reg.ServerCall unsafeSC callCQ requestMetadataRecv () callDeadline
|
||||
|
||||
serverCallCancel :: ServerCall -> C.StatusCode -> String -> IO ()
|
||||
serverCallCancel sc code reason =
|
||||
C.grpcCallCancelWithStatus (unsafeSC sc) code reason C.reserved
|
||||
|
@ -42,11 +45,6 @@ debugServerCall ServerCall{..} = do
|
|||
dbug $ "server call: " ++ show ptr
|
||||
dbug $ "metadata: " ++ show requestMetadataRecv
|
||||
|
||||
forM_ parentPtr $ \parentPtr' -> do
|
||||
dbug $ "parent ptr: " ++ show parentPtr'
|
||||
C.Call parent <- peek parentPtr'
|
||||
dbug $ "parent: " ++ show parent
|
||||
|
||||
dbug $ "deadline: " ++ show callDeadline
|
||||
dbug $ "method: " ++ show callMethod
|
||||
dbug $ "host: " ++ show callHost
|
||||
|
@ -60,6 +58,5 @@ destroyServerCall call@ServerCall{..} = do
|
|||
grpcDebug "destroyServerCall(U): entered."
|
||||
debugServerCall call
|
||||
grpcDebug $ "Destroying server-side call object: " ++ show unsafeSC
|
||||
shutdownCompletionQueue callCQ
|
||||
C.grpcCallDestroy unsafeSC
|
||||
grpcDebug $ "freeing parentPtr: " ++ show parentPtr
|
||||
forM_ parentPtr free
|
||||
|
|
|
@ -61,7 +61,6 @@ import qualified Network.GRPC.Unsafe.Time as C
|
|||
import System.Clock (Clock (..),
|
||||
getTime)
|
||||
import System.Info (os)
|
||||
import System.Timeout (timeout)
|
||||
|
||||
withCompletionQueue :: GRPC -> (CompletionQueue -> IO a) -> IO a
|
||||
withCompletionQueue grpc = bracket (createCompletionQueue grpc)
|
||||
|
@ -89,36 +88,6 @@ startBatch cq@CompletionQueue{..} call opArray opArraySize tag =
|
|||
grpcDebug "startBatch: grpc_call_start_batch call returned."
|
||||
return res
|
||||
|
||||
|
||||
-- | Shuts down the completion queue. See the comment above 'CompletionQueue'
|
||||
-- for the strategy we use to ensure that no one tries to use the
|
||||
-- queue after we begin the shutdown process. Errors with
|
||||
-- 'GRPCIOShutdownFailure' if the queue can't be shut down within 5 seconds.
|
||||
shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ())
|
||||
shutdownCompletionQueue CompletionQueue{..} = do
|
||||
atomically $ writeTVar shuttingDown True
|
||||
atomically $ do
|
||||
readTVar currentPushers >>= check . (==0)
|
||||
readTVar currentPluckers >>= check . (==0)
|
||||
--drain the queue
|
||||
C.grpcCompletionQueueShutdown unsafeCQ
|
||||
loopRes <- timeout (5*10^(6::Int)) drainLoop
|
||||
grpcDebug $ "Got CQ loop shutdown result of: " ++ show loopRes
|
||||
case loopRes of
|
||||
Nothing -> return $ Left GRPCIOShutdownFailure
|
||||
Just () -> C.grpcCompletionQueueDestroy unsafeCQ >> return (Right ())
|
||||
|
||||
where drainLoop :: IO ()
|
||||
drainLoop = do
|
||||
grpcDebug "drainLoop: before next() call"
|
||||
ev <- C.withDeadlineSeconds 1 $ \deadline ->
|
||||
C.grpcCompletionQueueNext unsafeCQ deadline C.reserved
|
||||
grpcDebug $ "drainLoop: next() call got " ++ show ev
|
||||
case C.eventCompletionType ev of
|
||||
C.QueueShutdown -> return ()
|
||||
C.QueueTimeout -> drainLoop
|
||||
C.OpComplete -> drainLoop
|
||||
|
||||
channelCreateCall :: C.Channel
|
||||
-> Maybe (ServerCall a)
|
||||
-> C.PropagationMask
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
module Network.GRPC.LowLevel.CompletionQueue.Internal where
|
||||
|
||||
import Control.Concurrent.STM (atomically, retry)
|
||||
import Control.Concurrent.STM (atomically, retry, check)
|
||||
import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar,
|
||||
writeTVar)
|
||||
import Control.Exception (bracket)
|
||||
|
@ -13,6 +13,7 @@ import Network.GRPC.LowLevel.GRPC
|
|||
import qualified Network.GRPC.Unsafe as C
|
||||
import qualified Network.GRPC.Unsafe.Constants as C
|
||||
import qualified Network.GRPC.Unsafe.Time as C
|
||||
import System.Timeout (timeout)
|
||||
|
||||
-- NOTE: the concurrency requirements for a CompletionQueue are a little
|
||||
-- complicated. There are two read operations: next and pluck. We can either
|
||||
|
@ -143,3 +144,32 @@ getCount Pluck = currentPluckers
|
|||
getLimit :: CQOpType -> Int
|
||||
getLimit Push = maxWorkPushers
|
||||
getLimit Pluck = C.maxCompletionQueuePluckers
|
||||
|
||||
-- | Shuts down the completion queue. See the comment above 'CompletionQueue'
|
||||
-- for the strategy we use to ensure that no one tries to use the
|
||||
-- queue after we begin the shutdown process. Errors with
|
||||
-- 'GRPCIOShutdownFailure' if the queue can't be shut down within 5 seconds.
|
||||
shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ())
|
||||
shutdownCompletionQueue CompletionQueue{..} = do
|
||||
atomically $ writeTVar shuttingDown True
|
||||
atomically $ do
|
||||
readTVar currentPushers >>= check . (==0)
|
||||
readTVar currentPluckers >>= check . (==0)
|
||||
--drain the queue
|
||||
C.grpcCompletionQueueShutdown unsafeCQ
|
||||
loopRes <- timeout (5*10^(6::Int)) drainLoop
|
||||
grpcDebug $ "Got CQ loop shutdown result of: " ++ show loopRes
|
||||
case loopRes of
|
||||
Nothing -> return $ Left GRPCIOShutdownFailure
|
||||
Just () -> C.grpcCompletionQueueDestroy unsafeCQ >> return (Right ())
|
||||
|
||||
where drainLoop :: IO ()
|
||||
drainLoop = do
|
||||
grpcDebug "drainLoop: before next() call"
|
||||
ev <- C.withDeadlineSeconds 1 $ \deadline ->
|
||||
C.grpcCompletionQueueNext unsafeCQ deadline C.reserved
|
||||
grpcDebug $ "drainLoop: next() call got " ++ show ev
|
||||
case C.eventCompletionType ev of
|
||||
C.QueueShutdown -> return ()
|
||||
C.QueueTimeout -> drainLoop
|
||||
C.OpComplete -> drainLoop
|
||||
|
|
|
@ -61,7 +61,6 @@ serverRequestCall s scq ccq =
|
|||
<$> peek call
|
||||
<*> return ccq
|
||||
<*> C.getAllMetadataArray md
|
||||
<*> return Nothing
|
||||
<*> (C.timeSpec <$> C.callDetailsGetDeadline cd)
|
||||
<*> (MethodName <$> C.callDetailsGetMethod cd)
|
||||
<*> (Host <$> C.callDetailsGetHost cd)
|
||||
|
|
|
@ -175,9 +175,10 @@ resultFromOpContext (OpRecvMessageContext pbb) = do
|
|||
grpcDebug "resultFromOpContext: OpRecvMessageContext"
|
||||
bb@(C.ByteBuffer bbptr) <- peek pbb
|
||||
if bbptr == nullPtr
|
||||
then return $ Just $ OpRecvMessageResult Nothing
|
||||
then do grpcDebug "resultFromOpContext: WARNING: got empty message."
|
||||
return $ Just $ OpRecvMessageResult Nothing
|
||||
else do bs <- C.copyByteBufferToByteString bb
|
||||
grpcDebug "resultFromOpContext: bb copied."
|
||||
grpcDebug $ "resultFromOpContext: bb copied: " ++ show bs
|
||||
return $ Just $ OpRecvMessageResult (Just bs)
|
||||
resultFromOpContext (OpRecvStatusOnClientContext pmetadata pcode pstr) = do
|
||||
grpcDebug "resultFromOpContext: OpRecvStatusOnClientContext"
|
||||
|
@ -294,6 +295,14 @@ recvStatusOnClient c cq = runOps' c cq [OpRecvStatusOnClient] >>= \case
|
|||
-> return (md, st, StatusDetails ds)
|
||||
_ -> throwE (GRPCIOInternalUnexpectedRecv "recvStatusOnClient")
|
||||
|
||||
recvInitialMessage :: RecvSingle ByteString
|
||||
recvInitialMessage c cq = runOps' c cq [OpRecvMessage] >>= \case
|
||||
[OpRecvMessageResult (Just bs)]
|
||||
-> return bs
|
||||
[OpRecvMessageResult Nothing]
|
||||
-> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMessage: no message.")
|
||||
_ -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMessage")
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Streaming types and helpers
|
||||
|
||||
|
|
|
@ -228,17 +228,17 @@ serverRegisterMethodBiDiStreaming internalServer meth e = do
|
|||
-- method.
|
||||
serverCreateCall :: Server
|
||||
-> RegisteredMethod mt
|
||||
-> CompletionQueue -- ^ call CQ
|
||||
-> IO (Either GRPCIOError (ServerCall (MethodPayload mt)))
|
||||
serverCreateCall Server{..} rm = serverRequestCall rm unsafeServer serverCQ
|
||||
serverCreateCall Server{..} rm = do
|
||||
callCQ <- createCompletionQueue serverGRPC
|
||||
serverRequestCall rm unsafeServer serverCQ callCQ
|
||||
|
||||
withServerCall :: Server
|
||||
-> RegisteredMethod mt
|
||||
-> (ServerCall (MethodPayload mt) -> IO (Either GRPCIOError a))
|
||||
-> IO (Either GRPCIOError a)
|
||||
withServerCall s rm f =
|
||||
withCompletionQueue (serverGRPC s) $
|
||||
serverCreateCall s rm >=> \case
|
||||
serverCreateCall s rm >>= \case
|
||||
Left e -> return (Left e)
|
||||
Right c -> do
|
||||
debugServerCall c
|
||||
|
|
|
@ -3,34 +3,64 @@
|
|||
|
||||
module Network.GRPC.LowLevel.Server.Unregistered where
|
||||
|
||||
import Control.Concurrent (forkIO)
|
||||
import Control.Exception (finally)
|
||||
import Control.Monad
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.ByteString (ByteString)
|
||||
import Network.GRPC.LowLevel.Call.Unregistered
|
||||
import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, withCompletionQueue)
|
||||
import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue
|
||||
, withCompletionQueue
|
||||
, createCompletionQueue)
|
||||
import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall)
|
||||
import Network.GRPC.LowLevel.GRPC
|
||||
import Network.GRPC.LowLevel.Op (Op (..), OpRecvResult (..),
|
||||
runOps)
|
||||
import Network.GRPC.LowLevel.Server (Server (..))
|
||||
import Network.GRPC.LowLevel.Op (Op (..)
|
||||
, OpRecvResult (..)
|
||||
, runOps
|
||||
, runStreamingProxy
|
||||
, streamRecv
|
||||
, streamSend
|
||||
, runOps'
|
||||
, sendInitialMetadata
|
||||
, sendStatusFromServer
|
||||
, recvInitialMessage)
|
||||
import Network.GRPC.LowLevel.Server (Server (..)
|
||||
, ServerReaderHandler
|
||||
, ServerWriterHandler
|
||||
, ServerRWHandler)
|
||||
import qualified Network.GRPC.Unsafe.Op as C
|
||||
|
||||
serverCreateCall :: Server
|
||||
-> CompletionQueue -- ^ call CQ
|
||||
-> IO (Either GRPCIOError ServerCall)
|
||||
serverCreateCall Server{..} = serverRequestCall unsafeServer serverCQ
|
||||
serverCreateCall Server{..} = do
|
||||
callCQ <- createCompletionQueue serverGRPC
|
||||
serverRequestCall unsafeServer serverCQ callCQ
|
||||
|
||||
withServerCall :: Server
|
||||
-> (ServerCall -> IO (Either GRPCIOError a))
|
||||
-> IO (Either GRPCIOError a)
|
||||
withServerCall s f =
|
||||
withCompletionQueue (serverGRPC s) $
|
||||
serverCreateCall s >=> \case
|
||||
serverCreateCall s >>= \case
|
||||
Left e -> return (Left e)
|
||||
Right c -> f c `finally` do
|
||||
grpcDebug "withServerCall: destroying."
|
||||
destroyServerCall c
|
||||
|
||||
-- | Gets a call and then forks the given function on a new thread, with the
|
||||
-- new call as input. Blocks until a call is received, then returns immediately.
|
||||
-- Handles cleaning up the call safely.
|
||||
-- Because this function doesn't wait for the handler to return, it cannot
|
||||
-- return errors.
|
||||
withServerCallAsync :: Server
|
||||
-> (ServerCall -> IO ())
|
||||
-> IO ()
|
||||
withServerCallAsync s f =
|
||||
serverCreateCall s >>= \case
|
||||
Left e -> return ()
|
||||
Right c -> void $ forkIO (f c `finally` do
|
||||
grpcDebug "withServerCallAsync: destroying."
|
||||
destroyServerCall c)
|
||||
|
||||
-- | Sequence of 'Op's needed to receive a normal (non-streaming) call.
|
||||
-- TODO: We have to put 'OpRecvCloseOnServer' in the response ops, or else the
|
||||
-- client times out. Given this, I have no idea how to check for cancellation on
|
||||
|
@ -63,9 +93,16 @@ serverHandleNormalCall :: Server
|
|||
-> MetadataMap -- ^ Initial server metadata.
|
||||
-> ServerHandler
|
||||
-> IO (Either GRPCIOError ())
|
||||
serverHandleNormalCall s initMeta f = withServerCall s go
|
||||
where
|
||||
go sc@ServerCall{ unsafeSC = c, callCQ = cq, .. } = do
|
||||
serverHandleNormalCall s initMeta f =
|
||||
withServerCall s $ \c -> serverHandleNormalCall' s c initMeta f
|
||||
|
||||
serverHandleNormalCall' :: Server
|
||||
-> ServerCall
|
||||
-> MetadataMap -- ^ Initial server metadata.
|
||||
-> ServerHandler
|
||||
-> IO (Either GRPCIOError ())
|
||||
serverHandleNormalCall'
|
||||
s sc@ServerCall{ unsafeSC = c, callCQ = cq, .. } initMeta f = do
|
||||
grpcDebug "serverHandleNormalCall(U): starting batch."
|
||||
runOps c cq
|
||||
[ OpSendInitialMetadata initMeta
|
||||
|
@ -92,3 +129,45 @@ serverHandleNormalCall s initMeta f = withServerCall s go
|
|||
grpcDebug "serverHandleNormalCall(U): ops done."
|
||||
return $ Right ()
|
||||
x -> error $ "impossible pattern match: " ++ show x
|
||||
|
||||
serverReader :: Server
|
||||
-> ServerCall
|
||||
-> MetadataMap -- ^ initial server metadata
|
||||
-> ServerReaderHandler
|
||||
-> IO (Either GRPCIOError ())
|
||||
serverReader s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f =
|
||||
runExceptT $ do
|
||||
(mmsg, trailMeta, st, ds) <-
|
||||
runStreamingProxy "serverReader" c ccq (f (convertCall sc) streamRecv)
|
||||
runOps' c ccq ( OpSendInitialMetadata initMeta
|
||||
: OpSendStatusFromServer trailMeta st ds
|
||||
: maybe [] ((:[]) . OpSendMessage) mmsg
|
||||
)
|
||||
return ()
|
||||
|
||||
serverWriter :: Server
|
||||
-> ServerCall
|
||||
-> MetadataMap
|
||||
-- ^ Initial server metadata
|
||||
-> ServerWriterHandler
|
||||
-> IO (Either GRPCIOError ())
|
||||
serverWriter s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f =
|
||||
runExceptT $ do
|
||||
bs <- recvInitialMessage c ccq
|
||||
sendInitialMetadata c ccq initMeta
|
||||
let regCall = fmap (const bs) (convertCall sc)
|
||||
st <- runStreamingProxy "serverWriter" c ccq (f regCall streamSend)
|
||||
sendStatusFromServer c ccq st
|
||||
|
||||
serverRW :: Server
|
||||
-> ServerCall
|
||||
-> MetadataMap
|
||||
-- ^ initial server metadata
|
||||
-> ServerRWHandler
|
||||
-> IO (Either GRPCIOError ())
|
||||
serverRW s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f =
|
||||
runExceptT $ do
|
||||
sendInitialMetadata c ccq initMeta
|
||||
let regCall = convertCall sc
|
||||
st <- runStreamingProxy "serverRW" c ccq (f regCall streamRecv streamSend)
|
||||
sendStatusFromServer c ccq st
|
||||
|
|
|
@ -33,7 +33,7 @@ deriving instance Show Slice
|
|||
-- slice.
|
||||
{#fun gpr_slice_start_ as ^ {`Slice'} -> `Ptr CChar' castPtr #}
|
||||
|
||||
{#fun gpr_slice_from_copied_string_ as ^ {`CString'} -> `Slice'#}
|
||||
{#fun gpr_slice_from_copied_buffer_ as ^ {`CString', `Int'} -> `Slice'#}
|
||||
|
||||
-- | Properly cleans up all memory used by a 'Slice'. Danger: the Slice should
|
||||
-- not be used after this function is called on it.
|
||||
|
@ -52,4 +52,4 @@ sliceToByteString slice = do
|
|||
|
||||
-- | Copies a 'ByteString' to a 'Slice'.
|
||||
byteStringToSlice :: B.ByteString -> IO Slice
|
||||
byteStringToSlice bs = B.useAsCString bs gprSliceFromCopiedString
|
||||
byteStringToSlice bs = B.useAsCStringLen bs $ uncurry gprSliceFromCopiedBuffer
|
||||
|
|
|
@ -49,8 +49,11 @@ lowLevelTests = testGroup "Unit tests of low-level Haskell library"
|
|||
, testClientCompression
|
||||
, testClientServerCompression
|
||||
, testClientStreaming
|
||||
, testClientStreamingUnregistered
|
||||
, testServerStreaming
|
||||
, testServerStreamingUnregistered
|
||||
, testBiDiStreaming
|
||||
, testBiDiStreamingUnregistered
|
||||
]
|
||||
|
||||
testGRPCBracket :: TestTree
|
||||
|
@ -184,6 +187,37 @@ testServerStreaming =
|
|||
return (dummyMeta, StatusOk, "dtls")
|
||||
r @?= Right ()
|
||||
|
||||
-- TODO: these unregistered streaming tests are basically the same as the
|
||||
-- registered ones. Reduce duplication.
|
||||
-- TODO: Once client-side unregistered streaming functions are added, switch
|
||||
-- to using them in these tests.
|
||||
testServerStreamingUnregistered :: TestTree
|
||||
testServerStreamingUnregistered =
|
||||
csTest "unregistered server streaming" client server ([],[],[],[])
|
||||
where
|
||||
clientInitMD = [("client","initmd")]
|
||||
serverInitMD = [("server","initmd")]
|
||||
clientPay = "FEED ME!"
|
||||
pays = ["ONE", "TWO", "THREE", "FOUR"] :: [ByteString]
|
||||
|
||||
client c = do
|
||||
rm <- clientRegisterMethodServerStreaming c "/feed"
|
||||
eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do
|
||||
liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD
|
||||
forM_ pays $ \p -> recv `is` Right (Just p)
|
||||
recv `is` Right Nothing
|
||||
eea @?= Right (dummyMeta, StatusOk, "dtls")
|
||||
|
||||
server s = U.withServerCallAsync s $ \call -> do
|
||||
r <- U.serverWriter s call serverInitMD $ \sc send -> do
|
||||
liftIO $ do
|
||||
checkMD "Server request metadata mismatch"
|
||||
clientInitMD (requestMetadataRecv sc)
|
||||
optionalPayload sc @?= clientPay
|
||||
forM_ pays $ \p -> send p `is` Right ()
|
||||
return (dummyMeta, StatusOk, "dtls")
|
||||
r @?= Right ()
|
||||
|
||||
testClientStreaming :: TestTree
|
||||
testClientStreaming =
|
||||
csTest "client streaming" client server ([],["/slurp"],[],[])
|
||||
|
@ -213,6 +247,34 @@ testClientStreaming =
|
|||
return (Just serverRsp, trailMD, serverStatus, serverDtls)
|
||||
eea @?= Right ()
|
||||
|
||||
testClientStreamingUnregistered :: TestTree
|
||||
testClientStreamingUnregistered =
|
||||
csTest "unregistered client streaming" client server ([],[],[],[])
|
||||
where
|
||||
clientInitMD = [("a","b")]
|
||||
serverInitMD = [("x","y")]
|
||||
trailMD = dummyMeta
|
||||
serverRsp = "serverReader reply"
|
||||
serverDtls = "deets"
|
||||
serverStatus = StatusOk
|
||||
pays = ["P_ONE", "P_TWO", "P_THREE"] :: [ByteString]
|
||||
|
||||
client c = do
|
||||
rm <- clientRegisterMethodClientStreaming c "/slurp"
|
||||
eea <- clientWriter c rm 10 clientInitMD $ \send -> do
|
||||
-- liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD
|
||||
forM_ pays $ \p -> send p `is` Right ()
|
||||
eea @?= Right (Just serverRsp, serverInitMD, trailMD, serverStatus, serverDtls)
|
||||
|
||||
server s = U.withServerCallAsync s $ \call -> do
|
||||
eea <- U.serverReader s call serverInitMD $ \sc recv -> do
|
||||
liftIO $ checkMD "Client request metadata mismatch"
|
||||
clientInitMD (requestMetadataRecv sc)
|
||||
forM_ pays $ \p -> recv `is` Right (Just p)
|
||||
recv `is` Right Nothing
|
||||
return (Just serverRsp, trailMD, serverStatus, serverDtls)
|
||||
eea @?= Right ()
|
||||
|
||||
testBiDiStreaming :: TestTree
|
||||
testBiDiStreaming =
|
||||
csTest "bidirectional streaming" client server ([],[],[],["/bidi"])
|
||||
|
@ -249,6 +311,41 @@ testBiDiStreaming =
|
|||
return (trailMD, serverStatus, serverDtls)
|
||||
eea @?= Right ()
|
||||
|
||||
testBiDiStreamingUnregistered :: TestTree
|
||||
testBiDiStreamingUnregistered =
|
||||
csTest "unregistered bidirectional streaming" client server ([],[],[],[])
|
||||
where
|
||||
clientInitMD = [("bidi-streaming","client")]
|
||||
serverInitMD = [("bidi-streaming","server")]
|
||||
trailMD = dummyMeta
|
||||
serverStatus = StatusOk
|
||||
serverDtls = "deets"
|
||||
is act x = act >>= liftIO . (@?= x)
|
||||
|
||||
client c = do
|
||||
rm <- clientRegisterMethodBiDiStreaming c "/bidi"
|
||||
eea <- clientRW c rm 10 clientInitMD $ \initMD recv send -> do
|
||||
send "cw0" `is` Right ()
|
||||
recv `is` Right (Just "sw0")
|
||||
send "cw1" `is` Right ()
|
||||
recv `is` Right (Just "sw1")
|
||||
recv `is` Right (Just "sw2")
|
||||
return ()
|
||||
eea @?= Right (trailMD, serverStatus, serverDtls)
|
||||
|
||||
server s = U.withServerCallAsync s $ \call -> do
|
||||
eea <- U.serverRW s call serverInitMD $ \sc recv send -> do
|
||||
liftIO $ checkMD "Client request metadata mismatch"
|
||||
clientInitMD (requestMetadataRecv sc)
|
||||
recv `is` Right (Just "cw0")
|
||||
send "sw0" `is` Right ()
|
||||
recv `is` Right (Just "cw1")
|
||||
send "sw1" `is` Right ()
|
||||
send "sw2" `is` Right ()
|
||||
recv `is` Right Nothing
|
||||
return (trailMD, serverStatus, serverDtls)
|
||||
eea @?= Right ()
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Unregistered tests
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ import Test.Tasty.HUnit as HU (testCase, (@?=),
|
|||
unsafeTests :: TestTree
|
||||
unsafeTests = testGroup "Unit tests for unsafe C bindings"
|
||||
[ roundtripSlice "Hello, world!"
|
||||
, roundtripSlice "\NULabc\NUL"
|
||||
, roundtripByteBuffer "Hwaet! We gardena in geardagum..."
|
||||
, roundtripSlice largeByteString
|
||||
, roundtripByteBuffer largeByteString
|
||||
|
@ -45,7 +46,7 @@ roundtripSlice :: B.ByteString -> TestTree
|
|||
roundtripSlice bs = testCase "ByteString slice roundtrip" $ do
|
||||
slice <- byteStringToSlice bs
|
||||
unslice <- sliceToByteString slice
|
||||
bs HU.@?= unslice
|
||||
unslice HU.@?= bs
|
||||
freeSlice slice
|
||||
|
||||
roundtripByteBuffer :: B.ByteString -> TestTree
|
||||
|
|
Loading…
Reference in a new issue