Merge branch 'joel/bugfix/server-request-registered-call-pluck-permission'

This commit is contained in:
Joel Stanley 2016-07-11 15:13:55 -05:00
commit 84dc076a64
6 changed files with 159 additions and 143 deletions

View file

@ -1,25 +1,21 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-} {-# OPTIONS_GHC -fno-warn-missing-signatures #-}
{-# OPTIONS_GHC -fno-warn-unused-binds #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-}
import Control.Monad import Control.Monad
import Network.GRPC.LowLevel import Network.GRPC.LowLevel
import Network.GRPC.LowLevel.Call
import qualified Network.GRPC.LowLevel.Client.Unregistered as U import qualified Network.GRPC.LowLevel.Client.Unregistered as U
import System.Environment
echoMethod = MethodName "/echo.Echo/DoEcho" echoMethod = MethodName "/echo.Echo/DoEcho"
unregistered c = do _unregistered c = U.clientRequest c echoMethod 1 "hi" mempty
U.clientRequest c echoMethod 1 "hi" mempty
registered c = do main = withGRPC $ \g ->
meth <- clientRegisterMethod c echoMethod Normal withClient g (ClientConfig "localhost" 50051 []) $ \c -> do
clientRequest c meth 1 "hi" mempty rm <- clientRegisterMethod c echoMethod Normal
replicateM_ 100000 $ clientRequest c rm 5 "hi" mempty >>= \case
run f = withGRPC $ \g -> withClient g (ClientConfig "localhost" 50051 []) $ \c ->
f c >>= \case
Left e -> error $ "Got client error: " ++ show e Left e -> error $ "Got client error: " ++ show e
_ -> return () _ -> return ()
main = replicateM_ 100 $ run $
registered

View file

@ -5,10 +5,12 @@
{-# OPTIONS_GHC -fno-warn-missing-signatures #-} {-# OPTIONS_GHC -fno-warn-missing-signatures #-}
{-# OPTIONS_GHC -fno-warn-unused-binds #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-}
import Control.Concurrent.Async (async, wait) import Control.Concurrent
import Control.Monad (forever) import Control.Concurrent.Async
import Control.Monad
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Network.GRPC.LowLevel import Network.GRPC.LowLevel
import Network.GRPC.LowLevel.Call
import qualified Network.GRPC.LowLevel.Server.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U
import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Call.Unregistered as U
@ -44,14 +46,18 @@ regMain = withGRPC $ \grpc -> do
Left x -> putStrLn $ "registered call result error: " ++ show x Left x -> putStrLn $ "registered call result error: " ++ show x
Right _ -> return () Right _ -> return ()
-- | loop to fork n times tputStrLn x = do
tid <- myThreadId
putStrLn $ "[" ++ show tid ++ "]: " ++ x
regLoop :: Server -> RegisteredMethod 'Normal -> IO () regLoop :: Server -> RegisteredMethod 'Normal -> IO ()
regLoop server method = forever $ do regLoop server method = forever $ do
-- tputStrLn "about to block on call handler"
result <- serverHandleNormalCall server method serverMeta $ result <- serverHandleNormalCall server method serverMeta $
\_call reqBody _reqMeta -> return (reqBody, serverMeta, StatusOk, \_call reqBody _reqMeta ->
StatusDetails "") return (reqBody, serverMeta, StatusOk, StatusDetails "")
case result of case result of
Left x -> putStrLn $ "registered call result error: " ++ show x Left x -> error $! "registered call result error: " ++ show x
Right _ -> return () Right _ -> return ()
regMainThreaded :: IO () regMainThreaded :: IO ()
@ -60,11 +66,10 @@ regMainThreaded = do
let methods = [(MethodName "/echo.Echo/DoEcho", Normal)] let methods = [(MethodName "/echo.Echo/DoEcho", Normal)]
withServer grpc (ServerConfig "localhost" 50051 methods []) $ \server -> do withServer grpc (ServerConfig "localhost" 50051 methods []) $ \server -> do
let method = head (normalMethods server) let method = head (normalMethods server)
tid1 <- async $ regLoop server method tids <- replicateM 7 $ async $ do tputStrLn "starting handler"
tid2 <- async $ regLoop server method regLoop server method
wait tid1 waitAnyCancel tids
wait tid2 tputStrLn "finishing"
return ()
main :: IO () main :: IO ()
main = regMainThreaded main = regMainThreaded

View file

@ -10,11 +10,12 @@
-- implementation details to both are kept in -- implementation details to both are kept in
-- `Network.GRPC.LowLevel.CompletionQueue.Internal`. -- `Network.GRPC.LowLevel.CompletionQueue.Internal`.
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-} {-# LANGUAGE TupleSections #-}
module Network.GRPC.LowLevel.CompletionQueue module Network.GRPC.LowLevel.CompletionQueue
( CompletionQueue ( CompletionQueue
@ -33,32 +34,34 @@ module Network.GRPC.LowLevel.CompletionQueue
) )
where where
import Control.Concurrent.STM (atomically, check) import Control.Concurrent.STM (atomically,
import Control.Concurrent.STM.TVar (newTVarIO, readTVar, check)
writeTVar) import Control.Concurrent.STM.TVar (newTVarIO,
import Control.Exception (bracket) readTVar,
import Control.Monad.Trans.Class (MonadTrans(lift)) writeTVar)
import Control.Monad.Trans.Except import Control.Exception (bracket)
import Control.Monad (liftM2) import Control.Monad (liftM2)
import Control.Monad.Managed import Control.Monad.Managed
import Data.IORef (newIORef) import Control.Monad.Trans.Class (MonadTrans (lift))
import Data.List (intersperse) import Control.Monad.Trans.Except
import Foreign.Marshal.Alloc (free, malloc) import Data.IORef (newIORef)
import Foreign.Ptr (Ptr, nullPtr) import Data.List (intersperse)
import Foreign.Storable (Storable, peek) import Foreign.Marshal.Alloc (free, malloc)
import qualified Network.GRPC.Unsafe as C import Foreign.Ptr (Ptr, nullPtr)
import qualified Network.GRPC.Unsafe.Constants as C import Foreign.Storable (Storable, peek)
import qualified Network.GRPC.Unsafe.Metadata as C
import qualified Network.GRPC.Unsafe.Op as C
import qualified Network.GRPC.Unsafe.Time as C
import System.Clock (getTime, Clock(..))
import System.Info (os)
import System.Timeout (timeout)
import Network.GRPC.LowLevel.Call import Network.GRPC.LowLevel.Call
import Network.GRPC.LowLevel.GRPC
import Network.GRPC.LowLevel.CompletionQueue.Internal import Network.GRPC.LowLevel.CompletionQueue.Internal
import qualified Network.GRPC.Unsafe.ByteBuffer as C import Network.GRPC.LowLevel.GRPC
import qualified Network.GRPC.Unsafe as C
import qualified Network.GRPC.Unsafe.ByteBuffer as C
import qualified Network.GRPC.Unsafe.Constants as C
import qualified Network.GRPC.Unsafe.Metadata as C
import qualified Network.GRPC.Unsafe.Op as C
import qualified Network.GRPC.Unsafe.Time as C
import System.Clock (Clock (..),
getTime)
import System.Info (os)
import System.Timeout (timeout)
withCompletionQueue :: GRPC -> (CompletionQueue -> IO a) -> IO a withCompletionQueue :: GRPC -> (CompletionQueue -> IO a) -> IO a
withCompletionQueue grpc = bracket (createCompletionQueue grpc) withCompletionQueue grpc = bracket (createCompletionQueue grpc)
@ -94,8 +97,9 @@ startBatch cq@CompletionQueue{..} call opArray opArraySize tag =
shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ()) shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ())
shutdownCompletionQueue CompletionQueue{..} = do shutdownCompletionQueue CompletionQueue{..} = do
atomically $ writeTVar shuttingDown True atomically $ writeTVar shuttingDown True
atomically $ readTVar currentPushers >>= \x -> check (x == 0) atomically $ do
atomically $ readTVar currentPluckers >>= \x -> check (x == 0) readTVar currentPushers >>= check . (==0)
readTVar currentPluckers >>= check . (==0)
--drain the queue --drain the queue
C.grpcCompletionQueueShutdown unsafeCQ C.grpcCompletionQueueShutdown unsafeCQ
loopRes <- timeout (5*10^(6::Int)) drainLoop loopRes <- timeout (5*10^(6::Int)) drainLoop
@ -140,32 +144,34 @@ serverRequestCall :: C.Server
-> CompletionQueue -> CompletionQueue
-> RegisteredMethod mt -> RegisteredMethod mt
-> IO (Either GRPCIOError ServerCall) -> IO (Either GRPCIOError ServerCall)
serverRequestCall s cq@CompletionQueue{.. } RegisteredMethod{..} = serverRequestCall s cq@CompletionQueue{.. } rm =
-- NB: The method type dictates whether or not a payload is present, according -- NB: The method type dictates whether or not a payload is present, according
-- to the payloadHandling function. We do not allocate a buffer for the -- to the payloadHandling function. We do not allocate a buffer for the
-- payload when it is not present. -- payload when it is not present.
withPermission Push cq . with allocs $ \(dead, call, pay, meta) -> do withPermission Push cq . with allocs $ \(dead, call, pay, meta) -> do
md <- peek meta dbug "pre-pluck block"
tag <- newTag cq withPermission Pluck cq $ do
dbug $ "tag is " ++ show tag md <- peek meta
ce <- C.grpcServerRequestRegisteredCall s methodHandle call tag <- newTag cq
dead md pay unsafeCQ unsafeCQ tag dbug $ "got pluck permission, registering call for tag=" ++ show tag
dbug $ "callError: " ++ show ce ce <- C.grpcServerRequestRegisteredCall s (methodHandle rm) call dead md pay unsafeCQ unsafeCQ tag
runExceptT $ case ce of runExceptT $ case ce of
C.CallOk -> do C.CallOk -> do
ExceptT $ do ExceptT $ do
r <- pluck cq tag Nothing r <- pluck' cq tag Nothing
dbug $ "pluck finished:" ++ show r dbug $ "pluck' finished:" ++ show r
return r return r
lift $ lift $
ServerCall ServerCall
<$> peek call <$> peek call
<*> C.getAllMetadataArray md <*> C.getAllMetadataArray md
<*> (if havePay then toBS pay else return Nothing) <*> (if havePay then toBS pay else return Nothing)
<*> convertDeadline dead <*> liftM2 (+) (getTime Monotonic) (C.timeSpec <$> peek dead)
-- gRPC gives us a deadline that is just a delta, so we convert it -- gRPC gives us a deadline that is just a delta, so we convert
-- to a proper deadline. -- it to a proper deadline.
_ -> throwE (GRPCIOCallError ce) _ -> do
lift $ dbug $ "Throwing callError: " ++ show ce
throwE (GRPCIOCallError ce)
where where
allocs = (,,,) <$> ptr <*> ptr <*> pay <*> md allocs = (,,,) <$> ptr <*> ptr <*> pay <*> md
where where
@ -174,7 +180,7 @@ serverRequestCall s cq@CompletionQueue{.. } RegisteredMethod{..} =
ptr :: forall a. Storable a => Managed (Ptr a) ptr :: forall a. Storable a => Managed (Ptr a)
ptr = managed (bracket malloc free) ptr = managed (bracket malloc free)
dbug = grpcDebug . ("serverRequestCall(R): " ++) dbug = grpcDebug . ("serverRequestCall(R): " ++)
havePay = payloadHandling methodType /= C.SrmPayloadNone havePay = payloadHandling (methodType rm) /= C.SrmPayloadNone
toBS p = peek p >>= \bb@(C.ByteBuffer rawPtr) -> toBS p = peek p >>= \bb@(C.ByteBuffer rawPtr) ->
if | rawPtr == nullPtr -> return Nothing if | rawPtr == nullPtr -> return Nothing
| otherwise -> Just <$> C.copyByteBufferToByteString bb | otherwise -> Just <$> C.copyByteBufferToByteString bb

View file

@ -3,8 +3,11 @@
module Network.GRPC.LowLevel.CompletionQueue.Internal where module Network.GRPC.LowLevel.CompletionQueue.Internal where
import Control.Concurrent.STM (atomically, retry) import Control.Concurrent.STM (atomically, retry)
import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar) import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar,
writeTVar)
import Control.Monad.IO.Class
import Control.Exception (bracket) import Control.Exception (bracket)
import Control.Monad
import Data.IORef (IORef, atomicModifyIORef') import Data.IORef (IORef, atomicModifyIORef')
import Foreign.Ptr (nullPtr, plusPtr) import Foreign.Ptr (nullPtr, plusPtr)
import Network.GRPC.LowLevel.GRPC import Network.GRPC.LowLevel.GRPC
@ -74,26 +77,22 @@ newTag CompletionQueue{..} = do
-- | Safely brackets an operation that pushes work onto or plucks results from -- | Safely brackets an operation that pushes work onto or plucks results from
-- the given 'CompletionQueue'. -- the given 'CompletionQueue'.
withPermission :: CQOpType withPermission :: CQOpType
-> CompletionQueue -> CompletionQueue
-> IO (Either GRPCIOError a) -> IO (Either GRPCIOError a)
-> IO (Either GRPCIOError a) -> IO (Either GRPCIOError a)
withPermission op cq f = withPermission op cq act = bracket acquire release $ \gotResource ->
bracket acquire release doOp if gotResource then act else return (Left GRPCIOShutdown)
where acquire = atomically $ do where
isShuttingDown <- readTVar (shuttingDown cq) acquire = atomically $ do
if isShuttingDown isShuttingDown <- readTVar (shuttingDown cq)
then return False unless isShuttingDown $ do
else do currCount <- readTVar $ getCount op cq currCount <- readTVar (getCount op cq)
if currCount < getLimit op if currCount < getLimit op
then modifyTVar' (getCount op cq) (+1) >> return True then writeTVar (getCount op cq) (currCount + 1)
else retry else retry
doOp gotResource = if gotResource return (not isShuttingDown)
then f release gotResource = when gotResource $
else return $ Left GRPCIOShutdown atomically $ modifyTVar' (getCount op cq) (subtract 1)
release gotResource =
if gotResource
then atomically $ modifyTVar' (getCount op cq) (subtract 1)
else return ()
-- | Waits for the given number of seconds for the given tag to appear on the -- | Waits for the given number of seconds for the given tag to appear on the
-- completion queue. Throws 'GRPCIOShutdown' if the completion queue is shutting -- completion queue. Throws 'GRPCIOShutdown' if the completion queue is shutting
@ -103,17 +102,21 @@ withPermission op cq f =
-- 'serverRequestCall', this will block forever unless a timeout is given. -- 'serverRequestCall', this will block forever unless a timeout is given.
pluck :: CompletionQueue -> C.Tag -> Maybe TimeoutSeconds pluck :: CompletionQueue -> C.Tag -> Maybe TimeoutSeconds
-> IO (Either GRPCIOError ()) -> IO (Either GRPCIOError ())
pluck cq@CompletionQueue{..} tag waitSeconds = do pluck cq@CompletionQueue{..} tag mwait = do
grpcDebug $ "pluck: called with tag: " ++ show tag grpcDebug $ "pluck: called with tag=" ++ show tag ++ ",mwait=" ++ show mwait
++ " and wait: " ++ show waitSeconds withPermission Pluck cq $ pluck' cq tag mwait
withPermission Pluck cq $
case waitSeconds of -- Variant of pluck' which assumes pluck permission has been granted.
Nothing -> C.withInfiniteDeadline go pluck' :: CompletionQueue
Just seconds -> C.withDeadlineSeconds seconds go -> C.Tag
where go deadline = do -> Maybe TimeoutSeconds
ev <- C.grpcCompletionQueuePluck unsafeCQ tag deadline C.reserved -> IO (Either GRPCIOError ())
grpcDebug $ "pluck: finished. Event: " ++ show ev pluck' CompletionQueue{..} tag mwait =
return $ if isEventSuccessful ev then Right () else eventToError ev maybe C.withInfiniteDeadline C.withDeadlineSeconds mwait $ \dead -> do
grpcDebug $ "pluck: blocking on grpc_completion_queue_pluck for tag=" ++ show tag
ev <- C.grpcCompletionQueuePluck unsafeCQ tag dead C.reserved
grpcDebug $ "pluck finished: " ++ show ev
return $ if isEventSuccessful ev then Right () else eventToError ev
-- | Translate 'C.Event' to an error. The caller is responsible for ensuring -- | Translate 'C.Event' to an error. The caller is responsible for ensuring
-- that the event actually corresponds to an error condition; a successful event -- that the event actually corresponds to an error condition; a successful event
@ -133,9 +136,9 @@ maxWorkPushers :: Int
maxWorkPushers = 100 --TODO: figure out what this should be. maxWorkPushers = 100 --TODO: figure out what this should be.
getCount :: CQOpType -> CompletionQueue -> TVar Int getCount :: CQOpType -> CompletionQueue -> TVar Int
getCount Push = currentPushers getCount Push = currentPushers
getCount Pluck = currentPluckers getCount Pluck = currentPluckers
getLimit :: CQOpType -> Int getLimit :: CQOpType -> Int
getLimit Push = maxWorkPushers getLimit Push = maxWorkPushers
getLimit Pluck = C.maxCompletionQueuePluckers getLimit Pluck = C.maxCompletionQueuePluckers

