diff --git a/examples/echo/echo-client/Main.hs b/examples/echo/echo-client/Main.hs index 1ed6e90..e2e0c2a 100644 --- a/examples/echo/echo-client/Main.hs +++ b/examples/echo/echo-client/Main.hs @@ -1,25 +1,21 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_GHC -fno-warn-missing-signatures #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} import Control.Monad import Network.GRPC.LowLevel +import Network.GRPC.LowLevel.Call import qualified Network.GRPC.LowLevel.Client.Unregistered as U +import System.Environment echoMethod = MethodName "/echo.Echo/DoEcho" -unregistered c = do - U.clientRequest c echoMethod 1 "hi" mempty +_unregistered c = U.clientRequest c echoMethod 1 "hi" mempty -registered c = do - meth <- clientRegisterMethod c echoMethod Normal - clientRequest c meth 1 "hi" mempty - -run f = withGRPC $ \g -> withClient g (ClientConfig "localhost" 50051 []) $ \c -> - f c >>= \case +main = withGRPC $ \g -> + withClient g (ClientConfig "localhost" 50051 []) $ \c -> do + rm <- clientRegisterMethod c echoMethod Normal + replicateM_ 100000 $ clientRequest c rm 5 "hi" mempty >>= \case Left e -> error $ "Got client error: " ++ show e _ -> return () - -main = replicateM_ 100 $ run $ - registered diff --git a/examples/echo/echo-server/Main.hs b/examples/echo/echo-server/Main.hs index aa3f5c5..8c97856 100644 --- a/examples/echo/echo-server/Main.hs +++ b/examples/echo/echo-server/Main.hs @@ -5,10 +5,12 @@ {-# OPTIONS_GHC -fno-warn-missing-signatures #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} -import Control.Concurrent.Async (async, wait) -import Control.Monad (forever) +import Control.Concurrent +import Control.Concurrent.Async +import Control.Monad import Data.ByteString (ByteString) import Network.GRPC.LowLevel +import Network.GRPC.LowLevel.Call import qualified Network.GRPC.LowLevel.Server.Unregistered as U import qualified Network.GRPC.LowLevel.Call.Unregistered as U @@ -44,14 +46,18 @@ regMain = withGRPC $ \grpc -> do Left x -> putStrLn $ "registered call result error: " ++ show x Right _ -> return () --- | loop to fork n times +tputStrLn x = do + tid <- myThreadId + putStrLn $ "[" ++ show tid ++ "]: " ++ x + regLoop :: Server -> RegisteredMethod 'Normal -> IO () regLoop server method = forever $ do +-- tputStrLn "about to block on call handler" result <- serverHandleNormalCall server method serverMeta $ - \_call reqBody _reqMeta -> return (reqBody, serverMeta, StatusOk, - StatusDetails "") + \_call reqBody _reqMeta -> + return (reqBody, serverMeta, StatusOk, StatusDetails "") case result of - Left x -> putStrLn $ "registered call result error: " ++ show x + Left x -> error $! "registered call result error: " ++ show x Right _ -> return () regMainThreaded :: IO () @@ -60,11 +66,10 @@ regMainThreaded = do let methods = [(MethodName "/echo.Echo/DoEcho", Normal)] withServer grpc (ServerConfig "localhost" 50051 methods []) $ \server -> do let method = head (normalMethods server) - tid1 <- async $ regLoop server method - tid2 <- async $ regLoop server method - wait tid1 - wait tid2 - return () + tids <- replicateM 7 $ async $ do tputStrLn "starting handler" + regLoop server method + waitAnyCancel tids + tputStrLn "finishing" main :: IO () main = regMainThreaded diff --git a/src/Network/GRPC/LowLevel/CompletionQueue.hs b/src/Network/GRPC/LowLevel/CompletionQueue.hs index 85b4b3e..91159dd 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue.hs @@ -10,11 +10,12 @@ -- implementation details to both are kept in -- `Network.GRPC.LowLevel.CompletionQueue.Internal`. -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TupleSections #-} module Network.GRPC.LowLevel.CompletionQueue ( CompletionQueue @@ -33,31 +34,34 @@ module Network.GRPC.LowLevel.CompletionQueue ) where -import Control.Concurrent.STM (atomically, check) -import Control.Concurrent.STM.TVar (newTVarIO, readTVar, - writeTVar) -import Control.Exception (bracket) -import Control.Monad.Trans.Class (MonadTrans(lift)) -import Control.Monad.Trans.Except -import Control.Monad (liftM2) +import Control.Concurrent.STM (atomically, + check) +import Control.Concurrent.STM.TVar (newTVarIO, + readTVar, + writeTVar) +import Control.Exception (bracket) +import Control.Monad (liftM2) import Control.Monad.Managed -import Data.IORef (newIORef) -import Data.List (intersperse) -import Foreign.Marshal.Alloc (free, malloc) -import Foreign.Ptr (Ptr, nullPtr) -import Foreign.Storable (Storable, peek) -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.Constants as C -import qualified Network.GRPC.Unsafe.Metadata as C -import qualified Network.GRPC.Unsafe.Op as C -import qualified Network.GRPC.Unsafe.Time as C -import System.Clock (getTime, Clock(..)) -import System.Timeout (timeout) +import Control.Monad.Trans.Class (MonadTrans (lift)) +import Control.Monad.Trans.Except +import Data.IORef (newIORef) +import Data.List (intersperse) +import Foreign.Marshal.Alloc (free, malloc) +import Foreign.Ptr (Ptr, nullPtr) +import Foreign.Storable (Storable, peek) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Constants as C +import qualified Network.GRPC.Unsafe.Metadata as C +import qualified Network.GRPC.Unsafe.Op as C +import qualified Network.GRPC.Unsafe.Time as C +import System.Clock (Clock (..), + getTime) +import System.Timeout (timeout) import Network.GRPC.LowLevel.Call -import Network.GRPC.LowLevel.GRPC import Network.GRPC.LowLevel.CompletionQueue.Internal -import qualified Network.GRPC.Unsafe.ByteBuffer as C +import Network.GRPC.LowLevel.GRPC +import qualified Network.GRPC.Unsafe.ByteBuffer as C withCompletionQueue :: GRPC -> (CompletionQueue -> IO a) -> IO a withCompletionQueue grpc = bracket (createCompletionQueue grpc) @@ -81,7 +85,9 @@ startBatch cq@CompletionQueue{..} call opArray opArraySize tag = withPermission Push cq $ fmap throwIfCallError $ do grpcDebug $ "startBatch: calling grpc_call_start_batch with pointers: " ++ show call ++ " " ++ show opArray + grpcDebug "About to enter grpc_call_start_batch" res <- C.grpcCallStartBatch call opArray opArraySize tag C.reserved + grpcDebug "Returned from grpc_call_start_batch" grpcDebug "startBatch: grpc_call_start_batch call returned." return res @@ -93,8 +99,9 @@ startBatch cq@CompletionQueue{..} call opArray opArraySize tag = shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ()) shutdownCompletionQueue CompletionQueue{..} = do atomically $ writeTVar shuttingDown True - atomically $ readTVar currentPushers >>= \x -> check (x == 0) - atomically $ readTVar currentPluckers >>= \x -> check (x == 0) + atomically $ do + readTVar currentPushers >>= check . (==0) + readTVar currentPluckers >>= check . (==0) --drain the queue C.grpcCompletionQueueShutdown unsafeCQ loopRes <- timeout (5*10^(6::Int)) drainLoop @@ -139,32 +146,34 @@ serverRequestCall :: C.Server -> CompletionQueue -> RegisteredMethod mt -> IO (Either GRPCIOError ServerCall) -serverRequestCall s cq@CompletionQueue{.. } RegisteredMethod{..} = +serverRequestCall s cq@CompletionQueue{.. } rm = -- NB: The method type dictates whether or not a payload is present, according -- to the payloadHandling function. We do not allocate a buffer for the -- payload when it is not present. withPermission Push cq . with allocs $ \(dead, call, pay, meta) -> do - md <- peek meta - tag <- newTag cq - dbug $ "tag is " ++ show tag - ce <- C.grpcServerRequestRegisteredCall s methodHandle call - dead md pay unsafeCQ unsafeCQ tag - dbug $ "callError: " ++ show ce - runExceptT $ case ce of - C.CallOk -> do - ExceptT $ do - r <- pluck cq tag Nothing - dbug $ "pluck finished:" ++ show r - return r - lift $ - ServerCall - <$> peek call - <*> C.getAllMetadataArray md - <*> (if havePay then toBS pay else return Nothing) - <*> liftM2 (+) (getTime Monotonic) (C.timeSpec <$> peek dead) - -- gRPC gives us a deadline that is just a delta, so we convert it - -- to a proper deadline. - _ -> throwE (GRPCIOCallError ce) + dbug "pre-pluck block" + withPermission Pluck cq $ do + md <- peek meta + tag <- newTag cq + dbug $ "got pluck permission, registering call for tag=" ++ show tag + ce <- C.grpcServerRequestRegisteredCall s (methodHandle rm) call dead md pay unsafeCQ unsafeCQ tag + runExceptT $ case ce of + C.CallOk -> do + ExceptT $ do + r <- pluck' cq tag Nothing + dbug $ "pluck' finished:" ++ show r + return r + lift $ + ServerCall + <$> peek call + <*> C.getAllMetadataArray md + <*> (if havePay then toBS pay else return Nothing) + <*> liftM2 (+) (getTime Monotonic) (C.timeSpec <$> peek dead) + -- gRPC gives us a deadline that is just a delta, so we convert + -- it to a proper deadline. + _ -> do + lift $ dbug $ "Throwing callError: " ++ show ce + throwE (GRPCIOCallError ce) where allocs = (,,,) <$> ptr <*> ptr <*> pay <*> md where @@ -173,7 +182,7 @@ serverRequestCall s cq@CompletionQueue{.. } RegisteredMethod{..} = ptr :: forall a. Storable a => Managed (Ptr a) ptr = managed (bracket malloc free) dbug = grpcDebug . ("serverRequestCall(R): " ++) - havePay = payloadHandling methodType /= C.SrmPayloadNone + havePay = payloadHandling (methodType rm) /= C.SrmPayloadNone toBS p = peek p >>= \bb@(C.ByteBuffer rawPtr) -> if | rawPtr == nullPtr -> return Nothing | otherwise -> Just <$> C.copyByteBufferToByteString bb diff --git a/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs b/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs index 2e64cd4..af87452 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs @@ -3,8 +3,11 @@ module Network.GRPC.LowLevel.CompletionQueue.Internal where import Control.Concurrent.STM (atomically, retry) -import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar) +import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar, + writeTVar) +import Control.Monad.IO.Class import Control.Exception (bracket) +import Control.Monad import Data.IORef (IORef, atomicModifyIORef') import Foreign.Ptr (nullPtr, plusPtr) import Network.GRPC.LowLevel.GRPC @@ -12,6 +15,8 @@ import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.Constants as C import qualified Network.GRPC.Unsafe.Time as C +import Debug.Trace + -- NOTE: the concurrency requirements for a CompletionQueue are a little -- complicated. There are two read operations: next and pluck. We can either -- call next on a CQ or call pluck up to 'maxCompletionQueuePluckers' times @@ -74,26 +79,27 @@ newTag CompletionQueue{..} = do -- | Safely brackets an operation that pushes work onto or plucks results from -- the given 'CompletionQueue'. withPermission :: CQOpType - -> CompletionQueue - -> IO (Either GRPCIOError a) - -> IO (Either GRPCIOError a) -withPermission op cq f = - bracket acquire release doOp - where acquire = atomically $ do - isShuttingDown <- readTVar (shuttingDown cq) - if isShuttingDown - then return False - else do currCount <- readTVar $ getCount op cq - if currCount < getLimit op - then modifyTVar' (getCount op cq) (+1) >> return True - else retry - doOp gotResource = if gotResource - then f - else return $ Left GRPCIOShutdown - release gotResource = - if gotResource - then atomically $ modifyTVar' (getCount op cq) (subtract 1) - else return () + -> CompletionQueue + -> IO (Either GRPCIOError a) + -> IO (Either GRPCIOError a) +withPermission op cq act = bracket acquire release doOp + where + acquire = atomically $ do + isShuttingDown <- readTVar (shuttingDown cq) + if isShuttingDown + then return False + else do currCount <- readTVar (getCount op cq) + if currCount < getLimit op + then do + writeTVar (getCount op cq) (currCount+1) + return True + else retry + + doOp gotResource = + if gotResource then act else return (Left GRPCIOShutdown) + + release gotResource = when gotResource $ + atomically $ modifyTVar' (getCount op cq) (subtract 1) -- | Waits for the given number of seconds for the given tag to appear on the -- completion queue. Throws 'GRPCIOShutdown' if the completion queue is shutting @@ -103,17 +109,21 @@ withPermission op cq f = -- 'serverRequestCall', this will block forever unless a timeout is given. pluck :: CompletionQueue -> C.Tag -> Maybe TimeoutSeconds -> IO (Either GRPCIOError ()) -pluck cq@CompletionQueue{..} tag waitSeconds = do - grpcDebug $ "pluck: called with tag: " ++ show tag - ++ " and wait: " ++ show waitSeconds - withPermission Pluck cq $ - case waitSeconds of - Nothing -> C.withInfiniteDeadline go - Just seconds -> C.withDeadlineSeconds seconds go - where go deadline = do - ev <- C.grpcCompletionQueuePluck unsafeCQ tag deadline C.reserved - grpcDebug $ "pluck: finished. Event: " ++ show ev - return $ if isEventSuccessful ev then Right () else eventToError ev +pluck cq@CompletionQueue{..} tag mwait = do + grpcDebug $ "pluck: called with tag=" ++ show tag ++ ",mwait=" ++ show mwait + withPermission Pluck cq $ pluck' cq tag mwait + +-- Variant of pluck' which assumes pluck permission has been granted. +pluck' :: CompletionQueue + -> C.Tag + -> Maybe TimeoutSeconds + -> IO (Either GRPCIOError ()) +pluck' CompletionQueue{..} tag mwait = + maybe C.withInfiniteDeadline C.withDeadlineSeconds mwait $ \dead -> do + grpcDebug $ "pluck: blocking on grpc_completion_queue_pluck for tag=" ++ show tag + ev <- C.grpcCompletionQueuePluck unsafeCQ tag dead C.reserved + grpcDebug $ "pluck finished: " ++ show ev + return $ if isEventSuccessful ev then Right () else eventToError ev -- | Translate 'C.Event' to an error. The caller is responsible for ensuring -- that the event actually corresponds to an error condition; a successful event @@ -133,9 +143,9 @@ maxWorkPushers :: Int maxWorkPushers = 100 --TODO: figure out what this should be. getCount :: CQOpType -> CompletionQueue -> TVar Int -getCount Push = currentPushers +getCount Push = currentPushers getCount Pluck = currentPluckers getLimit :: CQOpType -> Int -getLimit Push = maxWorkPushers +getLimit Push = maxWorkPushers getLimit Pluck = C.maxCompletionQueuePluckers diff --git a/src/Network/GRPC/LowLevel/GRPC.hs b/src/Network/GRPC/LowLevel/GRPC.hs index bb1999f..6fdb448 100644 --- a/src/Network/GRPC/LowLevel/GRPC.hs +++ b/src/Network/GRPC/LowLevel/GRPC.hs @@ -4,7 +4,7 @@ module Network.GRPC.LowLevel.GRPC where -import Control.Concurrent (threadDelay) +import Control.Concurrent (threadDelay, myThreadId) import Control.Exception import Data.String (IsString) import qualified Data.ByteString as B @@ -59,12 +59,17 @@ throwIfCallError x = Left $ GRPCIOCallError x grpcDebug :: String -> IO () {-# INLINE grpcDebug #-} #ifdef DEBUG -grpcDebug str = do tid <- myThreadId - putStrLn $ (show tid) ++ ": " ++ str +grpcDebug = grpcDebug' #else grpcDebug _ = return () #endif +grpcDebug' :: String -> IO () +{-# INLINE grpcDebug' #-} +grpcDebug' str = do + tid <- myThreadId + putStrLn $ "[" ++ show tid ++ "]: " ++ str + threadDelaySecs :: Int -> IO () threadDelaySecs = threadDelay . (* 10^(6::Int))