Expose ClientCall in ClientReaderHandler and ClientRWHandler (#87)

This allows you to cancel the call from within the callback using
`clientCallCancel`.
This commit is contained in:
Moritz Kiefer 2019-08-22 17:53:41 +02:00 committed by Gabriel Gonzalez
parent 6e09678dc7
commit a26497c82c
5 changed files with 20 additions and 19 deletions

View file

@ -256,7 +256,7 @@ compileNormalRequestResults x =
-- clientReader (client side of server streaming mode) -- clientReader (client side of server streaming mode)
-- | First parameter is initial server metadata. -- | First parameter is initial server metadata.
type ClientReaderHandler = MetadataMap -> StreamRecv ByteString -> IO () type ClientReaderHandler = ClientCall -> MetadataMap -> StreamRecv ByteString -> IO ()
type ClientReaderResult = (MetadataMap, C.StatusCode, StatusDetails) type ClientReaderResult = (MetadataMap, C.StatusCode, StatusDetails)
clientReader :: Client clientReader :: Client
@ -269,13 +269,13 @@ clientReader :: Client
clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f = clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f =
withClientCall cl rm tm go withClientCall cl rm tm go
where where
go (unsafeCC -> c) = runExceptT $ do go cc@(unsafeCC -> c) = runExceptT $ do
void $ runOps' c cq [ OpSendInitialMetadata initMeta void $ runOps' c cq [ OpSendInitialMetadata initMeta
, OpSendMessage body , OpSendMessage body
, OpSendCloseFromClient , OpSendCloseFromClient
] ]
srvMD <- recvInitialMetadata c cq srvMD <- recvInitialMetadata c cq
liftIO $ f srvMD (streamRecvPrim c cq) liftIO $ f cc srvMD (streamRecvPrim c cq)
recvStatusOnClient c cq recvStatusOnClient c cq
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -326,7 +326,8 @@ pattern CWRFinal mmsg initMD trailMD st ds
-- clientRW (client side of bidirectional streaming mode) -- clientRW (client side of bidirectional streaming mode)
type ClientRWHandler type ClientRWHandler
= IO (Either GRPCIOError MetadataMap) = ClientCall
-> IO (Either GRPCIOError MetadataMap)
-> StreamRecv ByteString -> StreamRecv ByteString
-> StreamSend ByteString -> StreamSend ByteString
-> WritesDone -> WritesDone
@ -352,7 +353,7 @@ clientRW' :: Client
-> MetadataMap -> MetadataMap
-> ClientRWHandler -> ClientRWHandler
-> IO (Either GRPCIOError ClientRWResult) -> IO (Either GRPCIOError ClientRWResult)
clientRW' (clientCQ -> cq) (unsafeCC -> c) initMeta f = runExceptT $ do clientRW' (clientCQ -> cq) cc@(unsafeCC -> c) initMeta f = runExceptT $ do
sendInitialMetadata c cq initMeta sendInitialMetadata c cq initMeta
-- 'mdmv' is used to synchronize between callers of 'getMD' and 'recv' -- 'mdmv' is used to synchronize between callers of 'getMD' and 'recv'
@ -412,7 +413,7 @@ clientRW' (clientCQ -> cq) (unsafeCC -> c) initMeta f = runExceptT $ do
-- programmer. -- programmer.
writesDone = writesDonePrim c cq writesDone = writesDonePrim c cq
liftIO (f getMD recv send writesDone) liftIO (f cc getMD recv send writesDone)
recvStatusOnClient c cq -- Finish() recvStatusOnClient c cq -- Finish()
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------

View file

@ -406,7 +406,7 @@ testServerStreaming =
client c = do client c = do
rm <- clientRegisterMethodServerStreaming c "/feed" rm <- clientRegisterMethodServerStreaming c "/feed"
eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do eea <- clientReader c rm 10 clientPay clientInitMD $ \_cc initMD recv -> do
checkMD "Server initial metadata mismatch" serverInitMD initMD checkMD "Server initial metadata mismatch" serverInitMD initMD
forM_ pays $ \p -> recv `is` Right (Just p) forM_ pays $ \p -> recv `is` Right (Just p)
recv `is` Right Nothing recv `is` Right Nothing
@ -436,7 +436,7 @@ testServerStreamingUnregistered =
client c = do client c = do
rm <- clientRegisterMethodServerStreaming c "/feed" rm <- clientRegisterMethodServerStreaming c "/feed"
eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do eea <- clientReader c rm 10 clientPay clientInitMD $ \_cc initMD recv -> do
checkMD "Server initial metadata mismatch" serverInitMD initMD checkMD "Server initial metadata mismatch" serverInitMD initMD
forM_ pays $ \p -> recv `is` Right (Just p) forM_ pays $ \p -> recv `is` Right (Just p)
recv `is` Right Nothing recv `is` Right Nothing
@ -517,7 +517,7 @@ testBiDiStreaming =
client c = do client c = do
rm <- clientRegisterMethodBiDiStreaming c "/bidi" rm <- clientRegisterMethodBiDiStreaming c "/bidi"
eea <- clientRW c rm 10 clientInitMD $ \getMD recv send writesDone -> do eea <- clientRW c rm 10 clientInitMD $ \_cc getMD recv send writesDone -> do
either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD
send "cw0" `is` Right () send "cw0" `is` Right ()
recv `is` Right (Just "sw0") recv `is` Right (Just "sw0")
@ -553,7 +553,7 @@ testBiDiStreamingUnregistered =
client c = do client c = do
rm <- clientRegisterMethodBiDiStreaming c "/bidi" rm <- clientRegisterMethodBiDiStreaming c "/bidi"
eea <- clientRW c rm 10 clientInitMD $ \getMD recv send writesDone -> do eea <- clientRW c rm 10 clientInitMD $ \_cc getMD recv send writesDone -> do
either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD
send "cw0" `is` Right () send "cw0" `is` Right ()
recv `is` Right (Just "sw0") recv `is` Right (Just "sw0")