View file

@ -36,36 +36,37 @@ serverRequestCall server cq@CompletionQueue{..} =
withPermission Push cq $ withPermission Push cq $
bracket malloc free $ \callPtr -> bracket malloc free $ \callPtr ->
C.withMetadataArrayPtr $ \metadataArrayPtr -> C.withMetadataArrayPtr $ \metadataArrayPtr ->
C.withCallDetails $ \callDetails -> do C.withCallDetails $ \callDetails ->
grpcDebug $ "serverRequestCall: callPtr is " ++ show callPtr withPermission Pluck cq $ do
metadataArray <- peek metadataArrayPtr grpcDebug $ "serverRequestCall: callPtr is " ++ show callPtr
tag <- newTag cq metadataArray <- peek metadataArrayPtr
callError <- C.grpcServerRequestCall server callPtr callDetails tag <- newTag cq
metadataArray unsafeCQ unsafeCQ tag callError <- C.grpcServerRequestCall server callPtr callDetails
grpcDebug $ "serverRequestCall: callError was " ++ show callError metadataArray unsafeCQ unsafeCQ tag
if callError /= C.CallOk grpcDebug $ "serverRequestCall: callError was " ++ show callError
then do grpcDebug "serverRequestCall: got call error; cleaning up." if callError /= C.CallOk
return $ Left $ GRPCIOCallError callError then do grpcDebug "serverRequestCall: got call error; cleaning up."
else do pluckResult <- pluck cq tag Nothing return $ Left $ GRPCIOCallError callError
grpcDebug $ "serverRequestCall: pluckResult was " else do pluckResult <- pluck cq tag Nothing
++ show pluckResult grpcDebug $ "serverRequestCall: pluckResult was "
case pluckResult of ++ show pluckResult
Left x -> do case pluckResult of
grpcDebug "serverRequestCall: pluck error." Left x -> do
return $ Left x grpcDebug "serverRequestCall: pluck error."
Right () -> do return $ Left x
rawCall <- peek callPtr Right () -> do
metadata <- C.getAllMetadataArray metadataArray rawCall <- peek callPtr
deadline <- getDeadline callDetails metadata <- C.getAllMetadataArray metadataArray
method <- getMethod callDetails deadline <- getDeadline callDetails
host <- getHost callDetails method <- getMethod callDetails
let call = U.ServerCall rawCall host <- getHost callDetails
metadata let call = U.ServerCall rawCall
Nothing metadata
deadline Nothing
method deadline
host method
return $ Right call host
return $ Right call
where getDeadline callDetails = do where getDeadline callDetails = do
C.timeSpec <$> C.callDetailsGetDeadline callDetails C.timeSpec <$> C.callDetailsGetDeadline callDetails

