From a26497c82cd5b75a779a762aa9e7f3748a2526d6 Mon Sep 17 00:00:00 2001 From: Moritz Kiefer Date: Thu, 22 Aug 2019 17:53:41 +0200 Subject: [PATCH] Expose ClientCall in ClientReaderHandler and ClientRWHandler (#87) This allows you to cancel the call from within the callback using `clientCallCancel`. --- core/src/Network/GRPC/LowLevel/Client.hs | 13 +++++++------ core/tests/LowLevelTests.hs | 8 ++++---- examples/hellos/hellos-client/Main.hs | 4 ++-- src/Network/GRPC/HighLevel/Client.hs | 10 +++++----- tests/TestClient.hs | 4 ++-- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/core/src/Network/GRPC/LowLevel/Client.hs b/core/src/Network/GRPC/LowLevel/Client.hs index a221c82..77c461e 100644 --- a/core/src/Network/GRPC/LowLevel/Client.hs +++ b/core/src/Network/GRPC/LowLevel/Client.hs @@ -256,7 +256,7 @@ compileNormalRequestResults x = -- clientReader (client side of server streaming mode) -- | First parameter is initial server metadata. -type ClientReaderHandler = MetadataMap -> StreamRecv ByteString -> IO () +type ClientReaderHandler = ClientCall -> MetadataMap -> StreamRecv ByteString -> IO () type ClientReaderResult = (MetadataMap, C.StatusCode, StatusDetails) clientReader :: Client @@ -269,13 +269,13 @@ clientReader :: Client clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f = withClientCall cl rm tm go where - go (unsafeCC -> c) = runExceptT $ do + go cc@(unsafeCC -> c) = runExceptT $ do void $ runOps' c cq [ OpSendInitialMetadata initMeta , OpSendMessage body , OpSendCloseFromClient ] srvMD <- recvInitialMetadata c cq - liftIO $ f srvMD (streamRecvPrim c cq) + liftIO $ f cc srvMD (streamRecvPrim c cq) recvStatusOnClient c cq -------------------------------------------------------------------------------- @@ -326,7 +326,8 @@ pattern CWRFinal mmsg initMD trailMD st ds -- clientRW (client side of bidirectional streaming mode) type ClientRWHandler - = IO (Either GRPCIOError MetadataMap) + = ClientCall + -> IO (Either GRPCIOError MetadataMap) -> StreamRecv ByteString -> StreamSend ByteString -> WritesDone @@ -352,7 +353,7 @@ clientRW' :: Client -> MetadataMap -> ClientRWHandler -> IO (Either GRPCIOError ClientRWResult) -clientRW' (clientCQ -> cq) (unsafeCC -> c) initMeta f = runExceptT $ do +clientRW' (clientCQ -> cq) cc@(unsafeCC -> c) initMeta f = runExceptT $ do sendInitialMetadata c cq initMeta -- 'mdmv' is used to synchronize between callers of 'getMD' and 'recv' @@ -412,7 +413,7 @@ clientRW' (clientCQ -> cq) (unsafeCC -> c) initMeta f = runExceptT $ do -- programmer. writesDone = writesDonePrim c cq - liftIO (f getMD recv send writesDone) + liftIO (f cc getMD recv send writesDone) recvStatusOnClient c cq -- Finish() -------------------------------------------------------------------------------- diff --git a/core/tests/LowLevelTests.hs b/core/tests/LowLevelTests.hs index efd774e..e3fa00b 100644 --- a/core/tests/LowLevelTests.hs +++ b/core/tests/LowLevelTests.hs @@ -406,7 +406,7 @@ testServerStreaming = client c = do rm <- clientRegisterMethodServerStreaming c "/feed" - eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do + eea <- clientReader c rm 10 clientPay clientInitMD $ \_cc initMD recv -> do checkMD "Server initial metadata mismatch" serverInitMD initMD forM_ pays $ \p -> recv `is` Right (Just p) recv `is` Right Nothing @@ -436,7 +436,7 @@ testServerStreamingUnregistered = client c = do rm <- clientRegisterMethodServerStreaming c "/feed" - eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do + eea <- clientReader c rm 10 clientPay clientInitMD $ \_cc initMD recv -> do checkMD "Server initial metadata mismatch" serverInitMD initMD forM_ pays $ \p -> recv `is` Right (Just p) recv `is` Right Nothing @@ -517,7 +517,7 @@ testBiDiStreaming = client c = do rm <- clientRegisterMethodBiDiStreaming c "/bidi" - eea <- clientRW c rm 10 clientInitMD $ \getMD recv send writesDone -> do + eea <- clientRW c rm 10 clientInitMD $ \_cc getMD recv send writesDone -> do either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD send "cw0" `is` Right () recv `is` Right (Just "sw0") @@ -553,7 +553,7 @@ testBiDiStreamingUnregistered = client c = do rm <- clientRegisterMethodBiDiStreaming c "/bidi" - eea <- clientRW c rm 10 clientInitMD $ \getMD recv send writesDone -> do + eea <- clientRW c rm 10 clientInitMD $ \_cc getMD recv send writesDone -> do either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD send "cw0" `is` Right () recv `is` Right (Just "sw0") diff --git a/examples/hellos/hellos-client/Main.hs b/examples/hellos/hellos-client/Main.hs index 58e244c..5cfea55 100644 --- a/examples/hellos/hellos-client/Main.hs +++ b/examples/hellos/hellos-client/Main.hs @@ -42,7 +42,7 @@ doHelloSS c n = do let pay = SSRqt "server streaming mode" (fromIntegral n) enc = BL.toStrict . toLazyByteString $ pay err desc e = fail $ "doHelloSS: " ++ desc ++ " error: " ++ show e - eea <- clientReader c rm n enc mempty $ \_md recv -> do + eea <- clientReader c rm n enc mempty $ \_cc _md recv -> do n' <- flip fix (0::Int) $ \go i -> recv >>= \case Left e -> err "recv" e Right Nothing -> return i @@ -84,7 +84,7 @@ doHelloBi c n = do let pay = BiRqtRpy "bidi payload" enc = BL.toStrict . toLazyByteString $ pay err desc e = fail $ "doHelloBi: " ++ desc ++ " error: " ++ show e - eea <- clientRW c rm n mempty $ \_getMD recv send writesDone -> do + eea <- clientRW c rm n mempty $ \_cc _getMD recv send writesDone -> do -- perform n writes on a worker thread thd <- async $ do replicateM_ n $ send enc >>= \case diff --git a/src/Network/GRPC/HighLevel/Client.hs b/src/Network/GRPC/HighLevel/Client.hs index 619b31f..1067ae1 100644 --- a/src/Network/GRPC/HighLevel/Client.hs +++ b/src/Network/GRPC/HighLevel/Client.hs @@ -72,8 +72,8 @@ data ClientRequest (streamType :: GRPCMethodType) request response where -- | The final field will be invoked once, and it should repeatedly -- invoke its final argument (of type @(StreamRecv response)@) -- in order to obtain the streaming response incrementally. - ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response - ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response + ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (LL.ClientCall -> MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response + ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (LL.ClientCall -> MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response data ClientResult (streamType :: GRPCMethodType) response where ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response @@ -125,13 +125,13 @@ clientRequest client (RegisteredMethod method) (ClientWriterRequest timeout meta Right parsedRsp -> ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_ clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) = - mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\m recv -> handler m (convertRecv recv)) + mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\cc m recv -> handler cc m (convertRecv recv)) where mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_) mkResponse (Right (meta_, rspCode_, details_)) = ClientReaderResponse meta_ rspCode_ details_ clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) = - mkResponse <$> LL.clientRW client method timeout meta (\_m recv send writesDone -> handler meta (convertRecv recv) (convertSend send) writesDone) + mkResponse <$> LL.clientRW client method timeout meta (\cc _m recv send writesDone -> handler cc meta (convertRecv recv) (convertSend send) writesDone) where mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_) mkResponse (Right (meta_, rspCode_, details_)) = @@ -164,7 +164,7 @@ simplifyServerStreaming :: TimeoutSeconds -- ^ Endpoint implementation (typically generated by grpc-haskell) -> request -- ^ Request payload - -> (MetadataMap -> StreamRecv response -> IO ()) + -> (LL.ClientCall -> MetadataMap -> StreamRecv response -> IO ()) -- ^ Stream handler; note that the 'StreamRecv' -- action must be called repeatedly in order to -- consume the stream diff --git a/tests/TestClient.hs b/tests/TestClient.hs index 8ef96a1..9642757 100644 --- a/tests/TestClient.hs +++ b/tests/TestClient.hs @@ -88,7 +88,7 @@ testServerStreamingCall client = testCase "Server-streaming call" $ checkResults nums recv res <- simpleServiceServerStreamingCall client $ ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty - (\_ -> checkResults nums) + (\_ _ -> checkResults nums) case res of ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) ClientReaderResponse _ sts _ -> @@ -114,7 +114,7 @@ testBiDiStreamingCall client = testCase "Bidi-streaming call" $ iterations <- randomRIO (50, 500) res <- simpleServiceBiDiStreamingCall client $ - ClientBiDiRequest 10 mempty (\_ -> handleRequests iterations) + ClientBiDiRequest 10 mempty (\_ _ -> handleRequests iterations) case res of ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) ClientBiDiResponse _ sts _ ->