diff --git a/README.md b/README.md index c64bb35..aa3e68b 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ In order to run the tests, you will need to have the `grpcio`, `gevent`, and ``` $ virtualenv path/to/virtualenv # to create a virtualenv -$ . path/to/virtual/env/bin/activate # to use an existing virtualenv +$ . path/to/virtualenv/bin/activate # to use an existing virtualenv $ pip install grpcio-tools gevent $ pip install grpcio # Need to install grpcio-tools first to avoid a versioning problem ``` diff --git a/examples/hellos/hellos-client/Main.hs b/examples/hellos/hellos-client/Main.hs index bc32909..0b87a28 100644 --- a/examples/hellos/hellos-client/Main.hs +++ b/examples/hellos/hellos-client/Main.hs @@ -3,19 +3,22 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fno-warn-missing-signatures #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} +import Control.Concurrent.Async import Control.Monad import qualified Data.ByteString.Lazy as BL +import Data.Function import Data.Protobuf.Wire.Class import qualified Data.Text as T import Data.Word import GHC.Generics (Generic) import Network.GRPC.LowLevel +helloSS, helloCS, helloBi :: MethodName helloSS = MethodName "/hellos.Hellos/HelloSS" helloCS = MethodName "/hellos.Hellos/HelloCS" +helloBi = MethodName "/hellos.Hellos/HelloBi" data SSRqt = SSRqt { ssName :: T.Text, ssNumReplies :: Word32 } deriving (Show, Eq, Ord, Generic) instance Message SSRqt @@ -25,43 +28,43 @@ data CSRqt = CSRqt { csMessage :: T.Text } deriving (Show, Eq, Ord, Generic) instance Message CSRqt data CSRpy = CSRpy { csNumRequests :: Word32 } deriving (Show, Eq, Ord, Generic) instance Message CSRpy +data BiRqtRpy = BiRqtRpy { biMessage :: T.Text } deriving (Show, Eq, Ord, Generic) +instance Message BiRqtRpy expect :: (Eq a, Monad m, Show a) => String -> a -> a -> m () expect ctx ex got | ex /= got = fail $ ctx ++ " error: expected " ++ show ex ++ ", got " ++ show got | otherwise = return () -doHelloSS c = do +doHelloSS :: Client -> Int -> IO () +doHelloSS c n = do rm <- clientRegisterMethodServerStreaming c helloSS - let nr = 10 - pay = SSRqt "server streaming mode" nr - enc = BL.toStrict . toLazyByteString $ pay - eea <- clientReader c rm 5 enc mempty $ \_md recv -> do - n :: Int <- go recv 0 - expect "doHelloSS/cnt" (fromIntegral nr) n + let pay = SSRqt "server streaming mode" (fromIntegral n) + enc = BL.toStrict . toLazyByteString $ pay + err desc e = fail $ "doHelloSS: " ++ desc ++ " error: " ++ show e + eea <- clientReader c rm n enc mempty $ \_md recv -> do + n' <- flip fix (0::Int) $ \go i -> recv >>= \case + Left e -> err "recv" e + Right Nothing -> return i + Right (Just bs) -> case fromByteString bs of + Left e -> err "decoding" e + Right r -> expect "doHelloSS/rpy" expay (ssGreeting r) >> go (i+1) + expect "doHelloSS/cnt" n n' case eea of - Left e -> fail $ "clientReader error: " ++ show e + Left e -> err "clientReader" e Right (_, st, _) | st /= StatusOk -> fail "clientReader: non-OK status" - | otherwise -> return () + | otherwise -> putStrLn "doHelloSS: RPC successful" where expay = "Hello there, server streaming mode!" - go recv n = recv >>= \case - Left e -> fail $ "doHelloSS error: " ++ show e - Right Nothing -> return n - Right (Just r) -> case fromByteString r of - Left e -> fail $ "Decoding error: " ++ show e - Right r' -> do - expect "doHelloSS/rpy" expay (ssGreeting r') - go recv (n+1) -doHelloCS c = do +doHelloCS :: Client -> Int -> IO () +doHelloCS c n = do rm <- clientRegisterMethodClientStreaming c helloCS - let nr = 10 - pay = CSRqt "client streaming payload" + let pay = CSRqt "client streaming payload" enc = BL.toStrict . toLazyByteString $ pay - eea <- clientWriter c rm 10 mempty $ \send -> - replicateM_ (fromIntegral nr) $ send enc >>= \case + eea <- clientWriter c rm n mempty $ \send -> + replicateM_ n $ send enc >>= \case Left e -> fail $ "doHelloCS: send error: " ++ show e Right{} -> return () case eea of @@ -71,11 +74,49 @@ doHelloCS c = do | st /= StatusOk -> fail "clientWriter: non-OK status" | otherwise -> case fromByteString bs of Left e -> fail $ "Decoding error: " ++ show e - Right dec -> expect "doHelloCS/cnt" nr (csNumRequests dec) + Right dec -> do + expect "doHelloCS/cnt" (fromIntegral n) (csNumRequests dec) + putStrLn "doHelloCS: RPC successful" +doHelloBi :: Client -> Int -> IO () +doHelloBi c n = do + rm <- clientRegisterMethodBiDiStreaming c helloBi + let pay = BiRqtRpy "bidi payload" + enc = BL.toStrict . toLazyByteString $ pay + err desc e = fail $ "doHelloBi: " ++ desc ++ " error: " ++ show e + eea <- clientRW c rm n mempty $ \_ recv send writesDone -> do + -- perform n writes on a worker thread + thd <- async $ do + replicateM_ n $ send enc >>= \case + Left e -> err "send" e + _ -> return () + writesDone >>= \case + Left e -> err "writesDone" e + _ -> return () + -- perform reads on this thread until the stream is terminated + fix $ \go -> recv >>= \case + Left e -> err "recv" e + Right Nothing -> return () + Right (Just bs) -> case fromByteString bs of + Left e -> err "decoding" e + Right r -> when (r /= pay) (fail "Reply payload mismatch") >> go + wait thd + case eea of + Left e -> err "clientRW'" e + Right (_, st, _) -> do + when (st /= StatusOk) $ fail $ "clientRW: non-OK status: " ++ show st + putStrLn "doHelloBi: RPC successful" + +highlevelMain :: IO () highlevelMain = withGRPC $ \g -> withClient g (ClientConfig "localhost" 50051 []) $ \c -> do - doHelloSS c - doHelloCS c + let n = 100000 + putStrLn "-------------- HelloSS --------------" + doHelloSS c n + putStrLn "-------------- HelloCS --------------" + doHelloCS c n + putStrLn "-------------- HelloBi --------------" + doHelloBi c n +main :: IO () main = highlevelMain diff --git a/src/Network/GRPC/HighLevel/Server.hs b/src/Network/GRPC/HighLevel/Server.hs index be22278..03b9e46 100644 --- a/src/Network/GRPC/HighLevel/Server.hs +++ b/src/Network/GRPC/HighLevel/Server.hs @@ -8,19 +8,16 @@ module Network.GRPC.HighLevel.Server where -import Control.Concurrent.Async import qualified Control.Exception as CE import Control.Monad import Data.ByteString (ByteString) import qualified Data.ByteString.Lazy as BL import Data.Protobuf.Wire.Class import Network.GRPC.LowLevel -import qualified Network.GRPC.LowLevel.Call.Unregistered as U -import qualified Network.GRPC.LowLevel.Server.Unregistered as U import System.IO -type ServerHandler a b = - ServerCall a +type ServerHandler a b + = ServerCall a -> IO (b, MetadataMap, StatusCode, StatusDetails) convertServerHandler :: (Message a, Message b) @@ -28,48 +25,45 @@ convertServerHandler :: (Message a, Message b) -> ServerHandlerLL convertServerHandler f c = case fromByteString (payload c) of Left x -> CE.throw (GRPCIODecodeError x) - Right x -> do (y, tm, sc, sd) <- f (fmap (const x) c) + Right x -> do (y, tm, sc, sd) <- f (const x <$> c) return (toBS y, tm, sc, sd) -type ServerReaderHandler a b = - ServerCall () +type ServerReaderHandler a b + = ServerCall (MethodPayload 'ClientStreaming) -> StreamRecv a - -> Streaming (Maybe b, MetadataMap, StatusCode, StatusDetails) + -> IO (Maybe b, MetadataMap, StatusCode, StatusDetails) convertServerReaderHandler :: (Message a, Message b) => ServerReaderHandler a b -> ServerReaderHandlerLL -convertServerReaderHandler f c recv = - serialize <$> f c (convertRecv recv) +convertServerReaderHandler f c recv = serialize <$> f c (convertRecv recv) where serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd) type ServerWriterHandler a b = - ServerCall a + ServerCall a -> StreamSend b - -> Streaming (MetadataMap, StatusCode, StatusDetails) + -> IO (MetadataMap, StatusCode, StatusDetails) -convertServerWriterHandler :: (Message a, Message b) => - ServerWriterHandler a b - -> ServerWriterHandlerLL -convertServerWriterHandler f c send = - f (convert <$> c) (convertSend send) +convertServerWriterHandler :: (Message a, Message b) + => ServerWriterHandler a b + -> ServerWriterHandlerLL +convertServerWriterHandler f c send = f (convert <$> c) (convertSend send) where convert bs = case fromByteString bs of Left x -> CE.throw (GRPCIODecodeError x) Right x -> x -type ServerRWHandler a b = - ServerCall () +type ServerRWHandler a b + = ServerCall (MethodPayload 'BiDiStreaming) -> StreamRecv a -> StreamSend b - -> Streaming (MetadataMap, StatusCode, StatusDetails) + -> IO (MetadataMap, StatusCode, StatusDetails) convertServerRWHandler :: (Message a, Message b) => ServerRWHandler a b -> ServerRWHandlerLL -convertServerRWHandler f c recv send = - f c (convertRecv recv) (convertSend send) +convertServerRWHandler f c r s = f c (convertRecv r) (convertSend s) convertRecv :: Message a => StreamRecv ByteString -> StreamRecv a convertRecv = @@ -88,40 +82,21 @@ toBS :: Message a => a -> ByteString toBS = BL.toStrict . toLazyByteString data Handler (a :: GRPCMethodType) where - UnaryHandler - :: (Message c, Message d) - => MethodName - -> ServerHandler c d - -> Handler 'Normal + UnaryHandler :: (Message c, Message d) => MethodName -> ServerHandler c d -> Handler 'Normal + ClientStreamHandler :: (Message c, Message d) => MethodName -> ServerReaderHandler c d -> Handler 'ClientStreaming + ServerStreamHandler :: (Message c, Message d) => MethodName -> ServerWriterHandler c d -> Handler 'ServerStreaming + BiDiStreamHandler :: (Message c, Message d) => MethodName -> ServerRWHandler c d -> Handler 'BiDiStreaming - ClientStreamHandler - :: (Message c, Message d) - => MethodName - -> ServerReaderHandler c d - -> Handler 'ClientStreaming - - ServerStreamHandler - :: (Message c, Message d) - => MethodName - -> ServerWriterHandler c d - -> Handler 'ServerStreaming - - BiDiStreamHandler - :: (Message c, Message d) - => MethodName - -> ServerRWHandler c d - -> Handler 'BiDiStreaming - -data AnyHandler = forall (a :: GRPCMethodType) . AnyHandler (Handler a) +data AnyHandler = forall (a :: GRPCMethodType). AnyHandler (Handler a) anyHandlerMethodName :: AnyHandler -> MethodName anyHandlerMethodName (AnyHandler m) = handlerMethodName m handlerMethodName :: Handler a -> MethodName -handlerMethodName (UnaryHandler m _) = m +handlerMethodName (UnaryHandler m _) = m handlerMethodName (ClientStreamHandler m _) = m handlerMethodName (ServerStreamHandler m _) = m -handlerMethodName (BiDiStreamHandler m _) = m +handlerMethodName (BiDiStreamHandler m _) = m logMsg :: String -> IO () logMsg = hPutStrLn stderr @@ -146,17 +121,17 @@ handleCallError (Left x) = logMsg $ show x ++ ": This probably indicates a bug in gRPC-haskell. Please report this error." loopWError :: Int - -> IO (Either GRPCIOError a) - -> IO () + -> IO (Either GRPCIOError a) + -> IO () loopWError i f = do when (i `mod` 100 == 0) $ putStrLn $ "i = " ++ show i f >>= handleCallError loopWError (i + 1) f ---TODO: options for setting initial/trailing metadata +-- TODO: options for setting initial/trailing metadata handleLoop :: Server - -> (Handler a, RegisteredMethod a) - -> IO () + -> (Handler a, RegisteredMethod a) + -> IO () handleLoop s (UnaryHandler _ f, rm) = loopWError 0 $ serverHandleNormalCall s rm mempty $ convertServerHandler f handleLoop s (ClientStreamHandler _ f, rm) = @@ -167,30 +142,33 @@ handleLoop s (BiDiStreamHandler _ f, rm) = loopWError 0 $ serverRW s rm mempty $ convertServerRWHandler f data ServerOptions = ServerOptions - {optNormalHandlers :: [Handler 'Normal], - optClientStreamHandlers :: [Handler 'ClientStreaming], - optServerStreamHandlers :: [Handler 'ServerStreaming], - optBiDiStreamHandlers :: [Handler 'BiDiStreaming], - optServerPort :: Port, - optUseCompression :: Bool, - optUserAgentPrefix :: String, - optUserAgentSuffix :: String, - optInitialMetadata :: MetadataMap} + { optNormalHandlers :: [Handler 'Normal] + , optClientStreamHandlers :: [Handler 'ClientStreaming] + , optServerStreamHandlers :: [Handler 'ServerStreaming] + , optBiDiStreamHandlers :: [Handler 'BiDiStreaming] + , optServerPort :: Port + , optUseCompression :: Bool + , optUserAgentPrefix :: String + , optUserAgentSuffix :: String + , optInitialMetadata :: MetadataMap + } defaultOptions :: ServerOptions -defaultOptions = - ServerOptions {optNormalHandlers = [], - optClientStreamHandlers = [], - optServerStreamHandlers = [], - optBiDiStreamHandlers = [], - optServerPort = 50051, - optUseCompression = False, - optUserAgentPrefix = "grpc-haskell/0.0.0", - optUserAgentSuffix = "", - optInitialMetadata = mempty} +defaultOptions = ServerOptions + { optNormalHandlers = [] + , optClientStreamHandlers = [] + , optServerStreamHandlers = [] + , optBiDiStreamHandlers = [] + , optServerPort = 50051 + , optUseCompression = False + , optUserAgentPrefix = "grpc-haskell/0.0.0" + , optUserAgentSuffix = "" + , optInitialMetadata = mempty + } serverLoop :: ServerOptions -> IO () -serverLoop opts = +serverLoop _opts = fail "Registered method-based serverLoop NYI" +{- withGRPC $ \grpc -> withServer grpc (mkConfig opts) $ \server -> do let rmsN = zip (optNormalHandlers opts) $ normalMethods server @@ -231,3 +209,4 @@ serverLoop opts = logMsg $ "Requested unknown endpoint: " ++ show (U.callMethod call) return ("", mempty, StatusNotFound, StatusDetails "Unknown method") +-} diff --git a/src/Network/GRPC/HighLevel/Server/Unregistered.hs b/src/Network/GRPC/HighLevel/Server/Unregistered.hs index 77ef56b..467fd2b 100644 --- a/src/Network/GRPC/HighLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/HighLevel/Server/Unregistered.hs @@ -21,65 +21,49 @@ import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U dispatchLoop :: Server - -> MetadataMap - -> [Handler 'Normal] - -> [Handler 'ClientStreaming] - -> [Handler 'ServerStreaming] - -> [Handler 'BiDiStreaming] - -> IO () -dispatchLoop server meta hN hC hS hB = - forever $ U.withServerCallAsync server $ \call -> do - case findHandler call allHandlers of - Just (AnyHandler (UnaryHandler _ h)) -> unaryHandler call h - Just (AnyHandler (ClientStreamHandler _ h)) -> csHandler call h - Just (AnyHandler (ServerStreamHandler _ h)) -> ssHandler call h - Just (AnyHandler (BiDiStreamHandler _ h)) -> bdHandler call h - Nothing -> unknownHandler call - where allHandlers = map AnyHandler hN - ++ map AnyHandler hC - ++ map AnyHandler hS - ++ map AnyHandler hB - findHandler call = find ((== (U.callMethod call)) - . anyHandlerMethodName) - unknownHandler call = - void $ U.serverHandleNormalCall' server call meta $ \_ _ -> - return (mempty - , mempty - , StatusNotFound - , StatusDetails "unknown method") + -> MetadataMap + -> [Handler 'Normal] + -> [Handler 'ClientStreaming] + -> [Handler 'ServerStreaming] + -> [Handler 'BiDiStreaming] + -> IO () +dispatchLoop s md hN hC hS hB = + forever $ U.withServerCallAsync s $ \sc -> + case findHandler sc allHandlers of + Just (AnyHandler ah) -> case ah of + UnaryHandler _ h -> unaryHandler sc h + ClientStreamHandler _ h -> csHandler sc h + ServerStreamHandler _ h -> ssHandler sc h + BiDiStreamHandler _ h -> bdHandler sc h + Nothing -> unknownHandler sc + where + allHandlers = map AnyHandler hN ++ map AnyHandler hC + ++ map AnyHandler hS ++ map AnyHandler hB - handleError = (handleCallError . left herr =<<) . CE.try - where herr (e :: CE.SomeException) = GRPCIOHandlerException (show e) + findHandler sc = find ((== U.callMethod sc) . anyHandlerMethodName) - unaryHandler :: (Message a, Message b) => - U.ServerCall - -> ServerHandler a b - -> IO () - unaryHandler call h = - handleError $ - U.serverHandleNormalCall' server call meta $ \_call' bs -> - convertServerHandler h (fmap (const bs) $ U.convertCall call) - csHandler :: (Message a, Message b) => - U.ServerCall - -> ServerReaderHandler a b - -> IO () - csHandler call h = - handleError $ - U.serverReader server call meta (convertServerReaderHandler h) - ssHandler :: (Message a, Message b) => - U.ServerCall - -> ServerWriterHandler a b - -> IO () - ssHandler call h = - handleError $ - U.serverWriter server call meta (convertServerWriterHandler h) - bdHandler :: (Message a, Message b) => - U.ServerCall - -> ServerRWHandler a b - -> IO () - bdHandler call h = - handleError $ - U.serverRW server call meta (convertServerRWHandler h) + unaryHandler :: (Message a, Message b) => U.ServerCall -> ServerHandler a b -> IO () + unaryHandler sc h = + handleError $ + U.serverHandleNormalCall' s sc md $ \_sc' bs -> + convertServerHandler h (const bs <$> U.convertCall sc) + + csHandler :: (Message a, Message b) => U.ServerCall -> ServerReaderHandler a b -> IO () + csHandler sc = handleError . U.serverReader s sc md . convertServerReaderHandler + + ssHandler :: (Message a, Message b) => U.ServerCall -> ServerWriterHandler a b -> IO () + ssHandler sc = handleError . U.serverWriter s sc md . convertServerWriterHandler + + bdHandler :: (Message a, Message b) => U.ServerCall -> ServerRWHandler a b -> IO () + bdHandler sc = handleError . U.serverRW s sc md . convertServerRWHandler + + unknownHandler :: U.ServerCall -> IO () + unknownHandler sc = void $ U.serverHandleNormalCall' s sc md $ \_ _ -> + return (mempty, mempty, StatusNotFound, StatusDetails "unknown method") + + handleError :: IO a -> IO () + handleError = (handleCallError . left herr =<<) . CE.try + where herr (e :: CE.SomeException) = GRPCIOHandlerException (show e) serverLoop :: ServerOptions -> IO () serverLoop ServerOptions{..} = do @@ -95,17 +79,17 @@ serverLoop ServerOptions{..} = do optBiDiStreamHandlers wait tid where - config = - ServerConfig - { host = "localhost" - , port = optServerPort - , methodsToRegisterNormal = [] - , methodsToRegisterClientStreaming = [] - , methodsToRegisterServerStreaming = [] - , methodsToRegisterBiDiStreaming = [] - , serverArgs = - ([CompressionAlgArg GrpcCompressDeflate | optUseCompression] - ++ - [UserAgentPrefix optUserAgentPrefix - , UserAgentSuffix optUserAgentSuffix]) - } + config = ServerConfig + { host = "localhost" + , port = optServerPort + , methodsToRegisterNormal = [] + , methodsToRegisterClientStreaming = [] + , methodsToRegisterServerStreaming = [] + , methodsToRegisterBiDiStreaming = [] + , serverArgs = + [CompressionAlgArg GrpcCompressDeflate | optUseCompression] + ++ + [ UserAgentPrefix optUserAgentPrefix + , UserAgentSuffix optUserAgentSuffix + ] + } diff --git a/src/Network/GRPC/LowLevel.hs b/src/Network/GRPC/LowLevel.hs index a4eab1d..880b813 100644 --- a/src/Network/GRPC/LowLevel.hs +++ b/src/Network/GRPC/LowLevel.hs @@ -21,6 +21,7 @@ GRPC -- * Calls , GRPCMethodType(..) , RegisteredMethod +, MethodPayload , NormalRequestResult(..) , MetadataMap(..) , MethodName(..) @@ -74,7 +75,6 @@ GRPC , OpRecvResult(..) -- * Streaming utilities -, Streaming , StreamSend , StreamRecv diff --git a/src/Network/GRPC/LowLevel/Client.hs b/src/Network/GRPC/LowLevel/Client.hs index 9ee6d18..da912e7 100644 --- a/src/Network/GRPC/LowLevel/Client.hs +++ b/src/Network/GRPC/LowLevel/Client.hs @@ -12,6 +12,7 @@ module Network.GRPC.LowLevel.Client where import Control.Exception (bracket, finally) import Control.Monad +import Control.Monad.IO.Class import Control.Monad.Trans.Except import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call @@ -198,7 +199,8 @@ compileNormalRequestResults x = -- clientReader (client side of server streaming mode) -- | First parameter is initial server metadata. -type ClientReaderHandler = MetadataMap -> StreamRecv ByteString -> Streaming () +type ClientReaderHandler = MetadataMap -> StreamRecv ByteString -> IO () +type ClientReaderResult = (MetadataMap, C.StatusCode, StatusDetails) clientReader :: Client -> RegisteredMethod 'ServerStreaming @@ -206,7 +208,7 @@ clientReader :: Client -> ByteString -- ^ The body of the request -> MetadataMap -- ^ Metadata to send with the request -> ClientReaderHandler - -> IO (Either GRPCIOError (MetadataMap, C.StatusCode, StatusDetails)) + -> IO (Either GRPCIOError ClientReaderResult) clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f = withClientCall cl rm tm go where @@ -216,13 +218,13 @@ clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f = , OpSendCloseFromClient ] srvMD <- recvInitialMetadata c cq - runStreamingProxy "clientReader'" c cq (f srvMD streamRecv) + liftIO $ f srvMD (streamRecvPrim c cq) recvStatusOnClient c cq -------------------------------------------------------------------------------- -- clientWriter (client side of client streaming mode) -type ClientWriterHandler = StreamSend ByteString -> Streaming () +type ClientWriterHandler = StreamSend ByteString -> IO () type ClientWriterResult = (Maybe ByteString, MetadataMap, MetadataMap, C.StatusCode, StatusDetails) @@ -243,7 +245,7 @@ clientWriterCmn :: Client -- ^ The active client clientWriterCmn (clientCQ -> cq) initMeta f (unsafeCC -> c) = runExceptT $ do sendInitialMetadata c cq initMeta - runStreamingProxy "clientWriterCmn" c cq (f streamSend) + liftIO $ f (streamSendPrim c cq) sendSingle c cq OpSendCloseFromClient let ops = [OpRecvInitialMetadata, OpRecvMessage, OpRecvStatusOnClient] runOps' c cq ops >>= \case @@ -260,28 +262,41 @@ pattern CWRFinal mmsg initMD trailMD st ds -------------------------------------------------------------------------------- -- clientRW (client side of bidirectional streaming mode) --- | First parameter is initial server metadata. -type ClientRWHandler = MetadataMap - -> StreamRecv ByteString - -> StreamSend ByteString - -> Streaming () +type ClientRWHandler + = MetadataMap + -> StreamRecv ByteString + -> StreamSend ByteString + -> WritesDone + -> IO () +type ClientRWResult = (MetadataMap, C.StatusCode, StatusDetails) --- | For bidirectional-streaming registered requests +-- | The most generic version of clientRW. It does not assume anything about +-- threading model; caller must invoke the WritesDone operation, exactly once, +-- for the half-close, after all threads have completed writing. TODO: It'd be +-- nice to find a way to type-enforce this usage pattern rather than accomplish +-- it via usage convention and documentation. clientRW :: Client -> RegisteredMethod 'BiDiStreaming -> TimeoutSeconds -> MetadataMap - -- ^ request metadata -> ClientRWHandler - -> IO (Either GRPCIOError (MetadataMap, C.StatusCode, StatusDetails)) -clientRW cl@(clientCQ -> cq) rm tm initMeta f = - withClientCall cl rm tm go + -> IO (Either GRPCIOError ClientRWResult) +clientRW cl@(clientCQ -> cq) rm tm initMeta f = withClientCall cl rm tm go where go (unsafeCC -> c) = runExceptT $ do sendInitialMetadata c cq initMeta srvMeta <- recvInitialMetadata c cq - runStreamingProxy "clientRW" c cq (f srvMeta streamRecv streamSend) - runOps' c cq [OpSendCloseFromClient] -- WritesDone() + liftIO $ f srvMeta (streamRecvPrim c cq) (streamSendPrim c cq) (writesDonePrim c cq) + -- NB: We could consider having the passed writesDone action safely set a + -- flag once it had been called, and invoke it ourselves if not set after + -- returning from the handler (although this is actually borked in the + -- concurrent case, because a reader may remain blocked without the + -- half-close and thus not return control to us -- doh). Alternately, we + -- can document just this general-purpose function well, and then create + -- slightly simpler versions of the bidi interface which support (a) + -- monothreaded send/recv interleaving with implicit half-close and (b) + -- send/recv threads with implicit half-close after writer thread + -- termination. recvStatusOnClient c cq -- Finish() -------------------------------------------------------------------------------- diff --git a/src/Network/GRPC/LowLevel/Op.hs b/src/Network/GRPC/LowLevel/Op.hs index 9711c59..c448d46 100644 --- a/src/Network/GRPC/LowLevel/Op.hs +++ b/src/Network/GRPC/LowLevel/Op.hs @@ -8,7 +8,6 @@ module Network.GRPC.LowLevel.Op where import Control.Exception import Control.Monad -import Control.Monad.Trans.Class (MonadTrans(lift)) import Control.Monad.Trans.Except import Data.ByteString (ByteString) import qualified Data.ByteString as B @@ -27,8 +26,6 @@ import qualified Network.GRPC.Unsafe.ByteBuffer as C import qualified Network.GRPC.Unsafe.Metadata as C import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.Slice as C (Slice, freeSlice) -import qualified Pipes as P -import qualified Pipes.Core as P -- | Sum describing all possible send and receive operations that can be batched -- and executed by gRPC. Usually these are processed in a handful of @@ -304,60 +301,28 @@ recvInitialMessage c cq = runOps' c cq [OpRecvMessage] >>= \case -------------------------------------------------------------------------------- -- Streaming types and helpers --- | Requests use Nothing to denote read, Just to denote --- write. Right-constructed responses use Just to indicate a successful read, --- and Nothing to denote end of stream when reading or a successful write. -type Streaming a = - P.Client (Maybe ByteString) (Either GRPCIOError (Maybe ByteString)) IO a - --- | Run the given 'Streaming' operation via an appropriate upstream --- proxy. I.e., if called on the client side, the given 'Streaming' operation --- talks to a server proxy, and vice versa. -runStreamingProxy :: String - -- ^ context string for including in errors - -> C.Call - -- ^ the call associated with this streaming operation - -> CompletionQueue - -- ^ the completion queue for ops batches - -> Streaming a - -- ^ the requesting side of the streaming operation - -> ExceptT GRPCIOError IO a -runStreamingProxy nm c cq - = ExceptT . P.runEffect . (streamingProxy nm c cq P.+>>) . fmap Right - -streamingProxy :: String - -- ^ context string for including in errors - -> C.Call - -- ^ the call associated with this streaming operation - -> CompletionQueue - -- ^ the completion queue for ops batches - -> Maybe ByteString - -- ^ the request to the proxy - -> P.Server - (Maybe ByteString) - (Either GRPCIOError (Maybe ByteString)) - IO (Either GRPCIOError a) -streamingProxy nm c cq = maybe recv send +type StreamRecv a = IO (Either GRPCIOError (Maybe a)) +streamRecvPrim :: C.Call -> CompletionQueue -> StreamRecv ByteString +streamRecvPrim c cq = f <$> runOps c cq [OpRecvMessage] where - recv = run [OpRecvMessage] >>= \case - RecvMsgRslt mr -> rsp mr >>= streamingProxy nm c cq - Right{} -> err (urecv "recv") - Left e -> err e - send msg = run [OpSendMessage msg] >>= \case - Right [] -> rsp Nothing >>= streamingProxy nm c cq - Right _ -> err (urecv "send") - Left e -> err e - err e = P.respond (Left e) >> return (Left e) - rsp = P.respond . Right - run = lift . runOps c cq - urecv = GRPCIOInternalUnexpectedRecv . (nm ++) - -type StreamRecv a = Streaming (Either GRPCIOError (Maybe a)) -streamRecv :: StreamRecv ByteString -streamRecv = P.request Nothing - -type StreamSend a = a -> Streaming (Either GRPCIOError ()) -streamSend :: StreamSend ByteString -streamSend = fmap void . P.request . Just + f (RecvMsgRslt mmsg) = Right mmsg + f Right{} = Left (GRPCIOInternalUnexpectedRecv "streamRecvPrim") + f (Left e) = Left e pattern RecvMsgRslt mmsg <- Right [OpRecvMessageResult mmsg] + +type StreamSend a = a -> IO (Either GRPCIOError ()) +streamSendPrim :: C.Call -> CompletionQueue -> StreamSend ByteString +streamSendPrim c cq bs = f <$> runOps c cq [OpSendMessage bs] + where + f (Right []) = Right () + f Right{} = Left (GRPCIOInternalUnexpectedRecv "streamSendPrim") + f (Left e) = Left e + +type WritesDone = IO (Either GRPCIOError ()) +writesDonePrim :: C.Call -> CompletionQueue -> WritesDone +writesDonePrim c cq = f <$> runOps c cq [OpSendCloseFromClient] + where + f (Right []) = Right () + f Right{} = Left (GRPCIOInternalUnexpectedRecv "writesDonePrim") + f (Left e) = Left e diff --git a/src/Network/GRPC/LowLevel/Server.hs b/src/Network/GRPC/LowLevel/Server.hs index e17ee1f..799b2eb 100644 --- a/src/Network/GRPC/LowLevel/Server.hs +++ b/src/Network/GRPC/LowLevel/Server.hs @@ -28,6 +28,7 @@ import Control.Concurrent.STM.TVar (TVar , newTVarIO) import Control.Exception (bracket, finally) import Control.Monad +import Control.Monad.IO.Class import Control.Monad.Trans.Except import Data.ByteString (ByteString) import qualified Data.Set as S @@ -199,7 +200,7 @@ serverRegisterMethod :: C.Server -> MethodName -> Endpoint -> GRPCMethodType - -> IO (C.CallHandle) + -> IO C.CallHandle serverRegisterMethod s nm e mty = C.grpcServerRegisterMethod s (unMethodName nm) @@ -312,20 +313,19 @@ withServerCall s rm f = -- serverReader (server side of client streaming mode) type ServerReaderHandlerLL - = ServerCall () + = ServerCall (MethodPayload 'ClientStreaming) -> StreamRecv ByteString - -> Streaming (Maybe ByteString, MetadataMap, C.StatusCode, StatusDetails) + -> IO (Maybe ByteString, MetadataMap, C.StatusCode, StatusDetails) serverReader :: Server -> RegisteredMethod 'ClientStreaming - -> MetadataMap -- ^ initial server metadata + -> MetadataMap -- ^ Initial server metadata -> ServerReaderHandlerLL -> IO (Either GRPCIOError ()) serverReader s rm initMeta f = withServerCall s rm go where go sc@ServerCall{ unsafeSC = c, callCQ = ccq } = runExceptT $ do - (mmsg, trailMeta, st, ds) <- - runStreamingProxy "serverReader" c ccq (f sc streamRecv) + (mmsg, trailMeta, st, ds) <- liftIO $ f sc (streamRecvPrim c ccq) runOps' c ccq ( OpSendInitialMetadata initMeta : OpSendStatusFromServer trailMeta st ds : maybe [] ((:[]) . OpSendMessage) mmsg @@ -336,44 +336,42 @@ serverReader s rm initMeta f = withServerCall s rm go -- serverWriter (server side of server streaming mode) type ServerWriterHandlerLL - = ServerCall ByteString + = ServerCall (MethodPayload 'ServerStreaming) -> StreamSend ByteString - -> Streaming (MetadataMap, C.StatusCode, StatusDetails) + -> IO (MetadataMap, C.StatusCode, StatusDetails) -- | Wait for and then handle a registered, server-streaming call. serverWriter :: Server -> RegisteredMethod 'ServerStreaming - -> MetadataMap - -- ^ Initial server metadata + -> MetadataMap -- ^ Initial server metadata -> ServerWriterHandlerLL -> IO (Either GRPCIOError ()) serverWriter s rm initMeta f = withServerCall s rm go where go sc@ServerCall{ unsafeSC = c, callCQ = ccq } = runExceptT $ do sendInitialMetadata c ccq initMeta - st <- runStreamingProxy "serverWriter" c ccq (f sc streamSend) + st <- liftIO $ f sc (streamSendPrim c ccq) sendStatusFromServer c ccq st -------------------------------------------------------------------------------- --- serverRW (server side of bidirectional streaming mode) +-- serverRW (bidirectional streaming mode) type ServerRWHandlerLL - = ServerCall () + = ServerCall (MethodPayload 'BiDiStreaming) -> StreamRecv ByteString -> StreamSend ByteString - -> Streaming (MetadataMap, C.StatusCode, StatusDetails) + -> IO (MetadataMap, C.StatusCode, StatusDetails) serverRW :: Server -> RegisteredMethod 'BiDiStreaming - -> MetadataMap - -- ^ initial server metadata + -> MetadataMap -- ^ initial server metadata -> ServerRWHandlerLL -> IO (Either GRPCIOError ()) serverRW s rm initMeta f = withServerCall s rm go where go sc@ServerCall{ unsafeSC = c, callCQ = ccq } = runExceptT $ do sendInitialMetadata c ccq initMeta - st <- runStreamingProxy "serverRW" c ccq (f sc streamRecv streamSend) + st <- liftIO $ f sc (streamRecvPrim c ccq) (streamSendPrim c ccq) sendStatusFromServer c ccq st -------------------------------------------------------------------------------- @@ -386,7 +384,7 @@ serverRW s rm initMeta f = withServerCall s rm go -- respectively. We pass in the 'ServerCall' so that the server can call -- 'serverCallCancel' on it if needed. type ServerHandlerLL - = ServerCall ByteString + = ServerCall (MethodPayload 'Normal) -> IO (ByteString, MetadataMap, C.StatusCode, StatusDetails) -- | Wait for and then handle a normal (non-streaming) call. @@ -399,12 +397,10 @@ serverHandleNormalCall :: Server serverHandleNormalCall s rm initMeta f = withServerCall s rm go where - go sc@ServerCall{..} = do + go sc@ServerCall{ unsafeSC = c, callCQ = ccq } = do (rsp, trailMeta, st, ds) <- f sc - void <$> runOps unsafeSC callCQ - [ OpSendInitialMetadata initMeta - , OpRecvCloseOnServer - , OpSendMessage rsp - , OpSendStatusFromServer trailMeta st ds - ] - <* grpcDebug "serverHandleNormalCall(R): finished response ops." + void <$> runOps c ccq [ OpSendInitialMetadata initMeta + , OpRecvCloseOnServer + , OpSendMessage rsp + , OpSendStatusFromServer trailMeta st ds + ] diff --git a/src/Network/GRPC/LowLevel/Server/Unregistered.hs b/src/Network/GRPC/LowLevel/Server/Unregistered.hs index c65ffd6..2f3876b 100644 --- a/src/Network/GRPC/LowLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Server/Unregistered.hs @@ -4,26 +4,19 @@ module Network.GRPC.LowLevel.Server.Unregistered where import Control.Exception (finally) +import Control.Monad +import Control.Monad.IO.Class import Control.Monad.Trans.Except import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call.Unregistered import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall) import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op (Op (..) - , OpRecvResult (..) - , runOps - , runStreamingProxy - , streamRecv - , streamSend - , runOps' - , sendInitialMetadata - , sendStatusFromServer - , recvInitialMessage) -import Network.GRPC.LowLevel.Server (Server (..) - , ServerReaderHandlerLL - , ServerWriterHandlerLL - , ServerRWHandlerLL - , forkServer) +import Network.GRPC.LowLevel.Op +import Network.GRPC.LowLevel.Server (Server (..), + ServerRWHandlerLL, + ServerReaderHandlerLL, + ServerWriterHandlerLL, + forkServer) import qualified Network.GRPC.Unsafe.Op as C serverCreateCall :: Server @@ -47,49 +40,27 @@ withServerCall s f = -- Because this function doesn't wait for the handler to return, it cannot -- return errors. withServerCallAsync :: Server - -> (ServerCall -> IO ()) - -> IO () + -> (ServerCall -> IO ()) + -> IO () withServerCallAsync s f = serverCreateCall s >>= \case Left e -> do grpcDebug $ "withServerCallAsync: call error: " ++ show e return () Right c -> do wasForkSuccess <- forkServer s handler - if wasForkSuccess - then return () - else destroy + unless wasForkSuccess destroy where handler = f c `finally` destroy - --TODO: We sometimes never finish cleanup if the server - -- is shutting down and calls killThread. This causes - -- gRPC core to complain about leaks. - -- I think the cause of this is that killThread gets - -- called after we are already in destroyServerCall, - -- and wrapping uninterruptibleMask doesn't seem to help. - -- Doesn't crash, but does emit annoying log messages. + -- TODO: We sometimes never finish cleanup if the server + -- is shutting down and calls killThread. This causes gRPC + -- core to complain about leaks. I think the cause of + -- this is that killThread gets called after we are + -- already in destroyServerCall, and wrapping + -- uninterruptibleMask doesn't seem to help. Doesn't + -- crash, but does emit annoying log messages. destroy = do grpcDebug "withServerCallAsync: destroying." destroyServerCall c grpcDebug "withServerCallAsync: cleanup finished." --- | Sequence of 'Op's needed to receive a normal (non-streaming) call. --- TODO: We have to put 'OpRecvCloseOnServer' in the response ops, or else the --- client times out. Given this, I have no idea how to check for cancellation on --- the server. -serverOpsGetNormalCall :: MetadataMap -> [Op] -serverOpsGetNormalCall initMetadata = - [OpSendInitialMetadata initMetadata, - OpRecvMessage] - --- | Sequence of 'Op's needed to respond to a normal (non-streaming) call. -serverOpsSendNormalResponse :: ByteString - -> MetadataMap - -> C.StatusCode - -> StatusDetails - -> [Op] -serverOpsSendNormalResponse body metadata code details = - [OpRecvCloseOnServer, - OpSendMessage body, - OpSendStatusFromServer metadata code details] - -- | A handler for an unregistered server call; bytestring arguments are the -- request body and response body respectively. type ServerHandler @@ -125,6 +96,9 @@ serverHandleNormalCall' grpcDebug $ "got client metadata: " ++ show metadata grpcDebug $ "call_details host is: " ++ show callHost (rsp, trailMeta, st, ds) <- f sc body + -- TODO: We have to put 'OpRecvCloseOnServer' in the response ops, + -- or else the client times out. Given this, I have no idea how to + -- check for cancellation on the server. runOps c cq [ OpRecvCloseOnServer , OpSendMessage rsp, @@ -141,13 +115,12 @@ serverHandleNormalCall' serverReader :: Server -> ServerCall - -> MetadataMap -- ^ initial server metadata + -> MetadataMap -- ^ Initial server metadata -> ServerReaderHandlerLL -> IO (Either GRPCIOError ()) serverReader _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = runExceptT $ do - (mmsg, trailMeta, st, ds) <- - runStreamingProxy "serverReader" c ccq (f (convertCall sc) streamRecv) + (mmsg, trailMeta, st, ds) <- liftIO $ f (convertCall sc) (streamRecvPrim c ccq) runOps' c ccq ( OpSendInitialMetadata initMeta : OpSendStatusFromServer trailMeta st ds : maybe [] ((:[]) . OpSendMessage) mmsg @@ -156,27 +129,23 @@ serverReader _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = serverWriter :: Server -> ServerCall - -> MetadataMap - -- ^ Initial server metadata + -> MetadataMap -- ^ Initial server metadata -> ServerWriterHandlerLL -> IO (Either GRPCIOError ()) serverWriter _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = runExceptT $ do bs <- recvInitialMessage c ccq sendInitialMetadata c ccq initMeta - let regCall = fmap (const bs) (convertCall sc) - st <- runStreamingProxy "serverWriter" c ccq (f regCall streamSend) + st <- liftIO $ f (const bs <$> convertCall sc) (streamSendPrim c ccq) sendStatusFromServer c ccq st serverRW :: Server -> ServerCall - -> MetadataMap - -- ^ initial server metadata + -> MetadataMap -- ^ Initial server metadata -> ServerRWHandlerLL -> IO (Either GRPCIOError ()) serverRW _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = runExceptT $ do sendInitialMetadata c ccq initMeta - let regCall = convertCall sc - st <- runStreamingProxy "serverRW" c ccq (f regCall streamRecv streamSend) + st <- liftIO $ f (convertCall sc) (streamRecvPrim c ccq) (streamSendPrim c ccq) sendStatusFromServer c ccq st diff --git a/tests/LowLevelTests.hs b/tests/LowLevelTests.hs index 3aa9e0d..e6e91e6 100644 --- a/tests/LowLevelTests.hs +++ b/tests/LowLevelTests.hs @@ -295,13 +295,14 @@ testBiDiStreaming = client c = do rm <- clientRegisterMethodBiDiStreaming c "/bidi" - eea <- clientRW c rm 10 clientInitMD $ \_initMD recv send -> do + eea <- clientRW c rm 10 clientInitMD $ \_srvInitMD recv send writesDone -> do send "cw0" `is` Right () recv `is` Right (Just "sw0") send "cw1" `is` Right () recv `is` Right (Just "sw1") recv `is` Right (Just "sw2") - return () + writesDone `is` Right () + recv `is` Right Nothing eea @?= Right (trailMD, serverStatus, serverDtls) server s = do @@ -330,19 +331,19 @@ testBiDiStreamingUnregistered = client c = do rm <- clientRegisterMethodBiDiStreaming c "/bidi" - eea <- clientRW c rm 10 clientInitMD $ \_initMD recv send -> do + eea <- clientRW c rm 10 clientInitMD $ \_srvInitMD recv send writesDone -> do send "cw0" `is` Right () recv `is` Right (Just "sw0") send "cw1" `is` Right () recv `is` Right (Just "sw1") recv `is` Right (Just "sw2") - return () + writesDone `is` Right () + recv `is` Right Nothing eea @?= Right (trailMD, serverStatus, serverDtls) server s = U.withServerCallAsync s $ \call -> do eea <- U.serverRW s call serverInitMD $ \sc recv send -> do - liftIO $ checkMD "Client request metadata mismatch" - clientInitMD (metadata sc) + checkMD "Client request metadata mismatch" clientInitMD (metadata sc) recv `is` Right (Just "cw0") send "sw0" `is` Right () recv `is` Right (Just "cw1") diff --git a/tests/TestServer.hs b/tests/TestServer.hs index 58ae506..81a028b 100644 --- a/tests/TestServer.hs +++ b/tests/TestServer.hs @@ -23,7 +23,7 @@ handleNormalCall call = result = sum nums -handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> Streaming (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails) +handleClientStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> IO (Maybe SimpleServiceResponse, MetadataMap, StatusCode, StatusDetails) handleClientStreamingCall call recvRequest = go 0 "" where go sumAccum nameAccum = recvRequest >>= \req -> @@ -34,7 +34,7 @@ handleClientStreamingCall call recvRequest = go 0 "" Right (Just (SimpleServiceRequest name nums)) -> go (sumAccum + sum nums) (nameAccum <> name) -handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails) +handleServerStreamingCall :: ServerCall SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails) handleServerStreamingCall call sendResponse = go where go = do forM_ nums $ \num -> sendResponse (SimpleServiceResponse requestName num) @@ -42,7 +42,7 @@ handleServerStreamingCall call sendResponse = go SimpleServiceRequest requestName nums = payload call -handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> Streaming (MetadataMap, StatusCode, StatusDetails) +handleBiDiStreamingCall :: ServerCall () -> StreamRecv SimpleServiceRequest -> StreamSend SimpleServiceResponse -> IO (MetadataMap, StatusCode, StatusDetails) handleBiDiStreamingCall call recvRequest sendResponse = go where go = recvRequest >>= \req -> case req of