View file

@ -4,7 +4,7 @@
module Network.GRPC.LowLevel.GRPC where module Network.GRPC.LowLevel.GRPC where
import Control.Concurrent (threadDelay) import Control.Concurrent (threadDelay, myThreadId)
import Control.Exception import Control.Exception
import Data.String (IsString) import Data.String (IsString)
import qualified Data.ByteString as B import qualified Data.ByteString as B
@ -59,12 +59,17 @@ throwIfCallError x = Left $ GRPCIOCallError x
grpcDebug :: String -> IO () grpcDebug :: String -> IO ()
{-# INLINE grpcDebug #-} {-# INLINE grpcDebug #-}
#ifdef DEBUG #ifdef DEBUG
grpcDebug str = do tid <- myThreadId grpcDebug = grpcDebug'
putStrLn $ (show tid) ++ ": " ++ str
#else #else
grpcDebug _ = return () grpcDebug _ = return ()
#endif #endif
grpcDebug' :: String -> IO ()
{-# INLINE grpcDebug' #-}
grpcDebug' str = do
tid <- myThreadId
putStrLn $ "[" ++ show tid ++ "]: " ++ str
threadDelaySecs :: Int -> IO () threadDelaySecs :: Int -> IO ()
threadDelaySecs = threadDelay . (* 10^(6::Int)) threadDelaySecs = threadDelay . (* 10^(6::Int))