mirror of
https://github.com/unclechu/gRPC-haskell.git
synced 2024-12-25 03:09:44 +01:00
Merge branch 'joel/bugfix/server-request-registered-call-pluck-permission'
This commit is contained in:
commit
84dc076a64
6 changed files with 159 additions and 143 deletions
|
@ -1,25 +1,21 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unused-binds #-}
|
||||
|
||||
import Control.Monad
|
||||
import Network.GRPC.LowLevel
|
||||
import Network.GRPC.LowLevel.Call
|
||||
import qualified Network.GRPC.LowLevel.Client.Unregistered as U
|
||||
import System.Environment
|
||||
|
||||
echoMethod = MethodName "/echo.Echo/DoEcho"
|
||||
|
||||
unregistered c = do
|
||||
U.clientRequest c echoMethod 1 "hi" mempty
|
||||
_unregistered c = U.clientRequest c echoMethod 1 "hi" mempty
|
||||
|
||||
registered c = do
|
||||
meth <- clientRegisterMethod c echoMethod Normal
|
||||
clientRequest c meth 1 "hi" mempty
|
||||
|
||||
run f = withGRPC $ \g -> withClient g (ClientConfig "localhost" 50051 []) $ \c ->
|
||||
f c >>= \case
|
||||
main = withGRPC $ \g ->
|
||||
withClient g (ClientConfig "localhost" 50051 []) $ \c -> do
|
||||
rm <- clientRegisterMethod c echoMethod Normal
|
||||
replicateM_ 100000 $ clientRequest c rm 5 "hi" mempty >>= \case
|
||||
Left e -> error $ "Got client error: " ++ show e
|
||||
_ -> return ()
|
||||
|
||||
main = replicateM_ 100 $ run $
|
||||
registered
|
||||
|
|
|
@ -5,10 +5,12 @@
|
|||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unused-binds #-}
|
||||
|
||||
import Control.Concurrent.Async (async, wait)
|
||||
import Control.Monad (forever)
|
||||
import Control.Concurrent
|
||||
import Control.Concurrent.Async
|
||||
import Control.Monad
|
||||
import Data.ByteString (ByteString)
|
||||
import Network.GRPC.LowLevel
|
||||
import Network.GRPC.LowLevel.Call
|
||||
import qualified Network.GRPC.LowLevel.Server.Unregistered as U
|
||||
import qualified Network.GRPC.LowLevel.Call.Unregistered as U
|
||||
|
||||
|
@ -44,14 +46,18 @@ regMain = withGRPC $ \grpc -> do
|
|||
Left x -> putStrLn $ "registered call result error: " ++ show x
|
||||
Right _ -> return ()
|
||||
|
||||
-- | loop to fork n times
|
||||
tputStrLn x = do
|
||||
tid <- myThreadId
|
||||
putStrLn $ "[" ++ show tid ++ "]: " ++ x
|
||||
|
||||
regLoop :: Server -> RegisteredMethod 'Normal -> IO ()
|
||||
regLoop server method = forever $ do
|
||||
-- tputStrLn "about to block on call handler"
|
||||
result <- serverHandleNormalCall server method serverMeta $
|
||||
\_call reqBody _reqMeta -> return (reqBody, serverMeta, StatusOk,
|
||||
StatusDetails "")
|
||||
\_call reqBody _reqMeta ->
|
||||
return (reqBody, serverMeta, StatusOk, StatusDetails "")
|
||||
case result of
|
||||
Left x -> putStrLn $ "registered call result error: " ++ show x
|
||||
Left x -> error $! "registered call result error: " ++ show x
|
||||
Right _ -> return ()
|
||||
|
||||
regMainThreaded :: IO ()
|
||||
|
@ -60,11 +66,10 @@ regMainThreaded = do
|
|||
let methods = [(MethodName "/echo.Echo/DoEcho", Normal)]
|
||||
withServer grpc (ServerConfig "localhost" 50051 methods []) $ \server -> do
|
||||
let method = head (normalMethods server)
|
||||
tid1 <- async $ regLoop server method
|
||||
tid2 <- async $ regLoop server method
|
||||
wait tid1
|
||||
wait tid2
|
||||
return ()
|
||||
tids <- replicateM 7 $ async $ do tputStrLn "starting handler"
|
||||
regLoop server method
|
||||
waitAnyCancel tids
|
||||
tputStrLn "finishing"
|
||||
|
||||
main :: IO ()
|
||||
main = regMainThreaded
|
||||
|
|
|
@ -10,11 +10,12 @@
|
|||
-- implementation details to both are kept in
|
||||
-- `Network.GRPC.LowLevel.CompletionQueue.Internal`.
|
||||
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiWayIf #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiWayIf #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Network.GRPC.LowLevel.CompletionQueue
|
||||
( CompletionQueue
|
||||
|
@ -33,32 +34,34 @@ module Network.GRPC.LowLevel.CompletionQueue
|
|||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM (atomically, check)
|
||||
import Control.Concurrent.STM.TVar (newTVarIO, readTVar,
|
||||
writeTVar)
|
||||
import Control.Exception (bracket)
|
||||
import Control.Monad.Trans.Class (MonadTrans(lift))
|
||||
import Control.Monad.Trans.Except
|
||||
import Control.Monad (liftM2)
|
||||
import Control.Concurrent.STM (atomically,
|
||||
check)
|
||||
import Control.Concurrent.STM.TVar (newTVarIO,
|
||||
readTVar,
|
||||
writeTVar)
|
||||
import Control.Exception (bracket)
|
||||
import Control.Monad (liftM2)
|
||||
import Control.Monad.Managed
|
||||
import Data.IORef (newIORef)
|
||||
import Data.List (intersperse)
|
||||
import Foreign.Marshal.Alloc (free, malloc)
|
||||
import Foreign.Ptr (Ptr, nullPtr)
|
||||
import Foreign.Storable (Storable, peek)
|
||||
import qualified Network.GRPC.Unsafe as C
|
||||
import qualified Network.GRPC.Unsafe.Constants as C
|
||||
import qualified Network.GRPC.Unsafe.Metadata as C
|
||||
import qualified Network.GRPC.Unsafe.Op as C
|
||||
import qualified Network.GRPC.Unsafe.Time as C
|
||||
import System.Clock (getTime, Clock(..))
|
||||
import System.Info (os)
|
||||
import System.Timeout (timeout)
|
||||
|
||||
import Control.Monad.Trans.Class (MonadTrans (lift))
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.IORef (newIORef)
|
||||
import Data.List (intersperse)
|
||||
import Foreign.Marshal.Alloc (free, malloc)
|
||||
import Foreign.Ptr (Ptr, nullPtr)
|
||||
import Foreign.Storable (Storable, peek)
|
||||
import Network.GRPC.LowLevel.Call
|
||||
import Network.GRPC.LowLevel.GRPC
|
||||
import Network.GRPC.LowLevel.CompletionQueue.Internal
|
||||
import qualified Network.GRPC.Unsafe.ByteBuffer as C
|
||||
import Network.GRPC.LowLevel.GRPC
|
||||
import qualified Network.GRPC.Unsafe as C
|
||||
import qualified Network.GRPC.Unsafe.ByteBuffer as C
|
||||
import qualified Network.GRPC.Unsafe.Constants as C
|
||||
import qualified Network.GRPC.Unsafe.Metadata as C
|
||||
import qualified Network.GRPC.Unsafe.Op as C
|
||||
import qualified Network.GRPC.Unsafe.Time as C
|
||||
import System.Clock (Clock (..),
|
||||
getTime)
|
||||
import System.Info (os)
|
||||
import System.Timeout (timeout)
|
||||
|
||||
withCompletionQueue :: GRPC -> (CompletionQueue -> IO a) -> IO a
|
||||
withCompletionQueue grpc = bracket (createCompletionQueue grpc)
|
||||
|
@ -94,8 +97,9 @@ startBatch cq@CompletionQueue{..} call opArray opArraySize tag =
|
|||
shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ())
|
||||
shutdownCompletionQueue CompletionQueue{..} = do
|
||||
atomically $ writeTVar shuttingDown True
|
||||
atomically $ readTVar currentPushers >>= \x -> check (x == 0)
|
||||
atomically $ readTVar currentPluckers >>= \x -> check (x == 0)
|
||||
atomically $ do
|
||||
readTVar currentPushers >>= check . (==0)
|
||||
readTVar currentPluckers >>= check . (==0)
|
||||
--drain the queue
|
||||
C.grpcCompletionQueueShutdown unsafeCQ
|
||||
loopRes <- timeout (5*10^(6::Int)) drainLoop
|
||||
|
@ -140,32 +144,34 @@ serverRequestCall :: C.Server
|
|||
-> CompletionQueue
|
||||
-> RegisteredMethod mt
|
||||
-> IO (Either GRPCIOError ServerCall)
|
||||
serverRequestCall s cq@CompletionQueue{.. } RegisteredMethod{..} =
|
||||
serverRequestCall s cq@CompletionQueue{.. } rm =
|
||||
-- NB: The method type dictates whether or not a payload is present, according
|
||||
-- to the payloadHandling function. We do not allocate a buffer for the
|
||||
-- payload when it is not present.
|
||||
withPermission Push cq . with allocs $ \(dead, call, pay, meta) -> do
|
||||
md <- peek meta
|
||||
tag <- newTag cq
|
||||
dbug $ "tag is " ++ show tag
|
||||
ce <- C.grpcServerRequestRegisteredCall s methodHandle call
|
||||
dead md pay unsafeCQ unsafeCQ tag
|
||||
dbug $ "callError: " ++ show ce
|
||||
runExceptT $ case ce of
|
||||
C.CallOk -> do
|
||||
ExceptT $ do
|
||||
r <- pluck cq tag Nothing
|
||||
dbug $ "pluck finished:" ++ show r
|
||||
return r
|
||||
lift $
|
||||
ServerCall
|
||||
<$> peek call
|
||||
<*> C.getAllMetadataArray md
|
||||
<*> (if havePay then toBS pay else return Nothing)
|
||||
<*> convertDeadline dead
|
||||
-- gRPC gives us a deadline that is just a delta, so we convert it
|
||||
-- to a proper deadline.
|
||||
_ -> throwE (GRPCIOCallError ce)
|
||||
dbug "pre-pluck block"
|
||||
withPermission Pluck cq $ do
|
||||
md <- peek meta
|
||||
tag <- newTag cq
|
||||
dbug $ "got pluck permission, registering call for tag=" ++ show tag
|
||||
ce <- C.grpcServerRequestRegisteredCall s (methodHandle rm) call dead md pay unsafeCQ unsafeCQ tag
|
||||
runExceptT $ case ce of
|
||||
C.CallOk -> do
|
||||
ExceptT $ do
|
||||
r <- pluck' cq tag Nothing
|
||||
dbug $ "pluck' finished:" ++ show r
|
||||
return r
|
||||
lift $
|
||||
ServerCall
|
||||
<$> peek call
|
||||
<*> C.getAllMetadataArray md
|
||||
<*> (if havePay then toBS pay else return Nothing)
|
||||
<*> liftM2 (+) (getTime Monotonic) (C.timeSpec <$> peek dead)
|
||||
-- gRPC gives us a deadline that is just a delta, so we convert
|
||||
-- it to a proper deadline.
|
||||
_ -> do
|
||||
lift $ dbug $ "Throwing callError: " ++ show ce
|
||||
throwE (GRPCIOCallError ce)
|
||||
where
|
||||
allocs = (,,,) <$> ptr <*> ptr <*> pay <*> md
|
||||
where
|
||||
|
@ -174,7 +180,7 @@ serverRequestCall s cq@CompletionQueue{.. } RegisteredMethod{..} =
|
|||
ptr :: forall a. Storable a => Managed (Ptr a)
|
||||
ptr = managed (bracket malloc free)
|
||||
dbug = grpcDebug . ("serverRequestCall(R): " ++)
|
||||
havePay = payloadHandling methodType /= C.SrmPayloadNone
|
||||
havePay = payloadHandling (methodType rm) /= C.SrmPayloadNone
|
||||
toBS p = peek p >>= \bb@(C.ByteBuffer rawPtr) ->
|
||||
if | rawPtr == nullPtr -> return Nothing
|
||||
| otherwise -> Just <$> C.copyByteBufferToByteString bb
|
||||
|
|
|
@ -3,8 +3,11 @@
|
|||
module Network.GRPC.LowLevel.CompletionQueue.Internal where
|
||||
|
||||
import Control.Concurrent.STM (atomically, retry)
|
||||
import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar)
|
||||
import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar,
|
||||
writeTVar)
|
||||
import Control.Monad.IO.Class
|
||||
import Control.Exception (bracket)
|
||||
import Control.Monad
|
||||
import Data.IORef (IORef, atomicModifyIORef')
|
||||
import Foreign.Ptr (nullPtr, plusPtr)
|
||||
import Network.GRPC.LowLevel.GRPC
|
||||
|
@ -74,26 +77,22 @@ newTag CompletionQueue{..} = do
|
|||
-- | Safely brackets an operation that pushes work onto or plucks results from
|
||||
-- the given 'CompletionQueue'.
|
||||
withPermission :: CQOpType
|
||||
-> CompletionQueue
|
||||
-> IO (Either GRPCIOError a)
|
||||
-> IO (Either GRPCIOError a)
|
||||
withPermission op cq f =
|
||||
bracket acquire release doOp
|
||||
where acquire = atomically $ do
|
||||
isShuttingDown <- readTVar (shuttingDown cq)
|
||||
if isShuttingDown
|
||||
then return False
|
||||
else do currCount <- readTVar $ getCount op cq
|
||||
if currCount < getLimit op
|
||||
then modifyTVar' (getCount op cq) (+1) >> return True
|
||||
else retry
|
||||
doOp gotResource = if gotResource
|
||||
then f
|
||||
else return $ Left GRPCIOShutdown
|
||||
release gotResource =
|
||||
if gotResource
|
||||
then atomically $ modifyTVar' (getCount op cq) (subtract 1)
|
||||
else return ()
|
||||
-> CompletionQueue
|
||||
-> IO (Either GRPCIOError a)
|
||||
-> IO (Either GRPCIOError a)
|
||||
withPermission op cq act = bracket acquire release $ \gotResource ->
|
||||
if gotResource then act else return (Left GRPCIOShutdown)
|
||||
where
|
||||
acquire = atomically $ do
|
||||
isShuttingDown <- readTVar (shuttingDown cq)
|
||||
unless isShuttingDown $ do
|
||||
currCount <- readTVar (getCount op cq)
|
||||
if currCount < getLimit op
|
||||
then writeTVar (getCount op cq) (currCount + 1)
|
||||
else retry
|
||||
return (not isShuttingDown)
|
||||
release gotResource = when gotResource $
|
||||
atomically $ modifyTVar' (getCount op cq) (subtract 1)
|
||||
|
||||
-- | Waits for the given number of seconds for the given tag to appear on the
|
||||
-- completion queue. Throws 'GRPCIOShutdown' if the completion queue is shutting
|
||||
|
@ -103,17 +102,21 @@ withPermission op cq f =
|
|||
-- 'serverRequestCall', this will block forever unless a timeout is given.
|
||||
pluck :: CompletionQueue -> C.Tag -> Maybe TimeoutSeconds
|
||||
-> IO (Either GRPCIOError ())
|
||||
pluck cq@CompletionQueue{..} tag waitSeconds = do
|
||||
grpcDebug $ "pluck: called with tag: " ++ show tag
|
||||
++ " and wait: " ++ show waitSeconds
|
||||
withPermission Pluck cq $
|
||||
case waitSeconds of
|
||||
Nothing -> C.withInfiniteDeadline go
|
||||
Just seconds -> C.withDeadlineSeconds seconds go
|
||||
where go deadline = do
|
||||
ev <- C.grpcCompletionQueuePluck unsafeCQ tag deadline C.reserved
|
||||
grpcDebug $ "pluck: finished. Event: " ++ show ev
|
||||
return $ if isEventSuccessful ev then Right () else eventToError ev
|
||||
pluck cq@CompletionQueue{..} tag mwait = do
|
||||
grpcDebug $ "pluck: called with tag=" ++ show tag ++ ",mwait=" ++ show mwait
|
||||
withPermission Pluck cq $ pluck' cq tag mwait
|
||||
|
||||
-- Variant of pluck' which assumes pluck permission has been granted.
|
||||
pluck' :: CompletionQueue
|
||||
-> C.Tag
|
||||
-> Maybe TimeoutSeconds
|
||||
-> IO (Either GRPCIOError ())
|
||||
pluck' CompletionQueue{..} tag mwait =
|
||||
maybe C.withInfiniteDeadline C.withDeadlineSeconds mwait $ \dead -> do
|
||||
grpcDebug $ "pluck: blocking on grpc_completion_queue_pluck for tag=" ++ show tag
|
||||
ev <- C.grpcCompletionQueuePluck unsafeCQ tag dead C.reserved
|
||||
grpcDebug $ "pluck finished: " ++ show ev
|
||||
return $ if isEventSuccessful ev then Right () else eventToError ev
|
||||
|
||||
-- | Translate 'C.Event' to an error. The caller is responsible for ensuring
|
||||
-- that the event actually corresponds to an error condition; a successful event
|
||||
|
@ -133,9 +136,9 @@ maxWorkPushers :: Int
|
|||
maxWorkPushers = 100 --TODO: figure out what this should be.
|
||||
|
||||
getCount :: CQOpType -> CompletionQueue -> TVar Int
|
||||
getCount Push = currentPushers
|
||||
getCount Push = currentPushers
|
||||
getCount Pluck = currentPluckers
|
||||
|
||||
getLimit :: CQOpType -> Int
|
||||
getLimit Push = maxWorkPushers
|
||||
getLimit Push = maxWorkPushers
|
||||
getLimit Pluck = C.maxCompletionQueuePluckers
|
||||
|
|
|
@ -36,36 +36,37 @@ serverRequestCall server cq@CompletionQueue{..} =
|
|||
withPermission Push cq $
|
||||
bracket malloc free $ \callPtr ->
|
||||
C.withMetadataArrayPtr $ \metadataArrayPtr ->
|
||||
C.withCallDetails $ \callDetails -> do
|
||||
grpcDebug $ "serverRequestCall: callPtr is " ++ show callPtr
|
||||
metadataArray <- peek metadataArrayPtr
|
||||
tag <- newTag cq
|
||||
callError <- C.grpcServerRequestCall server callPtr callDetails
|
||||
metadataArray unsafeCQ unsafeCQ tag
|
||||
grpcDebug $ "serverRequestCall: callError was " ++ show callError
|
||||
if callError /= C.CallOk
|
||||
then do grpcDebug "serverRequestCall: got call error; cleaning up."
|
||||
return $ Left $ GRPCIOCallError callError
|
||||
else do pluckResult <- pluck cq tag Nothing
|
||||
grpcDebug $ "serverRequestCall: pluckResult was "
|
||||
++ show pluckResult
|
||||
case pluckResult of
|
||||
Left x -> do
|
||||
grpcDebug "serverRequestCall: pluck error."
|
||||
return $ Left x
|
||||
Right () -> do
|
||||
rawCall <- peek callPtr
|
||||
metadata <- C.getAllMetadataArray metadataArray
|
||||
deadline <- getDeadline callDetails
|
||||
method <- getMethod callDetails
|
||||
host <- getHost callDetails
|
||||
let call = U.ServerCall rawCall
|
||||
metadata
|
||||
Nothing
|
||||
deadline
|
||||
method
|
||||
host
|
||||
return $ Right call
|
||||
C.withCallDetails $ \callDetails ->
|
||||
withPermission Pluck cq $ do
|
||||
grpcDebug $ "serverRequestCall: callPtr is " ++ show callPtr
|
||||
metadataArray <- peek metadataArrayPtr
|
||||
tag <- newTag cq
|
||||
callError <- C.grpcServerRequestCall server callPtr callDetails
|
||||
metadataArray unsafeCQ unsafeCQ tag
|
||||
grpcDebug $ "serverRequestCall: callError was " ++ show callError
|
||||
if callError /= C.CallOk
|
||||
then do grpcDebug "serverRequestCall: got call error; cleaning up."
|
||||
return $ Left $ GRPCIOCallError callError
|
||||
else do pluckResult <- pluck cq tag Nothing
|
||||
grpcDebug $ "serverRequestCall: pluckResult was "
|
||||
++ show pluckResult
|
||||
case pluckResult of
|
||||
Left x -> do
|
||||
grpcDebug "serverRequestCall: pluck error."
|
||||
return $ Left x
|
||||
Right () -> do
|
||||
rawCall <- peek callPtr
|
||||
metadata <- C.getAllMetadataArray metadataArray
|
||||
deadline <- getDeadline callDetails
|
||||
method <- getMethod callDetails
|
||||
host <- getHost callDetails
|
||||
let call = U.ServerCall rawCall
|
||||
metadata
|
||||
Nothing
|
||||
deadline
|
||||
method
|
||||
host
|
||||
return $ Right call
|
||||
|
||||
where getDeadline callDetails = do
|
||||
C.timeSpec <$> C.callDetailsGetDeadline callDetails
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
module Network.GRPC.LowLevel.GRPC where
|
||||
|
||||
import Control.Concurrent (threadDelay)
|
||||
import Control.Concurrent (threadDelay, myThreadId)
|
||||
import Control.Exception
|
||||
import Data.String (IsString)
|
||||
import qualified Data.ByteString as B
|
||||
|
@ -59,12 +59,17 @@ throwIfCallError x = Left $ GRPCIOCallError x
|
|||
grpcDebug :: String -> IO ()
|
||||
{-# INLINE grpcDebug #-}
|
||||
#ifdef DEBUG
|
||||
grpcDebug str = do tid <- myThreadId
|
||||
putStrLn $ (show tid) ++ ": " ++ str
|
||||
grpcDebug = grpcDebug'
|
||||
#else
|
||||
grpcDebug _ = return ()
|
||||
#endif
|
||||
|
||||
grpcDebug' :: String -> IO ()
|
||||
{-# INLINE grpcDebug' #-}
|
||||
grpcDebug' str = do
|
||||
tid <- myThreadId
|
||||
putStrLn $ "[" ++ show tid ++ "]: " ++ str
|
||||
|
||||
threadDelaySecs :: Int -> IO ()
|
||||
threadDelaySecs = threadDelay . (* 10^(6::Int))
|
||||
|
||||
|
|
Loading…
Reference in a new issue