View file

@ -42,7 +42,7 @@ doHelloSS c n = do
let pay = SSRqt "server streaming mode" (fromIntegral n) let pay = SSRqt "server streaming mode" (fromIntegral n)
enc = BL.toStrict . toLazyByteString $ pay enc = BL.toStrict . toLazyByteString $ pay
err desc e = fail $ "doHelloSS: " ++ desc ++ " error: " ++ show e err desc e = fail $ "doHelloSS: " ++ desc ++ " error: " ++ show e
eea <- clientReader c rm n enc mempty $ \_md recv -> do eea <- clientReader c rm n enc mempty $ \_cc _md recv -> do
n' <- flip fix (0::Int) $ \go i -> recv >>= \case n' <- flip fix (0::Int) $ \go i -> recv >>= \case
Left e -> err "recv" e Left e -> err "recv" e
Right Nothing -> return i Right Nothing -> return i
@ -84,7 +84,7 @@ doHelloBi c n = do
let pay = BiRqtRpy "bidi payload" let pay = BiRqtRpy "bidi payload"
enc = BL.toStrict . toLazyByteString $ pay enc = BL.toStrict . toLazyByteString $ pay
err desc e = fail $ "doHelloBi: " ++ desc ++ " error: " ++ show e err desc e = fail $ "doHelloBi: " ++ desc ++ " error: " ++ show e
eea <- clientRW c rm n mempty $ \_getMD recv send writesDone -> do eea <- clientRW c rm n mempty $ \_cc _getMD recv send writesDone -> do
-- perform n writes on a worker thread -- perform n writes on a worker thread
thd <- async $ do thd <- async $ do
replicateM_ n $ send enc >>= \case replicateM_ n $ send enc >>= \case

View file

@ -72,8 +72,8 @@ data ClientRequest (streamType :: GRPCMethodType) request response where
-- | The final field will be invoked once, and it should repeatedly -- | The final field will be invoked once, and it should repeatedly
-- invoke its final argument (of type @(StreamRecv response)@) -- invoke its final argument (of type @(StreamRecv response)@)
-- in order to obtain the streaming response incrementally. -- in order to obtain the streaming response incrementally.
ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (LL.ClientCall -> MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response
ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (LL.ClientCall -> MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response
data ClientResult (streamType :: GRPCMethodType) response where data ClientResult (streamType :: GRPCMethodType) response where
ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response
@ -125,13 +125,13 @@ clientRequest client (RegisteredMethod method) (ClientWriterRequest timeout meta
Right parsedRsp -> Right parsedRsp ->
ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_ ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_
clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) = clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) =
mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\m recv -> handler m (convertRecv recv)) mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\cc m recv -> handler cc m (convertRecv recv))
where where
mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_) mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_)
mkResponse (Right (meta_, rspCode_, details_)) = mkResponse (Right (meta_, rspCode_, details_)) =
ClientReaderResponse meta_ rspCode_ details_ ClientReaderResponse meta_ rspCode_ details_
clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) = clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) =
mkResponse <$> LL.clientRW client method timeout meta (\_m recv send writesDone -> handler meta (convertRecv recv) (convertSend send) writesDone) mkResponse <$> LL.clientRW client method timeout meta (\cc _m recv send writesDone -> handler cc meta (convertRecv recv) (convertSend send) writesDone)
where where
mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_) mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_)
mkResponse (Right (meta_, rspCode_, details_)) = mkResponse (Right (meta_, rspCode_, details_)) =
@ -164,7 +164,7 @@ simplifyServerStreaming :: TimeoutSeconds
-- ^ Endpoint implementation (typically generated by grpc-haskell) -- ^ Endpoint implementation (typically generated by grpc-haskell)
-> request -> request
-- ^ Request payload -- ^ Request payload
-> (MetadataMap -> StreamRecv response -> IO ()) -> (LL.ClientCall -> MetadataMap -> StreamRecv response -> IO ())
-- ^ Stream handler; note that the 'StreamRecv' -- ^ Stream handler; note that the 'StreamRecv'
-- action must be called repeatedly in order to -- action must be called repeatedly in order to
-- consume the stream -- consume the stream

View file

@ -88,7 +88,7 @@ testServerStreamingCall client = testCase "Server-streaming call" $
checkResults nums recv checkResults nums recv
res <- simpleServiceServerStreamingCall client $ res <- simpleServiceServerStreamingCall client $
ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty
(\_ -> checkResults nums) (\_ _ -> checkResults nums)
case res of case res of
ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err)
ClientReaderResponse _ sts _ -> ClientReaderResponse _ sts _ ->
@ -114,7 +114,7 @@ testBiDiStreamingCall client = testCase "Bidi-streaming call" $
iterations <- randomRIO (50, 500) iterations <- randomRIO (50, 500)
res <- simpleServiceBiDiStreamingCall client $ res <- simpleServiceBiDiStreamingCall client $
ClientBiDiRequest 10 mempty (\_ -> handleRequests iterations) ClientBiDiRequest 10 mempty (\_ _ -> handleRequests iterations)
case res of case res of
ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err)
ClientBiDiResponse _ sts _ -> ClientBiDiResponse _ sts _ ->