diff --git a/examples/echo/echo-server/Main.hs b/examples/echo/echo-server/Main.hs index db89af2..9a08fac 100644 --- a/examples/echo/echo-server/Main.hs +++ b/examples/echo/echo-server/Main.hs @@ -8,13 +8,14 @@ import Control.Monad (forever) import Data.ByteString (ByteString) import Network.GRPC.LowLevel import qualified Network.GRPC.LowLevel.Server.Unregistered as U +import qualified Network.GRPC.LowLevel.Call.Unregistered as U serverMeta :: MetadataMap serverMeta = [("test_meta", "test_meta_value")] -handler :: ByteString -> MetadataMap -> MethodName +handler :: U.ServerCall -> ByteString -> MetadataMap -> MethodName -> IO (ByteString, MetadataMap, StatusDetails) -handler reqBody _reqMeta _method = do +handler _call reqBody _reqMeta _method = do --putStrLn $ "Got request for method: " ++ show method --putStrLn $ "Got metadata: " ++ show reqMeta return (reqBody, serverMeta, StatusDetails "") @@ -34,8 +35,8 @@ regMain = withGRPC $ \grpc -> do forever $ do let method = head (registeredMethods server) result <- serverHandleNormalCall server method 15 serverMeta $ - \reqBody _reqMeta -> return (reqBody, serverMeta, serverMeta, - StatusDetails "") + \_call reqBody _reqMeta -> return (reqBody, serverMeta, + StatusDetails "") case result of Left x -> putStrLn $ "registered call result error: " ++ show x Right _ -> return () @@ -44,8 +45,8 @@ regMain = withGRPC $ \grpc -> do regLoop :: Server -> RegisteredMethod -> IO () regLoop server method = forever $ do result <- serverHandleNormalCall server method 15 serverMeta $ - \reqBody _reqMeta -> return (reqBody, serverMeta, serverMeta, - StatusDetails "") + \_call reqBody _reqMeta -> return (reqBody, serverMeta, + StatusDetails "") case result of Left x -> putStrLn $ "registered call result error: " ++ show x Right _ -> return () diff --git a/grpc-haskell.cabal b/grpc-haskell.cabal index 2522f79..fc96865 100644 --- a/grpc-haskell.cabal +++ b/grpc-haskell.cabal @@ -44,7 +44,6 @@ library Network.GRPC.LowLevel Network.GRPC.LowLevel.Server.Unregistered Network.GRPC.LowLevel.Client.Unregistered - other-modules: Network.GRPC.LowLevel.CompletionQueue Network.GRPC.LowLevel.CompletionQueue.Internal Network.GRPC.LowLevel.CompletionQueue.Unregistered @@ -117,6 +116,7 @@ test-suite test , containers ==0.5.* other-modules: LowLevelTests, + LowLevelTests.Op, UnsafeTests default-language: Haskell2010 ghc-options: -Wall -fwarn-incomplete-patterns -fno-warn-unused-do-bind -g -threaded diff --git a/src/Network/GRPC/LowLevel.hs b/src/Network/GRPC/LowLevel.hs index 3976d07..dfa8155 100644 --- a/src/Network/GRPC/LowLevel.hs +++ b/src/Network/GRPC/LowLevel.hs @@ -34,6 +34,7 @@ GRPC , withServer , serverHandleNormalCall , withServerCall +, serverCallCancel -- * Client , ClientConfig(..) @@ -45,6 +46,7 @@ GRPC , clientRegisterMethod , clientRequest , withClientCall +, clientCallCancel -- * Ops , Op(..) diff --git a/src/Network/GRPC/LowLevel/Call.hs b/src/Network/GRPC/LowLevel/Call.hs index 2907b01..a74faf8 100644 --- a/src/Network/GRPC/LowLevel/Call.hs +++ b/src/Network/GRPC/LowLevel/Call.hs @@ -17,6 +17,7 @@ import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.ByteBuffer as C import qualified Network.GRPC.Unsafe.Metadata as C import qualified Network.GRPC.Unsafe.Time as C +import qualified Network.GRPC.Unsafe.Op as C import Network.GRPC.LowLevel.GRPC (MetadataMap, grpcDebug) @@ -56,6 +57,9 @@ data RegisteredMethod = RegisteredMethod {methodType :: GRPCMethodType, -- This is used to associate send/receive 'Op's with a request. data ClientCall = ClientCall { unClientCall :: C.Call } +clientCallCancel :: ClientCall -> IO () +clientCallCancel cc = C.grpcCallCancel (unClientCall cc) C.reserved + -- | Represents one registered GRPC call on the server. Contains pointers to all -- the C state needed to respond to a registered call. data ServerCall = ServerCall @@ -66,6 +70,10 @@ data ServerCall = ServerCall callDeadline :: C.CTimeSpecPtr } +serverCallCancel :: ServerCall -> C.StatusCode -> String -> IO () +serverCallCancel sc code reason = + C.grpcCallCancelWithStatus (unServerCall sc) code reason C.reserved + serverCallGetMetadata :: ServerCall -> IO MetadataMap serverCallGetMetadata ServerCall{..} = do marray <- peek requestMetadataRecv diff --git a/src/Network/GRPC/LowLevel/Call/Unregistered.hs b/src/Network/GRPC/LowLevel/Call/Unregistered.hs index dc714f3..b14537a 100644 --- a/src/Network/GRPC/LowLevel/Call/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Call/Unregistered.hs @@ -10,6 +10,7 @@ import Network.GRPC.LowLevel.Call (Host (..), MethodName (..)) import Network.GRPC.LowLevel.GRPC (MetadataMap, grpcDebug) import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.Metadata as C +import qualified Network.GRPC.Unsafe.Op as C -- | Represents one unregistered GRPC call on the server. -- Contains pointers to all the C state needed to respond to an unregistered @@ -21,6 +22,10 @@ data ServerCall = ServerCall , callDetails :: C.CallDetails } +serverCallCancel :: ServerCall -> C.StatusCode -> String -> IO () +serverCallCancel sc code reason = + C.grpcCallCancelWithStatus (unServerCall sc) code reason C.reserved + serverCallGetMetadata :: ServerCall -> IO MetadataMap serverCallGetMetadata ServerCall{..} = do marray <- peek requestMetadataRecv diff --git a/src/Network/GRPC/LowLevel/CompletionQueue.hs b/src/Network/GRPC/LowLevel/CompletionQueue.hs index d9e1bd8..c491714 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue.hs @@ -128,10 +128,9 @@ serverRequestCall :: C.Server -> CompletionQueue -> TimeoutSeconds -> RegisteredMethod - -> MetadataMap -> IO (Either GRPCIOError ServerCall) serverRequestCall - server cq@CompletionQueue{..} timeLimit RegisteredMethod{..} initMeta = + server cq@CompletionQueue{..} timeLimit RegisteredMethod{..} = withPermission Push cq $ do -- TODO: Is gRPC supposed to populate this deadline? -- NOTE: the below stuff is freed when we free the call we return. @@ -139,15 +138,6 @@ serverRequestCall callPtr <- malloc metadataArrayPtr <- C.metadataArrayCreate metadataArray <- peek metadataArrayPtr - #ifdef DEBUG - metaCount <- C.metadataArrayGetCount metadataArray - metaCap <- C.metadataArrayGetCapacity metadataArray - kvPtr <- C.metadataArrayGetMetadata metadataArray - grpcDebug $ "grpc-created meta: count: " ++ show metaCount - ++ " capacity: " ++ show metaCap ++ " ptr: " ++ show kvPtr - #endif - metadataContents <- C.createMetadata initMeta - C.metadataArraySetMetadata metadataArray metadataContents bbPtr <- malloc tag <- newTag cq callError <- C.grpcServerRequestRegisteredCall diff --git a/src/Network/GRPC/LowLevel/Op.hs b/src/Network/GRPC/LowLevel/Op.hs index 57109ca..88fb211 100644 --- a/src/Network/GRPC/LowLevel/Op.hs +++ b/src/Network/GRPC/LowLevel/Op.hs @@ -129,15 +129,17 @@ freeOpContext (OpRecvCloseOnServerContext pcancelled) = grpcDebug ("freeOpContext: freeing pcancelled: " ++ show pcancelled) >> free pcancelled --- | Converts a list of 'Op's into the corresponding 'OpContext's and guarantees --- they will be cleaned up correctly. -withOpContexts :: [Op] -> ([OpContext] -> IO a) -> IO a -withOpContexts ops = bracket (mapM createOpContext ops) - (mapM freeOpContext) - -withOpArray :: Int -> (C.OpArray -> IO a) -> IO a -withOpArray n = bracket (C.opArrayCreate n) - (flip C.opArrayDestroy n) +-- | Allocates an `OpArray` and a list of `OpContext`s from the given list of +-- `Op`s. +withOpArrayAndCtxts :: [Op] -> ((C.OpArray, [OpContext]) -> IO a) -> IO a +withOpArrayAndCtxts ops = bracket setup teardown + where setup = do ctxts <- mapM createOpContext ops + let l = length ops + arr <- C.opArrayCreate l + sequence_ $ zipWith (setOpArray arr) [0..l-1] ctxts + return (arr, ctxts) + teardown (arr, ctxts) = do C.opArrayDestroy arr (length ctxts) + mapM_ freeOpContext ctxts -- | Container holding GC-managed results for 'Op's which receive data. data OpRecvResult = @@ -216,25 +218,22 @@ runOps :: C.Call -> IO (Either GRPCIOError [OpRecvResult]) runOps call cq ops timeLimit = let l = length ops in - withOpArray l $ \opArray -> do - grpcDebug "runOps: created op array." - withOpContexts ops $ \contexts -> do - grpcDebug $ "runOps: allocated op contexts: " ++ show contexts - sequence_ $ zipWith (setOpArray opArray) [0..l-1] contexts - tag <- newTag cq - callError <- startBatch cq call opArray l tag - grpcDebug $ "runOps: called start_batch. callError: " - ++ (show callError) - case callError of - Left x -> return $ Left x - Right () -> do - ev <- pluck cq tag timeLimit - grpcDebug $ "runOps: pluck returned " ++ show ev - case ev of - Right () -> do - grpcDebug "runOps: got good op; starting." - fmap (Right . catMaybes) $ mapM resultFromOpContext contexts - Left err -> return $ Left err + withOpArrayAndCtxts ops $ \(opArray, contexts) -> do + grpcDebug $ "runOps: allocated op contexts: " ++ show contexts + tag <- newTag cq + callError <- startBatch cq call opArray l tag + grpcDebug $ "runOps: called start_batch. callError: " + ++ (show callError) + case callError of + Left x -> return $ Left x + Right () -> do + ev <- pluck cq tag timeLimit + grpcDebug $ "runOps: pluck returned " ++ show ev + case ev of + Right () -> do + grpcDebug "runOps: got good op; starting." + fmap (Right . catMaybes) $ mapM resultFromOpContext contexts + Left err -> return $ Left err -- | If response status info is present in the given 'OpRecvResult's, returns -- a tuple of trailing metadata, status code, and status details. diff --git a/src/Network/GRPC/LowLevel/Server.hs b/src/Network/GRPC/LowLevel/Server.hs index aaca99b..e2c4172 100644 --- a/src/Network/GRPC/LowLevel/Server.hs +++ b/src/Network/GRPC/LowLevel/Server.hs @@ -125,42 +125,23 @@ serverRegisterMethod _ _ _ _ = error "Streaming methods not implemented yet." serverCreateCall :: Server -> RegisteredMethod -> TimeoutSeconds - -> MetadataMap -> IO (Either GRPCIOError ServerCall) -serverCreateCall Server{..} rm timeLimit initMeta = - serverRequestCall internalServer serverCQ timeLimit rm initMeta +serverCreateCall Server{..} rm timeLimit = + serverRequestCall internalServer serverCQ timeLimit rm withServerCall :: Server -> RegisteredMethod -> TimeoutSeconds - -> MetadataMap -> (ServerCall -> IO (Either GRPCIOError a)) -> IO (Either GRPCIOError a) -withServerCall server regmethod timeout initMeta f = do - createResult <- serverCreateCall server regmethod timeout initMeta +withServerCall server regmethod timeout f = do + createResult <- serverCreateCall server regmethod timeout case createResult of Left x -> return $ Left x Right call -> f call `finally` logDestroy call where logDestroy c = grpcDebug "withServerRegisteredCall: destroying." >> destroyServerCall c --- | Sequence of 'Op's needed to receive a normal (non-streaming) call. -serverOpsGetNormalCall :: MetadataMap -> [Op] -serverOpsGetNormalCall initMetadata = - [OpSendInitialMetadata initMetadata, - OpRecvMessage] - --- | Sequence of 'Op's needed to respond to a normal (non-streaming) call. -serverOpsSendNormalResponse :: ByteString - -> MetadataMap - -> C.StatusCode - -> StatusDetails - -> [Op] -serverOpsSendNormalResponse body metadata code details = - [OpRecvCloseOnServer, - OpSendMessage body, - OpSendStatusFromServer metadata code details] - serverOpsSendNormalRegisteredResponse :: ByteString -> MetadataMap -- ^ initial metadata @@ -180,13 +161,14 @@ serverOpsSendNormalRegisteredResponse -- body, with the bytestring response body in the result tuple. The first -- metadata parameter refers to the request metadata, with the two metadata -- values in the result tuple being the initial and trailing metadata --- respectively. +-- respectively. We pass in the 'ServerCall' so that the server can call +-- 'serverCallCancel' on it if needed. -- TODO: make a more rigid type for this with a Maybe MetadataMap for the -- trailing meta, and use it for both kinds of call handlers. type ServerHandler - = ByteString -> MetadataMap - -> IO (ByteString, MetadataMap, MetadataMap, StatusDetails) + = ServerCall -> ByteString -> MetadataMap + -> IO (ByteString, MetadataMap, StatusDetails) -- TODO: we will want to replace this with some more general concept that also -- works with streaming calls in the future. @@ -198,12 +180,12 @@ serverHandleNormalCall :: Server -- ^ Initial server metadata -> ServerHandler -> IO (Either GRPCIOError ()) -serverHandleNormalCall s@Server{..} rm timeLimit srvMetadata f = do +serverHandleNormalCall s@Server{..} rm timeLimit initMeta f = do -- TODO: we use this timeLimit twice, so the max time spent is 2*timeLimit. -- Should we just hard-code time limits instead? Not sure if client -- programmer cares, since this function will likely just be put in a loop -- anyway. - withServerCall s rm timeLimit srvMetadata $ \call -> do + withServerCall s rm timeLimit $ \call -> do grpcDebug "serverHandleNormalCall(R): starting batch." debugServerCall call payload <- serverCallGetPayload call @@ -213,7 +195,7 @@ serverHandleNormalCall s@Server{..} rm timeLimit srvMetadata f = do Nothing -> error "serverHandleNormalCall(R): payload empty." Just requestBody -> do requestMeta <- serverCallGetMetadata call - (respBody, initMeta, trailingMeta, details) <- f requestBody requestMeta + (respBody, trailingMeta, details) <- f call requestBody requestMeta let status = C.GrpcStatusOk let respOps = serverOpsSendNormalRegisteredResponse respBody initMeta trailingMeta status details diff --git a/src/Network/GRPC/LowLevel/Server/Unregistered.hs b/src/Network/GRPC/LowLevel/Server/Unregistered.hs index e513bac..9340f3c 100644 --- a/src/Network/GRPC/LowLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Server/Unregistered.hs @@ -9,10 +9,8 @@ import Network.GRPC.LowLevel.Call.Unregistered import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds) import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall) import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op (OpRecvResult (..), runOps) -import Network.GRPC.LowLevel.Server (Server (..), - serverOpsGetNormalCall, - serverOpsSendNormalResponse) +import Network.GRPC.LowLevel.Op (Op(..), OpRecvResult (..), runOps) +import Network.GRPC.LowLevel.Server (Server (..)) import qualified Network.GRPC.Unsafe.Op as C serverCreateCall :: Server -> TimeoutSeconds @@ -31,10 +29,30 @@ withServerCall server timeout f = do where logDestroy c = grpcDebug "withServerCall: destroying." >> destroyServerCall c +-- | Sequence of 'Op's needed to receive a normal (non-streaming) call. +-- TODO: We have to put 'OpRecvCloseOnServer' in the response ops, or else the +-- client times out. Given this, I have no idea how to check for cancellation on +-- the server. +serverOpsGetNormalCall :: MetadataMap -> [Op] +serverOpsGetNormalCall initMetadata = + [OpSendInitialMetadata initMetadata, + OpRecvMessage] + +-- | Sequence of 'Op's needed to respond to a normal (non-streaming) call. +serverOpsSendNormalResponse :: ByteString + -> MetadataMap + -> C.StatusCode + -> StatusDetails + -> [Op] +serverOpsSendNormalResponse body metadata code details = + [OpRecvCloseOnServer, + OpSendMessage body, + OpSendStatusFromServer metadata code details] + -- | A handler for an unregistered server call; bytestring arguments are the -- request body and response body respectively. type ServerHandler - = ByteString -> MetadataMap -> MethodName + = ServerCall -> ByteString -> MetadataMap -> MethodName -> IO (ByteString, MetadataMap, StatusDetails) -- | Handle one unregistered call. @@ -58,7 +76,7 @@ serverHandleNormalCall s@Server{..} timeLimit srvMetadata f = do methodName <- serverCallGetMethodName call hostName <- serverCallGetHost call grpcDebug $ "call_details host is: " ++ show hostName - (respBody, respMetadata, details) <- f body requestMeta methodName + (respBody, respMetadata, details) <- f call body requestMeta methodName let status = C.GrpcStatusOk let respOps = serverOpsSendNormalResponse respBody respMetadata status details diff --git a/tests/LowLevelTests.hs b/tests/LowLevelTests.hs index dfea42a..d17f0c8 100644 --- a/tests/LowLevelTests.hs +++ b/tests/LowLevelTests.hs @@ -3,7 +3,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} -module LowLevelTests (lowLevelTests) where +module LowLevelTests where import Control.Concurrent (threadDelay) import Control.Concurrent.Async @@ -34,6 +34,7 @@ lowLevelTests = testGroup "Unit tests of low-level Haskell library" -- , testWrongEndpoint , testPayload , testPayloadUnregistered + , testServerCancel , testGoaway , testSlowServer ] @@ -78,8 +79,8 @@ testServerTimeoutNoClient :: TestTree testServerTimeoutNoClient = serverOnlyTest "wait timeout when client DNE" [("/foo", Normal)] $ \s -> do let rm = head (registeredMethods s) - r <- serverHandleNormalCall s rm 1 mempty $ \_ _ -> - return ("", mempty, mempty, StatusDetails "details") + r <- serverHandleNormalCall s rm 1 mempty $ \_ _ _ -> + return ("", mempty, StatusDetails "details") r @?= Left GRPCIOTimeout -- TODO: fix this test: currently, client seems to hang and server times out, @@ -99,8 +100,8 @@ testWrongEndpoint = server s = do length (registeredMethods s) @?= 1 let rm = head (registeredMethods s) - r <- serverHandleNormalCall s rm 10 mempty $ \_ _ -> do - return ("reply test", dummyMeta, dummyMeta, StatusDetails "details string") + r <- serverHandleNormalCall s rm 10 mempty $ \_ _ _ -> do + return ("reply test", dummyMeta, StatusDetails "details string") r @?= Right () -- TODO: There seems to be a race here (and in other client/server pairs, of @@ -126,10 +127,27 @@ testPayload = server s = do length (registeredMethods s) @?= 1 let rm = head (registeredMethods s) - r <- serverHandleNormalCall s rm 11 mempty $ \reqBody reqMD -> do + r <- serverHandleNormalCall s rm 11 dummyMeta $ \_ reqBody reqMD -> do reqBody @?= "Hello!" checkMD "Server metadata mismatch" clientMD reqMD - return ("reply test", dummyMeta, dummyMeta, StatusDetails "details string") + return ("reply test", dummyMeta, StatusDetails "details string") + r @?= Right () + +testServerCancel :: TestTree +testServerCancel = + csTest "server cancel call" client server [("/foo", Normal)] + where + client c = do + rm <- clientRegisterMethod c "/foo" Normal + res <- clientRequest c rm 10 "" mempty + res @?= Left (GRPCIOBadStatusCode GrpcStatusCancelled + (StatusDetails + "Received RST_STREAM err=8")) + server s = do + let rm = head (registeredMethods s) + r <- serverHandleNormalCall s rm 10 mempty $ \c _ _ -> do + serverCallCancel c GrpcStatusCancelled "" + return (mempty, mempty, "") r @?= Right () testPayloadUnregistered :: TestTree @@ -143,7 +161,7 @@ testPayloadUnregistered = rspBody @?= "reply test" details @?= "details string" server s = do - r <- U.serverHandleNormalCall s 11 mempty $ \body _md meth -> do + r <- U.serverHandleNormalCall s 11 mempty $ \_ body _md meth -> do body @?= "Hello!" meth @?= "/foo" return ("reply test", mempty, "details string") @@ -184,9 +202,9 @@ testSlowServer = result == deadlineExceededStatus server s = do let rm = head (registeredMethods s) - serverHandleNormalCall s rm 1 mempty $ \_ _ -> do + serverHandleNormalCall s rm 1 mempty $ \_ _ _ -> do threadDelay (2*10^(6 :: Int)) - return ("", mempty, mempty, StatusDetails "") + return ("", mempty, StatusDetails "") return () -------------------------------------------------------------------------------- @@ -195,9 +213,9 @@ testSlowServer = dummyMeta :: M.Map ByteString ByteString dummyMeta = [("foo","bar")] -dummyHandler :: ByteString -> MetadataMap - -> IO (ByteString, MetadataMap, MetadataMap, StatusDetails) -dummyHandler _ _ = return ("", mempty, mempty, StatusDetails "") +dummyHandler :: ServerCall -> ByteString -> MetadataMap + -> IO (ByteString, MetadataMap, StatusDetails) +dummyHandler _ _ _ = return ("", mempty, StatusDetails "") unavailableStatus :: Either GRPCIOError a unavailableStatus = diff --git a/tests/LowLevelTests/Op.hs b/tests/LowLevelTests/Op.hs new file mode 100644 index 0000000..f0b1e5b --- /dev/null +++ b/tests/LowLevelTests/Op.hs @@ -0,0 +1,104 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} + +module LowLevelTests.Op where + +import Control.Concurrent (threadDelay) +import Control.Concurrent.Async +import Control.Monad +import Data.ByteString (ByteString, isPrefixOf) +import qualified Data.Map as M +import Foreign.Storable (peek) +import Test.Tasty +import Test.Tasty.HUnit as HU (testCase, (@?=), + assertBool) + +import Network.GRPC.LowLevel +import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.Client +import Network.GRPC.LowLevel.Server +import Network.GRPC.LowLevel.Op +import Network.GRPC.LowLevel.CompletionQueue + +lowLevelOpTests :: TestTree +lowLevelOpTests = testGroup "Synchronous unit tests of low-level Op interface" + [testCancelWhileHandling + ,testCancelFromServer] + +testCancelWhileHandling :: TestTree +testCancelWhileHandling = + testCase "Client/Server - cancel after handler starts does nothing" $ + runSerialTest $ \grpc -> + withClientServerUnaryCall grpc $ + \(c@Client{..}, s@Server{..}, cc@ClientCall{..}, sc@ServerCall{..}) -> do + withOpArrayAndCtxts serverEmptyRecvOps $ \(opArray, ctxts) -> do + tag <- newTag serverCQ + startBatch serverCQ unServerCall opArray 3 tag + pluck serverCQ tag 1 + let (OpRecvCloseOnServerContext pcancelled) = last ctxts + cancelledBefore <- peek pcancelled + cancelledBefore @?= 0 + clientCallCancel cc + threadDelay 1000000 + cancelledAfter <- peek pcancelled + cancelledAfter @?= 0 + return $ Right () + +testCancelFromServer :: TestTree +testCancelFromServer = + testCase "Client/Server - client receives server cancellation" $ + runSerialTest $ \grpc -> + withClientServerUnaryCall grpc $ + \(c@Client{..}, s@Server{..}, cc@ClientCall{..}, sc@ServerCall{..}) -> do + serverCallCancel sc GrpcStatusPermissionDenied "TestStatus" + clientRes <- runOps unClientCall clientCQ clientRecvOps 1 + case clientRes of + Left x -> error $ "Client recv error: " ++ show x + Right [_,_,OpRecvStatusOnClientResult _ code details] -> do + code @?= GrpcStatusPermissionDenied + assertBool "Received status details or RST_STREAM error" $ + details == "TestStatus" + || + isPrefixOf "Received RST_STREAM" details + return $ Right () + + +runSerialTest :: (GRPC -> IO (Either GRPCIOError ())) -> IO () +runSerialTest f = + withGRPC f >>= \case Left x -> error $ show x + Right () -> return () + +withClientServerUnaryCall :: GRPC + -> ((Client, Server, ClientCall, ServerCall) + -> IO (Either GRPCIOError a)) + -> IO (Either GRPCIOError a) +withClientServerUnaryCall grpc f = do + withClient grpc clientConf $ \c -> do + crm <- clientRegisterMethod c "/foo" Normal + withServer grpc serverConf $ \s -> + withClientCall c crm 10 $ \cc -> do + let srm = head (registeredMethods s) + -- NOTE: We need to send client ops here or else `withServerCall` hangs, + -- because registered methods try to do recv ops immediately when + -- created. If later we want to send payloads or metadata, we'll need + -- to tweak this. + clientRes <- runOps (unClientCall cc) (clientCQ c) clientEmptySendOps 1 + withServerCall s srm 10 $ \sc -> + f (c, s, cc, sc) + +serverConf = (ServerConfig "localhost" 50051 [("/foo", Normal)]) + +clientConf = (ClientConfig "localhost" 50051) + +clientEmptySendOps = [OpSendInitialMetadata mempty, + OpSendMessage "", + OpSendCloseFromClient] + +clientRecvOps = [OpRecvInitialMetadata, + OpRecvMessage, + OpRecvStatusOnClient] + +serverEmptyRecvOps = [OpSendInitialMetadata mempty, + OpRecvMessage, + OpRecvCloseOnServer] diff --git a/tests/Properties.hs b/tests/Properties.hs index 395c038..f7bf152 100644 --- a/tests/Properties.hs +++ b/tests/Properties.hs @@ -1,9 +1,11 @@ import LowLevelTests +import LowLevelTests.Op import Test.Tasty import UnsafeTests main :: IO () main = defaultMain $ testGroup "GRPC Unit Tests" [ unsafeTests + , lowLevelOpTests , lowLevelTests ]