mirror of
https://github.com/unclechu/gRPC-haskell.git
synced 2024-11-26 21:19:43 +01:00
Expose ClientCall in ClientReaderHandler and ClientRWHandler (#87)
This allows you to cancel the call from within the callback using `clientCallCancel`.
This commit is contained in:
parent
6e09678dc7
commit
a26497c82c
5 changed files with 20 additions and 19 deletions
|
@ -256,7 +256,7 @@ compileNormalRequestResults x =
|
||||||
-- clientReader (client side of server streaming mode)
|
-- clientReader (client side of server streaming mode)
|
||||||
|
|
||||||
-- | First parameter is initial server metadata.
|
-- | First parameter is initial server metadata.
|
||||||
type ClientReaderHandler = MetadataMap -> StreamRecv ByteString -> IO ()
|
type ClientReaderHandler = ClientCall -> MetadataMap -> StreamRecv ByteString -> IO ()
|
||||||
type ClientReaderResult = (MetadataMap, C.StatusCode, StatusDetails)
|
type ClientReaderResult = (MetadataMap, C.StatusCode, StatusDetails)
|
||||||
|
|
||||||
clientReader :: Client
|
clientReader :: Client
|
||||||
|
@ -269,13 +269,13 @@ clientReader :: Client
|
||||||
clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f =
|
clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f =
|
||||||
withClientCall cl rm tm go
|
withClientCall cl rm tm go
|
||||||
where
|
where
|
||||||
go (unsafeCC -> c) = runExceptT $ do
|
go cc@(unsafeCC -> c) = runExceptT $ do
|
||||||
void $ runOps' c cq [ OpSendInitialMetadata initMeta
|
void $ runOps' c cq [ OpSendInitialMetadata initMeta
|
||||||
, OpSendMessage body
|
, OpSendMessage body
|
||||||
, OpSendCloseFromClient
|
, OpSendCloseFromClient
|
||||||
]
|
]
|
||||||
srvMD <- recvInitialMetadata c cq
|
srvMD <- recvInitialMetadata c cq
|
||||||
liftIO $ f srvMD (streamRecvPrim c cq)
|
liftIO $ f cc srvMD (streamRecvPrim c cq)
|
||||||
recvStatusOnClient c cq
|
recvStatusOnClient c cq
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -326,7 +326,8 @@ pattern CWRFinal mmsg initMD trailMD st ds
|
||||||
-- clientRW (client side of bidirectional streaming mode)
|
-- clientRW (client side of bidirectional streaming mode)
|
||||||
|
|
||||||
type ClientRWHandler
|
type ClientRWHandler
|
||||||
= IO (Either GRPCIOError MetadataMap)
|
= ClientCall
|
||||||
|
-> IO (Either GRPCIOError MetadataMap)
|
||||||
-> StreamRecv ByteString
|
-> StreamRecv ByteString
|
||||||
-> StreamSend ByteString
|
-> StreamSend ByteString
|
||||||
-> WritesDone
|
-> WritesDone
|
||||||
|
@ -352,7 +353,7 @@ clientRW' :: Client
|
||||||
-> MetadataMap
|
-> MetadataMap
|
||||||
-> ClientRWHandler
|
-> ClientRWHandler
|
||||||
-> IO (Either GRPCIOError ClientRWResult)
|
-> IO (Either GRPCIOError ClientRWResult)
|
||||||
clientRW' (clientCQ -> cq) (unsafeCC -> c) initMeta f = runExceptT $ do
|
clientRW' (clientCQ -> cq) cc@(unsafeCC -> c) initMeta f = runExceptT $ do
|
||||||
sendInitialMetadata c cq initMeta
|
sendInitialMetadata c cq initMeta
|
||||||
|
|
||||||
-- 'mdmv' is used to synchronize between callers of 'getMD' and 'recv'
|
-- 'mdmv' is used to synchronize between callers of 'getMD' and 'recv'
|
||||||
|
@ -412,7 +413,7 @@ clientRW' (clientCQ -> cq) (unsafeCC -> c) initMeta f = runExceptT $ do
|
||||||
-- programmer.
|
-- programmer.
|
||||||
writesDone = writesDonePrim c cq
|
writesDone = writesDonePrim c cq
|
||||||
|
|
||||||
liftIO (f getMD recv send writesDone)
|
liftIO (f cc getMD recv send writesDone)
|
||||||
recvStatusOnClient c cq -- Finish()
|
recvStatusOnClient c cq -- Finish()
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
|
@ -406,7 +406,7 @@ testServerStreaming =
|
||||||
|
|
||||||
client c = do
|
client c = do
|
||||||
rm <- clientRegisterMethodServerStreaming c "/feed"
|
rm <- clientRegisterMethodServerStreaming c "/feed"
|
||||||
eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do
|
eea <- clientReader c rm 10 clientPay clientInitMD $ \_cc initMD recv -> do
|
||||||
checkMD "Server initial metadata mismatch" serverInitMD initMD
|
checkMD "Server initial metadata mismatch" serverInitMD initMD
|
||||||
forM_ pays $ \p -> recv `is` Right (Just p)
|
forM_ pays $ \p -> recv `is` Right (Just p)
|
||||||
recv `is` Right Nothing
|
recv `is` Right Nothing
|
||||||
|
@ -436,7 +436,7 @@ testServerStreamingUnregistered =
|
||||||
|
|
||||||
client c = do
|
client c = do
|
||||||
rm <- clientRegisterMethodServerStreaming c "/feed"
|
rm <- clientRegisterMethodServerStreaming c "/feed"
|
||||||
eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do
|
eea <- clientReader c rm 10 clientPay clientInitMD $ \_cc initMD recv -> do
|
||||||
checkMD "Server initial metadata mismatch" serverInitMD initMD
|
checkMD "Server initial metadata mismatch" serverInitMD initMD
|
||||||
forM_ pays $ \p -> recv `is` Right (Just p)
|
forM_ pays $ \p -> recv `is` Right (Just p)
|
||||||
recv `is` Right Nothing
|
recv `is` Right Nothing
|
||||||
|
@ -517,7 +517,7 @@ testBiDiStreaming =
|
||||||
|
|
||||||
client c = do
|
client c = do
|
||||||
rm <- clientRegisterMethodBiDiStreaming c "/bidi"
|
rm <- clientRegisterMethodBiDiStreaming c "/bidi"
|
||||||
eea <- clientRW c rm 10 clientInitMD $ \getMD recv send writesDone -> do
|
eea <- clientRW c rm 10 clientInitMD $ \_cc getMD recv send writesDone -> do
|
||||||
either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD
|
either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD
|
||||||
send "cw0" `is` Right ()
|
send "cw0" `is` Right ()
|
||||||
recv `is` Right (Just "sw0")
|
recv `is` Right (Just "sw0")
|
||||||
|
@ -553,7 +553,7 @@ testBiDiStreamingUnregistered =
|
||||||
|
|
||||||
client c = do
|
client c = do
|
||||||
rm <- clientRegisterMethodBiDiStreaming c "/bidi"
|
rm <- clientRegisterMethodBiDiStreaming c "/bidi"
|
||||||
eea <- clientRW c rm 10 clientInitMD $ \getMD recv send writesDone -> do
|
eea <- clientRW c rm 10 clientInitMD $ \_cc getMD recv send writesDone -> do
|
||||||
either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD
|
either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD
|
||||||
send "cw0" `is` Right ()
|
send "cw0" `is` Right ()
|
||||||
recv `is` Right (Just "sw0")
|
recv `is` Right (Just "sw0")
|
||||||
|
|
|
@ -42,7 +42,7 @@ doHelloSS c n = do
|
||||||
let pay = SSRqt "server streaming mode" (fromIntegral n)
|
let pay = SSRqt "server streaming mode" (fromIntegral n)
|
||||||
enc = BL.toStrict . toLazyByteString $ pay
|
enc = BL.toStrict . toLazyByteString $ pay
|
||||||
err desc e = fail $ "doHelloSS: " ++ desc ++ " error: " ++ show e
|
err desc e = fail $ "doHelloSS: " ++ desc ++ " error: " ++ show e
|
||||||
eea <- clientReader c rm n enc mempty $ \_md recv -> do
|
eea <- clientReader c rm n enc mempty $ \_cc _md recv -> do
|
||||||
n' <- flip fix (0::Int) $ \go i -> recv >>= \case
|
n' <- flip fix (0::Int) $ \go i -> recv >>= \case
|
||||||
Left e -> err "recv" e
|
Left e -> err "recv" e
|
||||||
Right Nothing -> return i
|
Right Nothing -> return i
|
||||||
|
@ -84,7 +84,7 @@ doHelloBi c n = do
|
||||||
let pay = BiRqtRpy "bidi payload"
|
let pay = BiRqtRpy "bidi payload"
|
||||||
enc = BL.toStrict . toLazyByteString $ pay
|
enc = BL.toStrict . toLazyByteString $ pay
|
||||||
err desc e = fail $ "doHelloBi: " ++ desc ++ " error: " ++ show e
|
err desc e = fail $ "doHelloBi: " ++ desc ++ " error: " ++ show e
|
||||||
eea <- clientRW c rm n mempty $ \_getMD recv send writesDone -> do
|
eea <- clientRW c rm n mempty $ \_cc _getMD recv send writesDone -> do
|
||||||
-- perform n writes on a worker thread
|
-- perform n writes on a worker thread
|
||||||
thd <- async $ do
|
thd <- async $ do
|
||||||
replicateM_ n $ send enc >>= \case
|
replicateM_ n $ send enc >>= \case
|
||||||
|
|
|
@ -72,8 +72,8 @@ data ClientRequest (streamType :: GRPCMethodType) request response where
|
||||||
-- | The final field will be invoked once, and it should repeatedly
|
-- | The final field will be invoked once, and it should repeatedly
|
||||||
-- invoke its final argument (of type @(StreamRecv response)@)
|
-- invoke its final argument (of type @(StreamRecv response)@)
|
||||||
-- in order to obtain the streaming response incrementally.
|
-- in order to obtain the streaming response incrementally.
|
||||||
ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response
|
ClientReaderRequest :: request -> TimeoutSeconds -> MetadataMap -> (LL.ClientCall -> MetadataMap -> StreamRecv response -> IO ()) -> ClientRequest 'ServerStreaming request response
|
||||||
ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response
|
ClientBiDiRequest :: TimeoutSeconds -> MetadataMap -> (LL.ClientCall -> MetadataMap -> StreamRecv response -> StreamSend request -> WritesDone -> IO ()) -> ClientRequest 'BiDiStreaming request response
|
||||||
|
|
||||||
data ClientResult (streamType :: GRPCMethodType) response where
|
data ClientResult (streamType :: GRPCMethodType) response where
|
||||||
ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response
|
ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response
|
||||||
|
@ -125,13 +125,13 @@ clientRequest client (RegisteredMethod method) (ClientWriterRequest timeout meta
|
||||||
Right parsedRsp ->
|
Right parsedRsp ->
|
||||||
ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_
|
ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_
|
||||||
clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) =
|
clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) =
|
||||||
mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\m recv -> handler m (convertRecv recv))
|
mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\cc m recv -> handler cc m (convertRecv recv))
|
||||||
where
|
where
|
||||||
mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_)
|
mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_)
|
||||||
mkResponse (Right (meta_, rspCode_, details_)) =
|
mkResponse (Right (meta_, rspCode_, details_)) =
|
||||||
ClientReaderResponse meta_ rspCode_ details_
|
ClientReaderResponse meta_ rspCode_ details_
|
||||||
clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) =
|
clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) =
|
||||||
mkResponse <$> LL.clientRW client method timeout meta (\_m recv send writesDone -> handler meta (convertRecv recv) (convertSend send) writesDone)
|
mkResponse <$> LL.clientRW client method timeout meta (\cc _m recv send writesDone -> handler cc meta (convertRecv recv) (convertSend send) writesDone)
|
||||||
where
|
where
|
||||||
mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_)
|
mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_)
|
||||||
mkResponse (Right (meta_, rspCode_, details_)) =
|
mkResponse (Right (meta_, rspCode_, details_)) =
|
||||||
|
@ -164,7 +164,7 @@ simplifyServerStreaming :: TimeoutSeconds
|
||||||
-- ^ Endpoint implementation (typically generated by grpc-haskell)
|
-- ^ Endpoint implementation (typically generated by grpc-haskell)
|
||||||
-> request
|
-> request
|
||||||
-- ^ Request payload
|
-- ^ Request payload
|
||||||
-> (MetadataMap -> StreamRecv response -> IO ())
|
-> (LL.ClientCall -> MetadataMap -> StreamRecv response -> IO ())
|
||||||
-- ^ Stream handler; note that the 'StreamRecv'
|
-- ^ Stream handler; note that the 'StreamRecv'
|
||||||
-- action must be called repeatedly in order to
|
-- action must be called repeatedly in order to
|
||||||
-- consume the stream
|
-- consume the stream
|
||||||
|
|
|
@ -88,7 +88,7 @@ testServerStreamingCall client = testCase "Server-streaming call" $
|
||||||
checkResults nums recv
|
checkResults nums recv
|
||||||
res <- simpleServiceServerStreamingCall client $
|
res <- simpleServiceServerStreamingCall client $
|
||||||
ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty
|
ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty
|
||||||
(\_ -> checkResults nums)
|
(\_ _ -> checkResults nums)
|
||||||
case res of
|
case res of
|
||||||
ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err)
|
ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err)
|
||||||
ClientReaderResponse _ sts _ ->
|
ClientReaderResponse _ sts _ ->
|
||||||
|
@ -114,7 +114,7 @@ testBiDiStreamingCall client = testCase "Bidi-streaming call" $
|
||||||
iterations <- randomRIO (50, 500)
|
iterations <- randomRIO (50, 500)
|
||||||
|
|
||||||
res <- simpleServiceBiDiStreamingCall client $
|
res <- simpleServiceBiDiStreamingCall client $
|
||||||
ClientBiDiRequest 10 mempty (\_ -> handleRequests iterations)
|
ClientBiDiRequest 10 mempty (\_ _ -> handleRequests iterations)
|
||||||
case res of
|
case res of
|
||||||
ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err)
|
ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err)
|
||||||
ClientBiDiResponse _ sts _ ->
|
ClientBiDiResponse _ sts _ ->
|
||||||
|
|
Loading…
Reference in a new issue