diff --git a/cbits/grpc_haskell.c b/cbits/grpc_haskell.c index 1403ccc..80942d7 100644 --- a/cbits/grpc_haskell.c +++ b/cbits/grpc_haskell.c @@ -401,13 +401,10 @@ gpr_timespec* call_details_get_deadline(grpc_call_details* details){ return &(details->deadline); } -void* grpc_server_register_method_(grpc_server* server, const char* method, - const char* host){ - //NOTE: grpc 0.14.0 added more params to this function. None of our code takes - //advantage of them, so we hardcode to the equivalent of 0.13.0's behavior. - return grpc_server_register_method(server, method, host, - GRPC_SRM_PAYLOAD_READ_INITIAL_BYTE_BUFFER, - 0); +void* grpc_server_register_method_( + grpc_server* server, const char* method, + const char* host, grpc_server_register_method_payload_handling payload_handling ){ + return grpc_server_register_method(server, method, host, payload_handling, 0); } grpc_arg* create_arg_array(size_t n){ diff --git a/examples/echo/echo-server/Main.hs b/examples/echo/echo-server/Main.hs index 79006f2..aa3f5c5 100644 --- a/examples/echo/echo-server/Main.hs +++ b/examples/echo/echo-server/Main.hs @@ -1,6 +1,7 @@ -{-# LANGUAGE OverloadedLists #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} {-# OPTIONS_GHC -fno-warn-missing-signatures #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} @@ -35,7 +36,7 @@ regMain = withGRPC $ \grpc -> do let methods = [(MethodName "/echo.Echo/DoEcho", Normal)] withServer grpc (ServerConfig "localhost" 50051 methods []) $ \server -> forever $ do - let method = head (registeredMethods server) + let method = head (normalMethods server) result <- serverHandleNormalCall server method serverMeta $ \_call reqBody _reqMeta -> return (reqBody, serverMeta, StatusOk, StatusDetails "") @@ -44,7 +45,7 @@ regMain = withGRPC $ \grpc -> do Right _ -> return () -- | loop to fork n times -regLoop :: Server -> RegisteredMethod -> IO () +regLoop :: Server -> RegisteredMethod 'Normal -> IO () regLoop server method = forever $ do result <- serverHandleNormalCall server method serverMeta $ \_call reqBody _reqMeta -> return (reqBody, serverMeta, StatusOk, @@ -58,7 +59,7 @@ regMainThreaded = do withGRPC $ \grpc -> do let methods = [(MethodName "/echo.Echo/DoEcho", Normal)] withServer grpc (ServerConfig "localhost" 50051 methods []) $ \server -> do - let method = head (registeredMethods server) + let method = head (normalMethods server) tid1 <- async $ regLoop server method tid2 <- async $ regLoop server method wait tid1 diff --git a/grpc-haskell.cabal b/grpc-haskell.cabal index 0958f0d..2f06c55 100644 --- a/grpc-haskell.cabal +++ b/grpc-haskell.cabal @@ -30,6 +30,15 @@ library , bytestring ==0.10.* , stm == 2.4.* , containers ==0.5.* + , managed >= 1.0.5 && < 1.1 + , pipes ==4.1.* + , transformers + + , async + , tasty >= 0.11 && <0.12 + , tasty-hunit >= 0.9 && <0.10 + , safe + c-sources: cbits/grpc_haskell.c exposed-modules: @@ -70,7 +79,6 @@ library include-dirs: include hs-source-dirs: src default-extensions: CPP - if flag(debug) CPP-Options: -DDEBUG CC-Options: -DGRPC_HASKELL_DEBUG @@ -115,6 +123,10 @@ test-suite test , tasty >= 0.11 && <0.12 , tasty-hunit >= 0.9 && <0.10 , containers ==0.5.* + , managed >= 1.0.5 && < 1.1 + , pipes ==4.1.* + , transformers + , safe other-modules: LowLevelTests, LowLevelTests.Op, @@ -125,7 +137,6 @@ test-suite test main-is: Properties.hs type: exitcode-stdio-1.0 default-extensions: CPP - if flag(debug) CPP-Options: -DDEBUG CC-Options: -DGRPC_HASKELL_DEBUG diff --git a/include/grpc_haskell.h b/include/grpc_haskell.h index 884cc5d..907eae8 100644 --- a/include/grpc_haskell.h +++ b/include/grpc_haskell.h @@ -137,8 +137,9 @@ char* call_details_get_host(grpc_call_details* details); gpr_timespec* call_details_get_deadline(grpc_call_details* details); -void* grpc_server_register_method_(grpc_server* server, const char* method, - const char* host); +void* grpc_server_register_method_( + grpc_server* server, const char* method, const char* host, + grpc_server_register_method_payload_handling payload_handling); //c2hs doesn't support #const pragmas referring to #define'd strings, so we use //this enum as a workaround. These are converted into actual GRPC #defines in diff --git a/src/Network/GRPC/LowLevel.hs b/src/Network/GRPC/LowLevel.hs index d4e4ced..bc69028 100644 --- a/src/Network/GRPC/LowLevel.hs +++ b/src/Network/GRPC/LowLevel.hs @@ -32,14 +32,17 @@ GRPC -- * Server , ServerConfig(..) -, Server -, ServerCall -, registeredMethods +, Server(normalMethods, sstreamingMethods, cstreamingMethods, + bidiStreamingMethods) +, ServerCall(optionalPayload, requestMetadataRecv) , withServer , serverHandleNormalCall , withServerCall , serverCallCancel , serverCallIsExpired +, serverReader -- for client streaming +, serverWriter -- for server streaming +, serverRW -- for bidirectional streaming -- * Client , ClientConfig(..) @@ -50,6 +53,9 @@ GRPC , withClient , clientRegisterMethod , clientRequest +, clientReader -- for server streaming +, clientWriter -- for client streaming +, clientRW -- for bidirectional streaming , withClientCall , withClientCallParent , clientCallCancel diff --git a/src/Network/GRPC/LowLevel/Call.hs b/src/Network/GRPC/LowLevel/Call.hs index bceb2e0..91e79ad 100644 --- a/src/Network/GRPC/LowLevel/Call.hs +++ b/src/Network/GRPC/LowLevel/Call.hs @@ -1,16 +1,19 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE FlexibleInstances #-} -- | This module defines data structures and operations pertaining to registered -- calls; for unregistered call support, see -- `Network.GRPC.LowLevel.Call.Unregistered`. module Network.GRPC.LowLevel.Call where -import Control.Monad import Data.ByteString (ByteString) import Data.String (IsString) -import Foreign.Marshal.Alloc (free) -import Foreign.Ptr (Ptr) +#ifdef DEBUG +import Foreign.Storable (peek) +#endif import System.Clock import qualified Network.GRPC.Unsafe as C @@ -18,9 +21,13 @@ import qualified Network.GRPC.Unsafe.Op as C import Network.GRPC.LowLevel.GRPC (MetadataMap, grpcDebug) --- | Models the four types of RPC call supported by gRPC. We currently only --- support the first alternative, and only in a preliminary fashion. -data GRPCMethodType = Normal | ClientStreaming | ServerStreaming | BiDiStreaming +-- | Models the four types of RPC call supported by gRPC (and correspond to +-- DataKinds phantom types on RegisteredMethods). +data GRPCMethodType + = Normal + | ClientStreaming + | ServerStreaming + | BiDiStreaming deriving (Show, Eq, Ord, Enum) newtype MethodName = MethodName {unMethodName :: String} @@ -40,15 +47,18 @@ endpoint :: Host -> Port -> Endpoint endpoint (Host h) (Port p) = Endpoint (h ++ ":" ++ show p) -- | Represents a registered method. Methods can optionally be registered in --- order to make the C-level request/response code simpler. --- Before making or awaiting a registered call, the --- method must be registered with the client (see 'clientRegisterMethod') and --- the server (see 'serverRegisterMethod'). --- Contains state for identifying that method in the underlying gRPC library. -data RegisteredMethod = RegisteredMethod {methodType :: GRPCMethodType, - methodName :: MethodName, - methodEndpoint :: Endpoint, - methodHandle :: C.CallHandle} +-- order to make the C-level request/response code simpler. Before making or +-- awaiting a registered call, the method must be registered with the client +-- (see 'clientRegisterMethod') and the server (see 'serverRegisterMethod'). +-- Contains state for identifying that method in the underlying gRPC +-- library. Note that we use a DataKind-ed phantom type to help constrain use of +-- different kinds of registered methods. +data RegisteredMethod (mt :: GRPCMethodType) = RegisteredMethod + { methodType :: GRPCMethodType + , methodName :: MethodName + , methodEndpoint :: Endpoint + , methodHandle :: C.CallHandle + } -- | Represents one GRPC call (i.e. request) on the client. -- This is used to associate send/receive 'Op's with a request. @@ -70,6 +80,14 @@ serverCallCancel :: ServerCall -> C.StatusCode -> String -> IO () serverCallCancel sc code reason = C.grpcCallCancelWithStatus (unServerCall sc) code reason C.reserved +-- | NB: For now, we've assumed that the method type is all the info we need to +-- decide the server payload handling method. +payloadHandling :: GRPCMethodType -> C.ServerRegisterMethodPayloadHandling +payloadHandling Normal = C.SrmPayloadReadInitialByteBuffer +payloadHandling ClientStreaming = C.SrmPayloadNone +payloadHandling ServerStreaming = C.SrmPayloadReadInitialByteBuffer +payloadHandling BiDiStreaming = C.SrmPayloadNone + serverCallIsExpired :: ServerCall -> IO Bool serverCallIsExpired sc = do currTime <- getTime Monotonic diff --git a/src/Network/GRPC/LowLevel/Call/Unregistered.hs b/src/Network/GRPC/LowLevel/Call/Unregistered.hs index 9e724a5..edeb54c 100644 --- a/src/Network/GRPC/LowLevel/Call/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Call/Unregistered.hs @@ -5,6 +5,9 @@ module Network.GRPC.LowLevel.Call.Unregistered where import Control.Monad import Foreign.Marshal.Alloc (free) import Foreign.Ptr (Ptr) +#ifdef DEBUG +import Foreign.Storable (peek) +#endif import System.Clock (TimeSpec) import Network.GRPC.LowLevel.Call (Host (..), MethodName (..)) @@ -30,14 +33,14 @@ serverCallCancel sc code reason = debugServerCall :: ServerCall -> IO () #ifdef DEBUG -debugServerCall call@ServerCall{..} = do - let (C.Call ptr) = unServerCall - grpcDebug $ "debugServerCall(U): server call: " ++ (show ptr) +debugServerCall ServerCall{..} = do + let C.Call ptr = unServerCall + grpcDebug $ "debugServerCall(U): server call: " ++ show ptr grpcDebug $ "debugServerCall(U): metadata: " ++ show requestMetadataRecv forM_ parentPtr $ \parentPtr' -> do grpcDebug $ "debugServerCall(U): parent ptr: " ++ show parentPtr' - (C.Call parent) <- peek parentPtr' + C.Call parent <- peek parentPtr' grpcDebug $ "debugServerCall(U): parent: " ++ show parent grpcDebug $ "debugServerCall(U): deadline: " ++ show callDeadline grpcDebug $ "debugServerCall(U): method: " ++ show callMethod diff --git a/src/Network/GRPC/LowLevel/Client.hs b/src/Network/GRPC/LowLevel/Client.hs index d9e1359..975200b 100644 --- a/src/Network/GRPC/LowLevel/Client.hs +++ b/src/Network/GRPC/LowLevel/Client.hs @@ -1,25 +1,33 @@ -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ViewPatterns #-} -- | This module defines data structures and operations pertaining to registered -- clients using registered calls; for unregistered support, see -- `Network.GRPC.LowLevel.Client.Unregistered`. module Network.GRPC.LowLevel.Client where +import Control.Arrow import Control.Exception (bracket, finally) -import Control.Monad (join) +import Control.Monad +import Control.Monad.Trans.Class (MonadTrans(lift)) +import Control.Monad.Trans.Except import Data.ByteString (ByteString) -import Foreign.Ptr (nullPtr) +import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.CompletionQueue +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Op import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.ChannelArgs as C import qualified Network.GRPC.Unsafe.Constants as C import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.Time as C - -import Network.GRPC.LowLevel.Call -import Network.GRPC.LowLevel.CompletionQueue -import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op +import qualified Pipes as P +import qualified Pipes.Core as P -- | Represents the context needed to perform client-side gRPC operations. data Client = Client {clientChannel :: C.Channel, @@ -71,22 +79,20 @@ clientConnectivity Client{..} = -- | Register a method on the client so that we can call it with -- 'clientRequest'. clientRegisterMethod :: Client - -> MethodName - -- ^ method name, e.g. "/foo" - -> GRPCMethodType - -> IO RegisteredMethod -clientRegisterMethod Client{..} meth Normal = do + -> MethodName + -> GRPCMethodType + -> IO (RegisteredMethod mt) +clientRegisterMethod Client{..} meth mty = do let e = clientEndpoint clientConfig - handle <- C.grpcChannelRegisterCall clientChannel - (unMethodName meth) (unEndpoint e) C.reserved - return $ RegisteredMethod Normal meth e handle -clientRegisterMethod _ _ _ = error "Streaming methods not yet implemented." + RegisteredMethod mty meth e <$> + C.grpcChannelRegisterCall clientChannel + (unMethodName meth) (unEndpoint e) C.reserved -- | Create a new call on the client for a registered method. -- Returns 'Left' if the CQ is shutting down or if the job to create a call -- timed out. clientCreateCall :: Client - -> RegisteredMethod + -> RegisteredMethod mt -> TimeoutSeconds -> IO (Either GRPCIOError ClientCall) clientCreateCall c rm ts = clientCreateCallParent c rm ts Nothing @@ -95,9 +101,9 @@ clientCreateCall c rm ts = clientCreateCallParent c rm ts Nothing -- a client call with an optional parent server call. This allows for cascading -- call cancellation from the `ServerCall` to the `ClientCall`. clientCreateCallParent :: Client - -> RegisteredMethod + -> RegisteredMethod mt -> TimeoutSeconds - -> (Maybe ServerCall) + -> Maybe ServerCall -- ^ Optional parent call for cascading cancellation. -> IO (Either GRPCIOError ClientCall) clientCreateCallParent Client{..} RegisteredMethod{..} timeout parent = do @@ -107,7 +113,7 @@ clientCreateCallParent Client{..} RegisteredMethod{..} timeout parent = do -- | Handles safe creation and cleanup of a client call withClientCall :: Client - -> RegisteredMethod + -> RegisteredMethod mt -> TimeoutSeconds -> (ClientCall -> IO (Either GRPCIOError a)) -> IO (Either GRPCIOError a) @@ -119,7 +125,7 @@ withClientCall client regmethod timeout f = -- `ServerCall` to the created `ClientCall`. Obviously, this is only useful if -- the given gRPC client is also a server. withClientCallParent :: Client - -> RegisteredMethod + -> RegisteredMethod mt -> TimeoutSeconds -> (Maybe ServerCall) -- ^ Optional parent call for cascading cancellation. @@ -135,8 +141,8 @@ withClientCallParent client regmethod timeout parent f = do data NormalRequestResult = NormalRequestResult { rspBody :: ByteString - , initMD :: MetadataMap -- initial metadata - , trailMD :: MetadataMap -- trailing metadata + , initMD :: MetadataMap -- ^ initial metadata + , trailMD :: MetadataMap -- ^ trailing metadata , rspCode :: C.StatusCode , details :: StatusDetails } @@ -156,58 +162,136 @@ compileNormalRequestResults x = Just (_meta, status, details) -> Left (GRPCIOBadStatusCode status (StatusDetails details)) +-------------------------------------------------------------------------------- +-- clientReader (client side of server streaming mode) + +-- | First parameter is initial server metadata. +type ClientReaderHandler = MetadataMap -> StreamRecv -> Streaming () + +clientReader :: Client + -> RegisteredMethod 'ServerStreaming + -> TimeoutSeconds + -> ByteString -- ^ The body of the request + -> MetadataMap -- ^ Metadata to send with the request + -> ClientReaderHandler + -> IO (Either GRPCIOError (MetadataMap, C.StatusCode, StatusDetails)) +clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f = + withClientCall cl rm tm go + where + go cc@(unClientCall -> c) = runExceptT $ do + lift (debugClientCall cc) + runOps' c cq [ OpSendInitialMetadata initMeta + , OpSendMessage body + , OpSendCloseFromClient + ] + srvMD <- recvInitialMetadata c cq + runStreamingProxy "clientReader'" c cq (f srvMD streamRecv) + recvStatusOnClient c cq + +-------------------------------------------------------------------------------- +-- clientWriter (client side of client streaming mode) + +type ClientWriterHandler = StreamSend -> Streaming () +type ClientWriterResult = (Maybe ByteString, MetadataMap, MetadataMap, + C.StatusCode, StatusDetails) + +clientWriter :: Client + -> RegisteredMethod 'ClientStreaming + -> TimeoutSeconds + -> MetadataMap -- ^ Initial client metadata + -> ClientWriterHandler + -> IO (Either GRPCIOError ClientWriterResult) +clientWriter cl rm tm initMeta = + withClientCall cl rm tm . clientWriterCmn cl initMeta + +clientWriterCmn :: Client -- ^ The active client + -> MetadataMap -- ^ Initial client metadata + -> ClientWriterHandler + -> ClientCall -- ^ The active client call + -> IO (Either GRPCIOError ClientWriterResult) +clientWriterCmn (clientCQ -> cq) initMeta f cc@(unClientCall -> c) = + runExceptT $ do + lift (debugClientCall cc) + sendInitialMetadata c cq initMeta + runStreamingProxy "clientWriterCmn" c cq (f streamSend) + sendSingle c cq OpSendCloseFromClient + let ops = [OpRecvInitialMetadata, OpRecvMessage, OpRecvStatusOnClient] + runOps' c cq ops >>= \case + CWRFinal mmsg initMD trailMD st ds + -> return (mmsg, initMD, trailMD, st, ds) + _ -> throwE (GRPCIOInternalUnexpectedRecv "clientWriter") + +pattern CWRFinal mmsg initMD trailMD st ds + <- [ OpRecvInitialMetadataResult initMD + , OpRecvMessageResult mmsg + , OpRecvStatusOnClientResult trailMD st (StatusDetails -> ds) + ] + +-------------------------------------------------------------------------------- +-- clientRW (client side of bidirectional streaming mode) + +-- | First parameter is initial server metadata. +type ClientRWHandler = MetadataMap -> StreamRecv -> StreamSend -> Streaming () + +-- | For bidirectional-streaming registered requests +clientRW :: Client + -> RegisteredMethod 'BiDiStreaming + -> TimeoutSeconds + -> MetadataMap + -- ^ request metadata + -> ClientRWHandler + -> IO (Either GRPCIOError (MetadataMap, C.StatusCode, StatusDetails)) +clientRW c@Client{ clientCQ = cq } rm tm initMeta f = + withClientCall c rm tm go + where + go cc@(unClientCall -> call) = runExceptT $ do + lift (debugClientCall cc) + sendInitialMetadata call cq initMeta + srvMeta <- recvInitialMetadata call cq + runStreamingProxy "clientRW" call cq (f srvMeta streamRecv streamSend) + runOps' call cq [OpSendCloseFromClient] -- WritesDone() + recvStatusOnClient call cq -- Finish() + +-------------------------------------------------------------------------------- +-- clientRequest (client side of normal request/response) + -- | Make a request of the given method with the given body. Returns the --- server's response. TODO: This is preliminary until we figure out how many --- different variations on sending request ops will be needed for full gRPC --- functionality. +-- server's response. clientRequest :: Client - -> RegisteredMethod + -> RegisteredMethod 'Normal -> TimeoutSeconds - -- ^ Timeout of both the grpc_call and the max time to wait for - -- the completion of the batch. TODO: I think we will need to - -- decouple the lifetime of the call from the queue deadline once - -- we expose functionality for streaming calls, where one call - -- object persists across many batches. -> ByteString -- ^ The body of the request -> MetadataMap -- ^ Metadata to send with the request -> IO (Either GRPCIOError NormalRequestResult) -clientRequest client@(Client{..}) rm@(RegisteredMethod{..}) - timeLimit body meta = - fmap join $ case methodType of - Normal -> withClientCall client rm timeLimit $ \call -> do - grpcDebug "clientRequest(R): created call." - debugClientCall call - let call' = unClientCall call - -- NOTE: sendOps and recvOps *must* be in separate batches or - -- the client hangs when the server can't be reached. - let sendOps = [OpSendInitialMetadata meta - , OpSendMessage body - , OpSendCloseFromClient] - sendRes <- runOps call' clientCQ sendOps - case sendRes of - Left x -> do grpcDebug "clientRequest(R) : batch error sending." - return $ Left x - Right rs -> do - let recvOps = [OpRecvInitialMetadata, - OpRecvMessage, - OpRecvStatusOnClient] - recvRes <- runOps call' clientCQ recvOps - case recvRes of - Left x -> do - grpcDebug "clientRequest(R): batch error receiving." - return $ Left x - Right rs' -> do - grpcDebug $ "clientRequest(R): got " ++ show rs' - return $ Right $ compileNormalRequestResults (rs ++ rs') - _ -> error "Streaming methods not yet implemented." - -clientNormalRequestOps :: ByteString -> MetadataMap -> [Op] -clientNormalRequestOps body metadata = - [OpSendInitialMetadata metadata, - OpSendMessage body, - OpSendCloseFromClient, - OpRecvInitialMetadata, - OpRecvMessage, - OpRecvStatusOnClient] +clientRequest c@Client{ clientCQ = cq } rm tm body initMeta = + withClientCall c rm tm (fmap join . go) + where + go cc@(unClientCall -> call) = do + grpcDebug "clientRequest(R): created call." + debugClientCall cc + -- NB: the send and receive operations below *must* be in separate + -- batches, or the client hangs when the server can't be reached. + runOps call cq + [ OpSendInitialMetadata initMeta + , OpSendMessage body + , OpSendCloseFromClient + ] + >>= \case + Left x -> do + grpcDebug "clientRequest(R) : batch error sending." + return $ Left x + Right rs -> + runOps call cq + [ OpRecvInitialMetadata + , OpRecvMessage + , OpRecvStatusOnClient + ] + >>= \case + Left x -> do + grpcDebug "clientRequest(R): batch error receiving.." + return $ Left x + Right rs' -> do + grpcDebug $ "clientRequest(R): got " ++ show rs' + return $ Right $ compileNormalRequestResults (rs ++ rs') diff --git a/src/Network/GRPC/LowLevel/Client/Unregistered.hs b/src/Network/GRPC/LowLevel/Client/Unregistered.hs index e0a47d7..0fe6da9 100644 --- a/src/Network/GRPC/LowLevel/Client/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Client/Unregistered.hs @@ -2,6 +2,7 @@ module Network.GRPC.LowLevel.Client.Unregistered where +import Control.Arrow import Control.Exception (finally) import Control.Monad (join) import Data.ByteString (ByteString) @@ -14,7 +15,6 @@ import Network.GRPC.LowLevel.Call import Network.GRPC.LowLevel.Client (Client (..), NormalRequestResult (..), clientEndpoint, - clientNormalRequestOps, compileNormalRequestResults) import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds) import qualified Network.GRPC.LowLevel.CompletionQueue.Unregistered as U @@ -60,11 +60,14 @@ clientRequest :: Client -- ^ Request metadata. -> IO (Either GRPCIOError NormalRequestResult) clientRequest client@Client{..} meth timeLimit body meta = - fmap join $ do - withClientCall client meth timeLimit $ \call -> do - let ops = clientNormalRequestOps body meta - results <- runOps (unClientCall call) clientCQ ops + fmap join $ withClientCall client meth timeLimit $ \call -> do + results <- runOps (unClientCall call) clientCQ + [ OpSendInitialMetadata meta + , OpSendMessage body + , OpSendCloseFromClient + , OpRecvInitialMetadata + , OpRecvMessage + , OpRecvStatusOnClient + ] grpcDebug "clientRequest(U): ops ran." - case results of - Left x -> return $ Left x - Right rs -> return $ Right $ compileNormalRequestResults rs + return $ right compileNormalRequestResults results diff --git a/src/Network/GRPC/LowLevel/CompletionQueue.hs b/src/Network/GRPC/LowLevel/CompletionQueue.hs index 8c32bc6..85b4b3e 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue.hs @@ -10,7 +10,11 @@ -- implementation details to both are kept in -- `Network.GRPC.LowLevel.CompletionQueue.Internal`. +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} module Network.GRPC.LowLevel.CompletionQueue ( CompletionQueue @@ -33,12 +37,15 @@ import Control.Concurrent.STM (atomically, check) import Control.Concurrent.STM.TVar (newTVarIO, readTVar, writeTVar) import Control.Exception (bracket) +import Control.Monad.Trans.Class (MonadTrans(lift)) +import Control.Monad.Trans.Except import Control.Monad (liftM2) +import Control.Monad.Managed import Data.IORef (newIORef) import Data.List (intersperse) import Foreign.Marshal.Alloc (free, malloc) -import Foreign.Ptr (nullPtr) -import Foreign.Storable (peek) +import Foreign.Ptr (Ptr, nullPtr) +import Foreign.Storable (Storable, peek) import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.Constants as C import qualified Network.GRPC.Unsafe.Metadata as C @@ -63,11 +70,8 @@ createCompletionQueue _ = do currentPushers <- newTVarIO 0 shuttingDown <- newTVarIO False nextTag <- newIORef minBound - return $ CompletionQueue{..} + return CompletionQueue{..} --- TODO: I'm thinking it might be easier to use 'Either' uniformly everywhere --- even when it's isomorphic to 'Maybe'. If that doesn't turn out to be the --- case, switch these to 'Maybe'. -- | Very simple wrapper around 'grpcCallStartBatch'. Throws 'GRPCIOShutdown' -- without calling 'grpcCallStartBatch' if the queue is shutting down. -- Throws 'CallError' if 'grpcCallStartBatch' returns a non-OK code. @@ -87,7 +91,7 @@ startBatch cq@CompletionQueue{..} call opArray opArraySize tag = -- queue after we begin the shutdown process. Errors with -- 'GRPCIOShutdownFailure' if the queue can't be shut down within 5 seconds. shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ()) -shutdownCompletionQueue (CompletionQueue{..}) = do +shutdownCompletionQueue CompletionQueue{..} = do atomically $ writeTVar shuttingDown True atomically $ readTVar currentPushers >>= \x -> check (x == 0) atomically $ readTVar currentPluckers >>= \x -> check (x == 0) @@ -105,7 +109,7 @@ shutdownCompletionQueue (CompletionQueue{..}) = do ev <- C.withDeadlineSeconds 1 $ \deadline -> C.grpcCompletionQueueNext unsafeCQ deadline C.reserved grpcDebug $ "drainLoop: next() call got " ++ show ev - case (C.eventCompletionType ev) of + case C.eventCompletionType ev of C.QueueShutdown -> return () C.QueueTimeout -> drainLoop C.OpComplete -> drainLoop @@ -133,68 +137,46 @@ channelCreateCall -- | Create the call object to handle a registered call. serverRequestCall :: C.Server -> CompletionQueue - -> RegisteredMethod + -> RegisteredMethod mt -> IO (Either GRPCIOError ServerCall) -serverRequestCall - server cq@CompletionQueue{..} RegisteredMethod{..} = - withPermission Push cq $ - bracket (liftM2 (,) malloc malloc) - (\(p1,p2) -> free p1 >> free p2) - $ \(deadlinePtr, callPtr) -> - C.withByteBufferPtr $ \bbPtr -> - C.withMetadataArrayPtr $ \metadataArrayPtr -> do - metadataArray <- peek metadataArrayPtr - tag <- newTag cq - grpcDebug $ "serverRequestCall(R): tag is " ++ show tag - callError <- C.grpcServerRequestRegisteredCall - server methodHandle callPtr deadlinePtr - metadataArray bbPtr unsafeCQ unsafeCQ tag - grpcDebug $ "serverRequestCall(R): callError: " - ++ show callError - if callError /= C.CallOk - then do grpcDebug "serverRequestCall(R): callError. cleaning up" - return $ Left $ GRPCIOCallError callError - else do pluckResult <- pluck cq tag Nothing - grpcDebug $ "serverRequestCall(R): finished pluck:" - ++ show pluckResult - case pluckResult of - Left x -> do - grpcDebug "serverRequestCall(R): cleanup pluck err" - return $ Left x - Right () -> do - rawCall <- peek callPtr - deadline <- convertDeadline deadlinePtr - payload <- convertPayload bbPtr - meta <- convertMeta metadataArrayPtr - let assembledCall = ServerCall rawCall - meta - payload - deadline - grpcDebug "serverRequestCall(R): About to return" - return $ Right assembledCall - where convertDeadline deadline = do - --gRPC gives us a deadline that is just a delta, so we convert it - --to a proper deadline. - deadline' <- C.timeSpec <$> peek deadline - now <- getTime Monotonic - return $ now + deadline' - convertPayload bbPtr = do - -- TODO: the reason this returns @Maybe ByteString@ is because the - -- gRPC library calls the underlying out parameter - -- "optional_payload". I am not sure exactly in what cases it - -- won't be present. The C++ library checks a - -- has_request_payload_ bool and passes in nullptr to - -- request_registered_call if the bool is false, so we - -- may be able to do the payload present/absent check earlier. - bb@(C.ByteBuffer rawPtr) <- peek bbPtr - if rawPtr == nullPtr - then return Nothing - else do bs <- C.copyByteBufferToByteString bb - return $ Just bs - convertMeta requestMetadataRecv = do - mArray <- peek requestMetadataRecv - metamap <- C.getAllMetadataArray mArray - return metamap +serverRequestCall s cq@CompletionQueue{.. } RegisteredMethod{..} = + -- NB: The method type dictates whether or not a payload is present, according + -- to the payloadHandling function. We do not allocate a buffer for the + -- payload when it is not present. + withPermission Push cq . with allocs $ \(dead, call, pay, meta) -> do + md <- peek meta + tag <- newTag cq + dbug $ "tag is " ++ show tag + ce <- C.grpcServerRequestRegisteredCall s methodHandle call + dead md pay unsafeCQ unsafeCQ tag + dbug $ "callError: " ++ show ce + runExceptT $ case ce of + C.CallOk -> do + ExceptT $ do + r <- pluck cq tag Nothing + dbug $ "pluck finished:" ++ show r + return r + lift $ + ServerCall + <$> peek call + <*> C.getAllMetadataArray md + <*> (if havePay then toBS pay else return Nothing) + <*> liftM2 (+) (getTime Monotonic) (C.timeSpec <$> peek dead) + -- gRPC gives us a deadline that is just a delta, so we convert it + -- to a proper deadline. + _ -> throwE (GRPCIOCallError ce) + where + allocs = (,,,) <$> ptr <*> ptr <*> pay <*> md + where + md = managed C.withMetadataArrayPtr + pay = if havePay then managed C.withByteBufferPtr else return nullPtr + ptr :: forall a. Storable a => Managed (Ptr a) + ptr = managed (bracket malloc free) + dbug = grpcDebug . ("serverRequestCall(R): " ++) + havePay = payloadHandling methodType /= C.SrmPayloadNone + toBS p = peek p >>= \bb@(C.ByteBuffer rawPtr) -> + if | rawPtr == nullPtr -> return Nothing + | otherwise -> Just <$> C.copyByteBufferToByteString bb -- | Register the server's completion queue. Must be done before the server is -- started. diff --git a/src/Network/GRPC/LowLevel/GRPC.hs b/src/Network/GRPC/LowLevel/GRPC.hs index 17e91aa..bb1999f 100644 --- a/src/Network/GRPC/LowLevel/GRPC.hs +++ b/src/Network/GRPC/LowLevel/GRPC.hs @@ -1,23 +1,26 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE StandaloneDeriving #-} module Network.GRPC.LowLevel.GRPC where -import Control.Concurrent (threadDelay) + +import Control.Concurrent (threadDelay) import Control.Exception -import qualified Data.ByteString as B -import qualified Data.Map as M -import Data.String (IsString) -import qualified Network.GRPC.Unsafe as C +import Data.String (IsString) +import qualified Data.ByteString as B +import qualified Data.Map as M +import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.Op as C + #ifdef DEBUG -import GHC.Conc (myThreadId) +import GHC.Conc (myThreadId) #endif type MetadataMap = M.Map B.ByteString B.ByteString -newtype StatusDetails = StatusDetails B.ByteString deriving (Show, Eq, IsString) +newtype StatusDetails = StatusDetails B.ByteString + deriving (Eq, IsString, Monoid, Show) -- | Functions as a proof that the gRPC core has been started. The gRPC core -- must be initialized to create any gRPC state, so this is a requirement for @@ -44,6 +47,9 @@ data GRPCIOError = GRPCIOCallError C.CallError -- reasonable amount of time. | GRPCIOUnknownError | GRPCIOBadStatusCode C.StatusCode StatusDetails + + | GRPCIOInternalMissingExpectedPayload + | GRPCIOInternalUnexpectedRecv String -- debugging description deriving (Show, Eq) throwIfCallError :: C.CallError -> Either GRPCIOError () diff --git a/src/Network/GRPC/LowLevel/Op.hs b/src/Network/GRPC/LowLevel/Op.hs index e34fc2d..cd16016 100644 --- a/src/Network/GRPC/LowLevel/Op.hs +++ b/src/Network/GRPC/LowLevel/Op.hs @@ -1,9 +1,17 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ViewPatterns #-} module Network.GRPC.LowLevel.Op where +import Control.Arrow import Control.Exception +import Control.Monad +import Control.Monad.Trans.Class (MonadTrans(lift)) +import Control.Monad.Trans.Except +import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.Map.Strict as M import Data.Maybe (catMaybes) @@ -20,6 +28,9 @@ import qualified Network.GRPC.Unsafe.ByteBuffer as C import qualified Network.GRPC.Unsafe.Metadata as C import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.Slice as C (Slice, freeSlice) +import Pipes ((>->)) +import qualified Pipes as P +import qualified Pipes.Core as P -- | Sum describing all possible send and receive operations that can be batched -- and executed by gRPC. Usually these are processed in a handful of @@ -144,7 +155,8 @@ withOpArrayAndCtxts ops = bracket setup teardown data OpRecvResult = OpRecvInitialMetadataResult MetadataMap | OpRecvMessageResult (Maybe B.ByteString) - -- ^ If the client or server dies, we might not receive a response body, in + -- ^ If a streaming call is in progress and the stream terminates normally, + -- or If the client or server dies, we might not receive a response body, in -- which case this will be 'Nothing'. | OpRecvStatusOnClientResult MetadataMap C.StatusCode B.ByteString | OpRecvCloseOnServerResult Bool -- ^ True if call was cancelled. @@ -202,7 +214,6 @@ resultFromOpContext _ = do -- GRPC_CALL_ERROR_TOO_MANY_OPERATIONS error if we use the same 'Op' twice in -- the same batch, so we might want to change the list to a set. I don't think -- order matters within a batch. Need to check. - runOps :: C.Call -- ^ 'Call' that this batch is associated with. One call can be -- associated with many batches. @@ -232,6 +243,12 @@ runOps call cq ops = fmap (Right . catMaybes) $ mapM resultFromOpContext contexts Left err -> return $ Left err +runOps' :: C.Call + -> CompletionQueue + -> [Op] + -> ExceptT GRPCIOError IO [OpRecvResult] +runOps' c cq = ExceptT . runOps c cq + -- | If response status info is present in the given 'OpRecvResult's, returns -- a tuple of trailing metadata, status code, and status details. extractStatusInfo :: [OpRecvResult] @@ -240,3 +257,100 @@ extractStatusInfo [] = Nothing extractStatusInfo (OpRecvStatusOnClientResult meta code details:_) = Just (meta, code, details) extractStatusInfo (_:xs) = extractStatusInfo xs + +-------------------------------------------------------------------------------- +-- Types and helpers for common ops batches + +type SendSingle a + = C.Call + -> CompletionQueue + -> a + -> ExceptT GRPCIOError IO () + +type RecvSingle a + = C.Call + -> CompletionQueue + -> ExceptT GRPCIOError IO a + +sendSingle :: SendSingle Op +sendSingle c cq op = void (runOps' c cq [op]) + +sendInitialMetadata :: SendSingle MetadataMap +sendInitialMetadata c cq = sendSingle c cq . OpSendInitialMetadata + +sendStatusFromServer :: SendSingle (MetadataMap, C.StatusCode, StatusDetails) +sendStatusFromServer c cq (md, st, ds) = + sendSingle c cq (OpSendStatusFromServer md st ds) + +recvInitialMetadata :: RecvSingle MetadataMap +recvInitialMetadata c cq = runOps' c cq [OpRecvInitialMetadata] >>= \case + [OpRecvInitialMetadataResult md] + -> return md + _ -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMetadata") + +recvStatusOnClient :: RecvSingle (MetadataMap, C.StatusCode, StatusDetails) +recvStatusOnClient c cq = runOps' c cq [OpRecvStatusOnClient] >>= \case + [OpRecvStatusOnClientResult md st ds] + -> return (md, st, StatusDetails ds) + _ -> throwE (GRPCIOInternalUnexpectedRecv "recvStatusOnClient") + +-------------------------------------------------------------------------------- +-- Streaming types and helpers + +-- | Requests use Nothing to denote read, Just to denote +-- write. Right-constructed responses use Just to indicate a successful read, +-- and Nothing to denote end of stream when reading or a successful write. +type Streaming a = + P.Client (Maybe ByteString) (Either GRPCIOError (Maybe ByteString)) IO a + +-- | Run the given 'Streaming' operation via an appropriate upstream +-- proxy. I.e., if called on the client side, the given 'Streaming' operation +-- talks to a server proxy, and vice versa. +runStreamingProxy :: String + -- ^ context string for including in errors + -> C.Call + -- ^ the call associated with this streaming operation + -> CompletionQueue + -- ^ the completion queue for ops batches + -> Streaming a + -- ^ the requesting side of the streaming operation + -> ExceptT GRPCIOError IO a +runStreamingProxy nm c cq + = ExceptT . P.runEffect . (streamingProxy nm c cq P.+>>) . fmap Right + +streamingProxy :: String + -- ^ context string for including in errors + -> C.Call + -- ^ the call associated with this streaming operation + -> CompletionQueue + -- ^ the completion queue for ops batches + -> Maybe ByteString + -- ^ the request to the proxy + -> P.Server + (Maybe ByteString) + (Either GRPCIOError (Maybe ByteString)) + IO (Either GRPCIOError a) +streamingProxy nm c cq = maybe recv send + where + recv = run [OpRecvMessage] >>= \case + RecvMsgRslt mr -> rsp mr >>= streamingProxy nm c cq + Right{} -> err (urecv "recv") + Left e -> err e + send msg = run [OpSendMessage msg] >>= \case + Right [] -> rsp Nothing >>= streamingProxy nm c cq + Right _ -> err (urecv "send") + Left e -> err e + err e = P.respond (Left e) >> return (Left e) + rsp = P.respond . Right + run = lift . runOps c cq + urecv = GRPCIOInternalUnexpectedRecv . (nm ++) + +type StreamRecv = Streaming (Either GRPCIOError (Maybe ByteString)) +streamRecv :: StreamRecv +streamRecv = P.request Nothing + +type StreamSend = ByteString -> Streaming (Either GRPCIOError ()) +streamSend :: StreamSend +streamSend = fmap void . P.request . Just + +pattern RecvMsgRslt mmsg <- Right [OpRecvMessageResult mmsg] diff --git a/src/Network/GRPC/LowLevel/Server.hs b/src/Network/GRPC/LowLevel/Server.hs index 073a4c4..6fce45a 100644 --- a/src/Network/GRPC/LowLevel/Server.hs +++ b/src/Network/GRPC/LowLevel/Server.hs @@ -1,12 +1,22 @@ -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} -- | This module defines data structures and operations pertaining to registered -- servers using registered calls; for unregistered support, see -- `Network.GRPC.LowLevel.Server.Unregistered`. module Network.GRPC.LowLevel.Server where +import Control.Arrow import Control.Exception (bracket, finally) import Control.Monad +import Control.Monad.Trans.Class (MonadTrans (lift)) +import Control.Monad.Trans.Except import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, @@ -19,15 +29,20 @@ import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, import Network.GRPC.LowLevel.GRPC import Network.GRPC.LowLevel.Op import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.ChannelArgs as C +import qualified Network.GRPC.Unsafe.Op as C +import qualified Pipes as P +import qualified Pipes.Core as P -- | Wraps various gRPC state needed to run a server. data Server = Server - { internalServer :: C.Server - , serverCQ :: CompletionQueue - , registeredMethods :: [RegisteredMethod] - , serverConfig :: ServerConfig + { internalServer :: C.Server + , serverCQ :: CompletionQueue + , normalMethods :: [RegisteredMethod 'Normal] + , sstreamingMethods :: [RegisteredMethod 'ServerStreaming] + , cstreamingMethods :: [RegisteredMethod 'ClientStreaming] + , bidiStreamingMethods :: [RegisteredMethod 'BiDiStreaming] + , serverConfig :: ServerConfig } -- | Configuration needed to start a server. @@ -39,13 +54,10 @@ data ServerConfig = ServerConfig -- ^ Port on which to listen for requests. , methodsToRegister :: [(MethodName, GRPCMethodType)] -- ^ List of (method name, method type) tuples specifying all methods to - -- register. You can also handle other unregistered methods with - -- `serverHandleNormalCall`. - , serverArgs :: [C.Arg] - -- ^ Optional arguments for setting up the - -- channel on the server. Supplying an empty - -- list will cause the channel to use gRPC's - -- default options. + -- register. + , serverArgs :: [C.Arg] + -- ^ Optional arguments for setting up the channel on the server. Supplying an + -- empty list will cause the channel to use gRPC's default options. } deriving (Show, Eq) @@ -62,34 +74,46 @@ startServer grpc conf@ServerConfig{..} = error $ "Unable to bind port: " ++ show port cq <- createCompletionQueue grpc serverRegisterCompletionQueue server cq - methods <- forM methodsToRegister $ \(name, mtype) -> - serverRegisterMethod server name e mtype + + -- Register methods according to their GRPCMethodType kind. It's a bit ugly + -- to partition them this way, but we get very convenient phantom typing + -- elsewhere by doing so. + (ns, ss, cs, bs) <- do + let f (ns, ss, cs, bs) (nm, mt) = do + let reg = serverRegisterMethod server nm e mt + case mt of + Normal -> ( , ss, cs, bs) . (:ns) <$> reg + ServerStreaming -> (ns, , cs, bs) . (:ss) <$> reg + ClientStreaming -> (ns, ss, , bs) . (:cs) <$> reg + BiDiStreaming -> (ns, ss, cs, ) . (:bs) <$> reg + foldM f ([],[],[],[]) methodsToRegister + C.grpcServerStart server - return $ Server server cq methods conf + return $ Server server cq ns ss cs bs conf stopServer :: Server -> IO () -- TODO: Do method handles need to be freed? -stopServer (Server server cq _ _) = do +stopServer Server{..} = do grpcDebug "stopServer: calling shutdownNotify." shutdownNotify grpcDebug "stopServer: cancelling all calls." - C.grpcServerCancelAllCalls server + C.grpcServerCancelAllCalls internalServer grpcDebug "stopServer: call grpc_server_destroy." - C.grpcServerDestroy server + C.grpcServerDestroy internalServer grpcDebug "stopServer: shutting down CQ." shutdownCQ where shutdownCQ = do - shutdownResult <- shutdownCompletionQueue cq + shutdownResult <- shutdownCompletionQueue serverCQ case shutdownResult of Left _ -> do putStrLn "Warning: completion queue didn't shut down." putStrLn "Trying to stop server anyway." Right _ -> return () shutdownNotify = do let shutdownTag = C.tag 0 - serverShutdownAndNotify server cq shutdownTag + serverShutdownAndNotify internalServer serverCQ shutdownTag grpcDebug "called serverShutdownAndNotify; plucking." - shutdownEvent <- pluck cq shutdownTag (Just 30) + shutdownEvent <- pluck serverCQ shutdownTag (Just 30) grpcDebug $ "shutdownNotify: got shutdown event" ++ show shutdownEvent case shutdownEvent of -- This case occurs when we pluck but the queue is already in the @@ -100,7 +124,7 @@ stopServer (Server server cq _ _) = do -- Uses 'bracket' to safely start and stop a server, even if exceptions occur. withServer :: GRPC -> ServerConfig -> (Server -> IO a) -> IO a -withServer grpc cfg f = bracket (startServer grpc cfg) stopServer f +withServer grpc cfg = bracket (startServer grpc cfg) stopServer -- | Register a method on a server. The 'RegisteredMethod' type can then be used -- to wait for a request to arrive. Note: gRPC claims this must be called before @@ -118,25 +142,23 @@ serverRegisterMethod :: C.Server -> GRPCMethodType -- ^ Type of method this will be. In the future, this will -- be used to switch to the correct handling logic. - -- Currently, the only valid choice is 'Normal'. - -> IO RegisteredMethod -serverRegisterMethod internalServer meth e Normal = do - handle <- C.grpcServerRegisterMethod internalServer - (unMethodName meth) (unEndpoint e) - grpcDebug $ "registered method to handle " ++ show handle - return $ RegisteredMethod Normal meth e handle -serverRegisterMethod _ _ _ _ = error "Streaming methods not implemented yet." + -> IO (RegisteredMethod mt) +serverRegisterMethod internalServer meth e mty = + RegisteredMethod mty meth e <$> do + h <- C.grpcServerRegisterMethod internalServer + (unMethodName meth) (unEndpoint e) (payloadHandling mty) + grpcDebug $ "registered method handle: " ++ show h ++ " of type " ++ show mty + return h -- | Create a 'Call' with which to wait for the invocation of a registered -- method. serverCreateCall :: Server - -> RegisteredMethod + -> RegisteredMethod mt -> IO (Either GRPCIOError ServerCall) -serverCreateCall Server{..} rm = - serverRequestCall internalServer serverCQ rm +serverCreateCall Server{..} = serverRequestCall internalServer serverCQ withServerCall :: Server - -> RegisteredMethod + -> RegisteredMethod mt -> (ServerCall -> IO (Either GRPCIOError a)) -> IO (Either GRPCIOError a) withServerCall server regmethod f = do @@ -147,61 +169,112 @@ withServerCall server regmethod f = do where logDestroy c = grpcDebug "withServerRegisteredCall: destroying." >> destroyServerCall c -serverOpsSendNormalRegisteredResponse :: ByteString - -> MetadataMap - -- ^ initial metadata - -> MetadataMap - -- ^ trailing metadata - -> C.StatusCode - -> StatusDetails - -> [Op] -serverOpsSendNormalRegisteredResponse - body initMetadata trailingMeta code details = - [OpSendInitialMetadata initMetadata, - OpRecvCloseOnServer, - OpSendMessage body, - OpSendStatusFromServer trailingMeta code details] +-------------------------------------------------------------------------------- +-- serverReader (server side of client streaming mode) --- | A handler for an registered server call; bytestring parameter is request +type ServerReaderHandler + = ServerCall + -> StreamRecv + -> Streaming (Maybe ByteString, MetadataMap, C.StatusCode, StatusDetails) + +serverReader :: Server + -> RegisteredMethod 'ClientStreaming + -> MetadataMap -- ^ initial server metadata + -> ServerReaderHandler + -> IO (Either GRPCIOError ()) +serverReader s@Server{ serverCQ = cq } rm initMeta f = withServerCall s rm go + where + go sc@(unServerCall -> c) = runExceptT $ do + lift $ debugServerCall sc + (mmsg, trailMD, st, ds) <- + runStreamingProxy "serverReader" c cq (f sc streamRecv) + runOps' c cq ( OpSendInitialMetadata initMeta + : OpSendStatusFromServer trailMD st ds + : maybe [] ((:[]) . OpSendMessage) mmsg + ) + return () + +-------------------------------------------------------------------------------- +-- serverWriter (server side of server streaming mode) + +type ServerWriterHandler + = ServerCall + -> StreamSend + -> Streaming (MetadataMap, C.StatusCode, StatusDetails) + +-- | Wait for and then handle a registered, server-streaming call. +serverWriter :: Server + -> RegisteredMethod 'ServerStreaming + -> MetadataMap + -- ^ Initial server metadata + -> ServerWriterHandler + -> IO (Either GRPCIOError ()) +serverWriter s@Server{ serverCQ = cq } rm initMeta f = withServerCall s rm go + where + go sc@ServerCall{ unServerCall = c } = runExceptT $ do + lift (debugServerCall sc) + sendInitialMetadata c cq initMeta + st <- runStreamingProxy "serverWriter" c cq (f sc streamSend) + sendStatusFromServer c cq st + +-------------------------------------------------------------------------------- +-- serverRW (server side of bidirectional streaming mode) + +type ServerRWHandler + = ServerCall + -> StreamRecv + -> StreamSend + -> Streaming (MetadataMap, C.StatusCode, StatusDetails) + +serverRW :: Server + -> RegisteredMethod 'BiDiStreaming + -> MetadataMap + -- ^ initial server metadata + -> ServerRWHandler + -> IO (Either GRPCIOError ()) +serverRW s@Server{ serverCQ = cq } rm initMeta f = withServerCall s rm go + where + go sc@(unServerCall -> c) = runExceptT $ do + lift $ debugServerCall sc + sendInitialMetadata c cq initMeta + st <- runStreamingProxy "serverRW" c cq (f sc streamRecv streamSend) + sendStatusFromServer c cq st + +-------------------------------------------------------------------------------- +-- serverHandleNormalCall (server side of normal request/response) + +-- | A handler for a registered server call; bytestring parameter is request -- body, with the bytestring response body in the result tuple. The first -- metadata parameter refers to the request metadata, with the two metadata -- values in the result tuple being the initial and trailing metadata -- respectively. We pass in the 'ServerCall' so that the server can call -- 'serverCallCancel' on it if needed. - --- TODO: make a more rigid type for this with a Maybe MetadataMap for the --- trailing meta, and use it for both kinds of call handlers. type ServerHandler = ServerCall -> ByteString -> MetadataMap -> IO (ByteString, MetadataMap, C.StatusCode, StatusDetails) --- TODO: we will want to replace this with some more general concept that also --- works with streaming calls in the future. -- | Wait for and then handle a normal (non-streaming) call. serverHandleNormalCall :: Server - -> RegisteredMethod + -> RegisteredMethod 'Normal -> MetadataMap -- ^ Initial server metadata -> ServerHandler -> IO (Either GRPCIOError ()) -serverHandleNormalCall s@Server{..} rm initMeta f = do - withServerCall s rm $ \call -> do - grpcDebug "serverHandleNormalCall(R): starting batch." - debugServerCall call - let payload = optionalPayload call - case payload of - --TODO: what should we do with an empty payload? Have the handler take - -- @Maybe ByteString@? Need to figure out when/why payload would be empty. - Nothing -> error "serverHandleNormalCall(R): payload empty." - Just requestBody -> do - let requestMeta = requestMetadataRecv call - (respBody, trailingMeta, status, details) <- f call - requestBody - requestMeta - let respOps = serverOpsSendNormalRegisteredResponse - respBody initMeta trailingMeta status details - respOpsResults <- runOps (unServerCall call) serverCQ respOps - grpcDebug "serverHandleNormalCall(R): finished response ops." - case respOpsResults of - Left x -> return $ Left x - Right _ -> return $ Right () +serverHandleNormalCall s@Server{ serverCQ = cq } rm initMeta f = + withServerCall s rm go + where + go sc@(unServerCall -> call) = do + grpcDebug "serverHandleNormalCall(R): starting batch." + debugServerCall sc + case optionalPayload sc of + Nothing -> return (Left GRPCIOInternalMissingExpectedPayload) + Just pay -> do + (rspBody, trailMeta, status, ds) <- f sc pay (requestMetadataRecv sc) + eea <- runOps call cq + [ OpSendInitialMetadata initMeta + , OpRecvCloseOnServer + , OpSendMessage rspBody + , OpSendStatusFromServer trailMeta status ds + ] + <* grpcDebug "serverHandleNormalCall(R): finished response ops." + return (void eea) diff --git a/src/Network/GRPC/LowLevel/Server/Unregistered.hs b/src/Network/GRPC/LowLevel/Server/Unregistered.hs index 02766de..103dbcb 100644 --- a/src/Network/GRPC/LowLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Server/Unregistered.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} module Network.GRPC.LowLevel.Server.Unregistered where @@ -7,7 +8,8 @@ import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call.Unregistered import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall) import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op (Op(..), OpRecvResult (..), runOps) +import Network.GRPC.LowLevel.Op (Op (..), OpRecvResult (..), + runOps) import Network.GRPC.LowLevel.Server (Server (..)) import qualified Network.GRPC.Unsafe.Op as C @@ -57,24 +59,31 @@ serverHandleNormalCall :: Server -> MetadataMap -- ^ Initial server metadata. -> ServerHandler -> IO (Either GRPCIOError ()) -serverHandleNormalCall s@Server{..} srvMetadata f = do +serverHandleNormalCall s@Server{..} srvMetadata f = withServerCall s $ \call@ServerCall{..} -> do grpcDebug "serverHandleNormalCall(U): starting batch." - let recvOps = serverOpsGetNormalCall srvMetadata - opResults <- runOps unServerCall serverCQ recvOps - case opResults of - Left x -> do grpcDebug "serverHandleNormalCall(U): ops failed; aborting" - return $ Left x - Right [OpRecvMessageResult (Just body)] -> do - grpcDebug $ "got client metadata: " ++ show requestMetadataRecv - grpcDebug $ "call_details host is: " ++ show callHost - (respBody, respMetadata, status, details) <- f call body - let respOps = serverOpsSendNormalResponse - respBody respMetadata status details - respOpsResults <- runOps unServerCall serverCQ respOps - case respOpsResults of - Left x -> do grpcDebug "serverHandleNormalCall(U): resp failed." - return $ Left x - Right _ -> grpcDebug "serverHandleNormalCall(U): ops done." - >> return (Right ()) - x -> error $ "impossible pattern match: " ++ show x + runOps unServerCall serverCQ + [ OpSendInitialMetadata srvMetadata + , OpRecvMessage + ] + >>= \case + Left x -> do + grpcDebug "serverHandleNormalCall(U): ops failed; aborting" + return $ Left x + Right [OpRecvMessageResult (Just body)] -> do + grpcDebug $ "got client metadata: " ++ show requestMetadataRecv + grpcDebug $ "call_details host is: " ++ show callHost + (rspBody, rspMeta, status, ds) <- f call body + runOps unServerCall serverCQ + [ OpRecvCloseOnServer + , OpSendMessage rspBody, + OpSendStatusFromServer rspMeta status ds + ] + >>= \case + Left x -> do + grpcDebug "serverHandleNormalCall(U): resp failed." + return $ Left x + Right _ -> do + grpcDebug "serverHandleNormalCall(U): ops done." + return $ Right () + x -> error $ "impossible pattern match: " ++ show x diff --git a/src/Network/GRPC/Unsafe.chs b/src/Network/GRPC/Unsafe.chs index 3c96e87..c58cc87 100644 --- a/src/Network/GRPC/Unsafe.chs +++ b/src/Network/GRPC/Unsafe.chs @@ -61,6 +61,8 @@ instance Storable Call where peek p = fmap Call (peek (castPtr p)) poke p (Call r) = poke (castPtr p) r +{#enum grpc_server_register_method_payload_handling as ServerRegisterMethodPayloadHandling {underscoreToCase} deriving (Eq, Show)#} + -- | A 'Tag' is an identifier that is used with a 'CompletionQueue' to signal -- that the corresponding operation has completed. newtype Tag = Tag {unTag :: Ptr ()} deriving (Show, Eq) @@ -235,7 +237,7 @@ getPeerPeek cstr = do {`GrpcChannelArgs',unReserved `Reserved'} -> `Server'#} {#fun grpc_server_register_method_ as ^ - {`Server', `String', `String'} -> `CallHandle' CallHandle#} + {`Server', `String', `String', `ServerRegisterMethodPayloadHandling'} -> `CallHandle' CallHandle#} {#fun grpc_server_register_completion_queue as ^ {`Server', `CompletionQueue', unReserved `Reserved'} -> `()'#} diff --git a/stack.yaml b/stack.yaml index a614a5b..1d865d1 100644 --- a/stack.yaml +++ b/stack.yaml @@ -8,7 +8,7 @@ resolver: lts-5.10 packages: - '.' # Packages to be pulled from upstream that are not in the resolver (e.g., acme-missiles-0.3) -extra-deps: [] +extra-deps: [managed-1.0.5] # Override default flag values for local packages and extra-deps flags: {} diff --git a/tests/LowLevelTests.hs b/tests/LowLevelTests.hs index 9939a8a..92e8de8 100644 --- a/tests/LowLevelTests.hs +++ b/tests/LowLevelTests.hs @@ -1,13 +1,17 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} module LowLevelTests where import Control.Concurrent (threadDelay) import Control.Concurrent.Async import Control.Monad +import Control.Monad.Managed import Data.ByteString (ByteString, isPrefixOf, isSuffixOf) @@ -16,6 +20,8 @@ import Network.GRPC.LowLevel import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Client.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U +import Pipes ((>->)) +import qualified Pipes as P import Test.Tasty import Test.Tasty.HUnit as HU (Assertion, assertBool, @@ -42,27 +48,25 @@ lowLevelTests = testGroup "Unit tests of low-level Haskell library" , testCustomUserAgent , testClientCompression , testClientServerCompression + , testClientStreaming + , testServerStreaming + , testBiDiStreaming ] testGRPCBracket :: TestTree testGRPCBracket = - testCase "Start/stop GRPC" $ withGRPC nop + testCase "Start/stop GRPC" $ runManaged $ void mgdGRPC testCompletionQueueCreateDestroy :: TestTree testCompletionQueueCreateDestroy = - testCase "Create/destroy CQ" $ withGRPC $ \g -> - withCompletionQueue g nop + testCase "Create/destroy CQ" $ runManaged $ do + g <- mgdGRPC + liftIO (withCompletionQueue g nop) testClientCreateDestroy :: TestTree testClientCreateDestroy = clientOnlyTest "start/stop" nop -testClientCall :: TestTree -testClientCall = - clientOnlyTest "create/destroy call" $ \c -> do - r <- U.withClientCall c "/foo" 10 $ const $ return $ Right () - r @?= Right () - testClientTimeoutNoServer :: TestTree testClientTimeoutNoServer = clientOnlyTest "request timeout when server DNE" $ \c -> do @@ -97,13 +101,13 @@ testMixRegisteredUnregistered = concurrently regThread unregThread return () where regThread = do - let rm = head (registeredMethods s) + let rm = head (normalMethods s) r <- serverHandleNormalCall s rm dummyMeta $ \_ body _ -> do body @?= "Hello" - return ("reply test", dummyMeta, StatusOk, StatusDetails "") + return ("reply test", dummyMeta, StatusOk, "") return () unregThread = do - r1 <- U.serverHandleNormalCall s mempty $ \call _ -> do + U.serverHandleNormalCall s mempty $ \call _ -> do U.callMethod call @?= "/bar" return ("", mempty, StatusOk, StatusDetails "Wrong endpoint") @@ -130,13 +134,11 @@ testPayload = initMD @?= dummyMeta trailMD @?= dummyMeta server s = do - length (registeredMethods s) @?= 1 - let rm = head (registeredMethods s) + let rm = head (normalMethods s) r <- serverHandleNormalCall s rm dummyMeta $ \_ reqBody reqMD -> do reqBody @?= "Hello!" checkMD "Server metadata mismatch" clientMD reqMD - return ("reply test", dummyMeta, StatusOk, - StatusDetails "details string") + return ("reply test", dummyMeta, StatusOk, "details string") r @?= Right () testServerCancel :: TestTree @@ -146,21 +148,130 @@ testServerCancel = client c = do rm <- clientRegisterMethod c "/foo" Normal res <- clientRequest c rm 10 "" mempty - res @?= Left (GRPCIOBadStatusCode StatusCancelled - (StatusDetails - "Received RST_STREAM err=8")) + res @?= badStatus StatusCancelled server s = do - let rm = head (registeredMethods s) + let rm = head (normalMethods s) r <- serverHandleNormalCall s rm mempty $ \c _ _ -> do serverCallCancel c StatusCancelled "" return (mempty, mempty, StatusCancelled, "") r @?= Right () +testServerStreaming :: TestTree +testServerStreaming = + csTest "server streaming" client server [("/feed", ServerStreaming)] + where + clientInitMD = [("client","initmd")] + serverInitMD = [("server","initmd")] + clientPay = "FEED ME!" + pays = ["ONE", "TWO", "THREE", "FOUR"] :: [ByteString] + + client c = do + rm <- clientRegisterMethod c "/feed" ServerStreaming + eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do + liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD + forM_ pays $ \p -> recv `is` Right (Just p) + recv `is` Right Nothing + eea @?= Right (dummyMeta, StatusOk, "dtls") + + server s = do + let rm = head (sstreamingMethods s) + eea <- serverWriter s rm serverInitMD $ \sc send -> do + liftIO $ do + checkMD "Client request metadata mismatch" + clientInitMD (requestMetadataRecv sc) + case optionalPayload sc of + Nothing -> assertFailure "expected optional payload" + Just pay -> pay @?= clientPay + forM_ pays $ \p -> send p `is` Right () + return (dummyMeta, StatusOk, "dtls") + eea @?= Right () + +testClientStreaming :: TestTree +testClientStreaming = + csTest "client streaming" client server [("/slurp", ClientStreaming)] + where + clientInitMD = [("a","b")] + serverInitMD = [("x","y")] + trailMD = dummyMeta + serverRsp = "serverReader reply" + serverDtls = "deets" + serverStatus = StatusOk + pays = ["P_ONE", "P_TWO", "P_THREE"] :: [ByteString] + + client c = do + rm <- clientRegisterMethod c "/slurp" ClientStreaming + eea <- clientWriter c rm 10 clientInitMD $ \send -> do + -- liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD + forM_ pays $ \p -> send p `is` Right () + eea @?= Right (Just serverRsp, serverInitMD, trailMD, serverStatus, serverDtls) + + server s = do + let rm = head (cstreamingMethods s) + eea <- serverReader s rm serverInitMD $ \sc recv -> do + liftIO $ checkMD "Client request metadata mismatch" + clientInitMD (requestMetadataRecv sc) + forM_ pays $ \p -> recv `is` Right (Just p) + recv `is` Right Nothing + return (Just serverRsp, trailMD, serverStatus, serverDtls) + eea @?= Right () + +testBiDiStreaming :: TestTree +testBiDiStreaming = + csTest "bidirectional streaming" client server [("/bidi", BiDiStreaming)] + where + clientInitMD = [("bidi-streaming","client")] + serverInitMD = [("bidi-streaming","server")] + trailMD = dummyMeta + serverStatus = StatusOk + serverDtls = "deets" + + client c = do + rm <- clientRegisterMethod c "/bidi" BiDiStreaming + eea <- clientRW c rm 10 clientInitMD $ \initMD recv send -> do + liftIO $ checkMD "Server initial metadata mismatch" + serverInitMD initMD + send "cw0" `is` Right () + recv `is` Right (Just "sw0") + send "cw1" `is` Right () + recv `is` Right (Just "sw1") + recv `is` Right (Just "sw2") + return () + eea @?= Right (trailMD, serverStatus, serverDtls) + + server s = do + let rm = head (bidiStreamingMethods s) + eea <- serverRW s rm serverInitMD $ \sc recv send -> do + liftIO $ checkMD "Client request metadata mismatch" + clientInitMD (requestMetadataRecv sc) + recv `is` Right (Just "cw0") + send "sw0" `is` Right () + recv `is` Right (Just "cw1") + send "sw1" `is` Right () + send "sw2" `is` Right () + recv `is` Right Nothing + return (trailMD, serverStatus, serverDtls) + eea @?= Right () + +-------------------------------------------------------------------------------- +-- Unregistered tests + +testClientCall :: TestTree +testClientCall = + clientOnlyTest "create/destroy call" $ \c -> do + r <- U.withClientCall c "/foo" 10 $ const $ return $ Right () + r @?= Right () + +testServerCall :: TestTree +testServerCall = + serverOnlyTest "create/destroy call" [] $ \s -> do + r <- U.withServerCall s $ const $ return $ Right () + r @?= Left GRPCIOTimeout + testPayloadUnregistered :: TestTree testPayloadUnregistered = csTest "unregistered normal request/response" client server [] where - client c = do + client c = U.clientRequest c "/foo" 10 "Hello!" mempty >>= do checkReqRslt $ \NormalRequestResult{..} -> do rspCode @?= StatusOk @@ -186,13 +297,13 @@ testGoaway = clientRequest c rm 10 "" mempty lastResult <- clientRequest c rm 1 "" mempty assertBool "Client handles server shutdown gracefully" $ - lastResult == unavailableStatus + lastResult == badStatus StatusUnavailable || - lastResult == deadlineExceededStatus + lastResult == badStatus StatusDeadlineExceeded || lastResult == Left GRPCIOTimeout server s = do - let rm = head (registeredMethods s) + let rm = head (normalMethods s) serverHandleNormalCall s rm mempty dummyHandler serverHandleNormalCall s rm mempty dummyHandler return () @@ -204,9 +315,9 @@ testSlowServer = client c = do rm <- clientRegisterMethod c "/foo" Normal result <- clientRequest c rm 1 "" mempty - result @?= deadlineExceededStatus + result @?= badStatus StatusDeadlineExceeded server s = do - let rm = head (registeredMethods s) + let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \_ _ _ -> do threadDelay (2*10^(6 :: Int)) return dummyResp @@ -221,7 +332,7 @@ testServerCallExpirationCheck = result <- clientRequest c rm 3 "" mempty return () server s = do - let rm = head (registeredMethods s) + let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \c _ _ -> do exp1 <- serverCallIsExpired c assertBool "Call isn't expired when handler starts" $ not exp1 @@ -245,7 +356,7 @@ testCustomUserAgent = result <- clientRequest c rm 4 "" mempty return () server = TestServer (stdServerConf [("/foo", Normal)]) $ \s -> do - let rm = head (registeredMethods s) + let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \_ _ meta -> do let ua = meta M.! "user-agent" assertBool "User agent prefix is present" $ isPrefixOf "prefix!" ua @@ -266,8 +377,8 @@ testClientCompression = result <- clientRequest c rm 1 "hello" mempty return () server = TestServer (stdServerConf [("/foo", Normal)]) $ \s -> do - let rm = head (registeredMethods s) - serverHandleNormalCall s rm mempty $ \c body _ -> do + let rm = head (normalMethods s) + serverHandleNormalCall s rm mempty $ \_ body _ -> do body @?= "hello" return dummyResp return () @@ -294,8 +405,8 @@ testClientServerCompression = [("/foo", Normal)] [CompressionAlgArg GrpcCompressDeflate] server = TestServer sconf $ \s -> do - let rm = head (registeredMethods s) - serverHandleNormalCall s rm dummyMeta $ \c body _ -> do + let rm = head (normalMethods s) + serverHandleNormalCall s rm dummyMeta $ \_sc body _ -> do body @?= "hello" return ("hello", dummyMeta, StatusOk, StatusDetails "") return () @@ -303,23 +414,28 @@ testClientServerCompression = -------------------------------------------------------------------------------- -- Utilities and helpers +is :: (Eq a, Show a, MonadIO m) => m a -> a -> m () +is act x = act >>= liftIO . (@?= x) + dummyMeta :: M.Map ByteString ByteString dummyMeta = [("foo","bar")] +dummyResp :: (ByteString, MetadataMap, StatusCode, StatusDetails) dummyResp = ("", mempty, StatusOk, StatusDetails "") dummyHandler :: ServerCall -> ByteString -> MetadataMap -> IO (ByteString, MetadataMap, StatusCode, StatusDetails) dummyHandler _ _ _ = return dummyResp -unavailableStatus :: Either GRPCIOError a -unavailableStatus = - Left (GRPCIOBadStatusCode StatusUnavailable (StatusDetails "")) +dummyResult' :: StatusDetails + -> IO (ByteString, MetadataMap, StatusCode, StatusDetails) +dummyResult' = return . (mempty, mempty, StatusOk, ) -deadlineExceededStatus :: Either GRPCIOError a -deadlineExceededStatus = - Left (GRPCIOBadStatusCode StatusDeadlineExceeded - (StatusDetails "Deadline Exceeded")) +badStatus :: StatusCode -> Either GRPCIOError a +badStatus st = Left . GRPCIOBadStatusCode st $ case st of + StatusDeadlineExceeded -> "Deadline Exceeded" + StatusCancelled -> "Received RST_STREAM err=8" + _ -> mempty nop :: Monad m => a -> m () nop = const (return ()) @@ -354,8 +470,8 @@ csTest' nm tc ts = -- | @checkMD msg expected actual@ fails when keys from @expected@ are not in -- @actual@, or when values differ for matching keys. checkMD :: String -> MetadataMap -> MetadataMap -> Assertion -checkMD desc expected actual = do - when (not $ M.null $ expected `diff` actual) $ do +checkMD desc expected actual = + unless (M.null $ expected `diff` actual) $ assertEqual desc expected (actual `M.intersection` expected) where diff = M.differenceWith $ \a b -> if a == b then Nothing else Just b @@ -363,13 +479,19 @@ checkMD desc expected actual = do checkReqRslt :: Show a => (b -> Assertion) -> Either a b -> Assertion checkReqRslt = either clientFail +-- | The consumer which asserts that the next value it consumes is equal to the +-- given value; string parameter used as in 'assertEqual'. +assertConsumeEq :: (Eq a, Show a) => String -> a -> P.Consumer a IO () +assertConsumeEq s v = P.lift . assertEqual s v =<< P.await + clientFail :: Show a => a -> Assertion clientFail = assertFailure . ("Client error: " ++). show data TestClient = TestClient ClientConfig (Client -> IO ()) runTestClient :: TestClient -> IO () -runTestClient (TestClient conf c) = withGRPC $ \g -> withClient g conf c +runTestClient (TestClient conf f) = + runManaged $ mgdGRPC >>= mgdClient conf >>= liftIO . f stdTestClient :: (Client -> IO ()) -> TestClient stdTestClient = TestClient stdClientConf @@ -380,7 +502,8 @@ stdClientConf = ClientConfig "localhost" 50051 [] data TestServer = TestServer ServerConfig (Server -> IO ()) runTestServer :: TestServer -> IO () -runTestServer (TestServer conf s) = withGRPC $ \g -> withServer g conf s +runTestServer (TestServer conf f) = + runManaged $ mgdGRPC >>= mgdServer conf >>= liftIO . f stdTestServer :: [(MethodName, GRPCMethodType)] -> (Server -> IO ()) -> TestServer stdTestServer = TestServer . stdServerConf @@ -388,6 +511,14 @@ stdTestServer = TestServer . stdServerConf stdServerConf :: [(MethodName, GRPCMethodType)] -> ServerConfig stdServerConf xs = ServerConfig "localhost" 50051 xs [] - threadDelaySecs :: Int -> IO () threadDelaySecs = threadDelay . (* 10^(6::Int)) + +mgdGRPC :: Managed GRPC +mgdGRPC = managed withGRPC + +mgdClient :: ClientConfig -> GRPC -> Managed Client +mgdClient conf g = managed $ withClient g conf + +mgdServer :: ServerConfig -> GRPC -> Managed Server +mgdServer conf g = managed $ withServer g conf diff --git a/tests/LowLevelTests/Op.hs b/tests/LowLevelTests/Op.hs index 59a6e63..8c47168 100644 --- a/tests/LowLevelTests/Op.hs +++ b/tests/LowLevelTests/Op.hs @@ -1,6 +1,6 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE RecordWildCards #-} module LowLevelTests.Op where @@ -76,7 +76,7 @@ withClientServerUnaryCall grpc f = do crm <- clientRegisterMethod c "/foo" Normal withServer grpc serverConf $ \s -> withClientCall c crm 10 $ \cc -> do - let srm = head (registeredMethods s) + let srm = head (normalMethods s) -- NOTE: We need to send client ops here or else `withServerCall` hangs, -- because registered methods try to do recv ops immediately when -- created. If later we want to send payloads or metadata, we'll need