fix up handler type, tweak ServerCall record names (#42)

* fix up handler type, tweak ServerCall record names

* remove ' from handler types, use LL suffix for low-level handlers

* fix all build warnings
This commit is contained in:
Connor Clark 2016-07-14 16:33:56 -07:00
parent e4a28e9e4b
commit 99e6f0652d
14 changed files with 108 additions and 146 deletions

View file

@ -13,7 +13,6 @@ import Data.Word
import GHC.Generics (Generic) import GHC.Generics (Generic)
import Network.GRPC.LowLevel import Network.GRPC.LowLevel
import qualified Network.GRPC.LowLevel.Client.Unregistered as U import qualified Network.GRPC.LowLevel.Client.Unregistered as U
import Proto3.Wire.Decode (ParseError)
echoMethod = MethodName "/echo.Echo/DoEcho" echoMethod = MethodName "/echo.Echo/DoEcho"
addMethod = MethodName "/echo.Add/DoAdd" addMethod = MethodName "/echo.Add/DoAdd"

View file

@ -47,8 +47,7 @@ regMain = withGRPC $ \grpc -> do
forever $ do forever $ do
let method = head (normalMethods server) let method = head (normalMethods server)
result <- serverHandleNormalCall server method serverMeta $ result <- serverHandleNormalCall server method serverMeta $
\_call reqBody _reqMeta -> return (reqBody, serverMeta, StatusOk, \call -> return (payload call, serverMeta, StatusOk, StatusDetails "")
StatusDetails "")
case result of case result of
Left x -> putStrLn $ "registered call result error: " ++ show x Left x -> putStrLn $ "registered call result error: " ++ show x
Right _ -> return () Right _ -> return ()
@ -61,8 +60,8 @@ regLoop :: Server -> RegisteredMethod 'Normal -> IO ()
regLoop server method = forever $ do regLoop server method = forever $ do
-- tputStrLn "about to block on call handler" -- tputStrLn "about to block on call handler"
result <- serverHandleNormalCall server method serverMeta $ result <- serverHandleNormalCall server method serverMeta $
\_call reqBody _reqMeta -> \call ->
return (reqBody, serverMeta, StatusOk, StatusDetails "") return (payload call, serverMeta, StatusOk, StatusDetails "")
case result of case result of
Left x -> error $! "registered call result error: " ++ show x Left x -> error $! "registered call result error: " ++ show x
Right _ -> return () Right _ -> return ()
@ -75,7 +74,7 @@ regMainThreaded = do
let method = head (normalMethods server) let method = head (normalMethods server)
tids <- replicateM 7 $ async $ do tputStrLn "starting handler" tids <- replicateM 7 $ async $ do tputStrLn "starting handler"
regLoop server method regLoop server method
waitAnyCancel tids _ <- waitAnyCancel tids
tputStrLn "finishing" tputStrLn "finishing"
-- NB: If you change these, make sure to change them in the client as well. -- NB: If you change these, make sure to change them in the client as well.
@ -86,9 +85,9 @@ instance Message EchoRequest
echoHandler :: Handler 'Normal echoHandler :: Handler 'Normal
echoHandler = echoHandler =
UnaryHandler "/echo.Echo/DoEcho" $ UnaryHandler "/echo.Echo/DoEcho" $
\_c body m -> do \call -> do
return ( body :: EchoRequest return ( payload call :: EchoRequest
, m , metadata call
, StatusOk , StatusOk
, StatusDetails "" , StatusDetails ""
) )
@ -104,12 +103,13 @@ instance Message AddResponse
addHandler :: Handler 'Normal addHandler :: Handler 'Normal
addHandler = addHandler =
UnaryHandler "/echo.Add/DoAdd" $ UnaryHandler "/echo.Add/DoAdd" $
\_c b m -> do \c -> do
--tputStrLn $ "UnaryHandler for DoAdd hit, b=" ++ show b --tputStrLn $ "UnaryHandler for DoAdd hit, b=" ++ show b
let b = payload c
print (addX b) print (addX b)
print (addY b) print (addY b)
return ( AddResponse $ addX b + addY b return ( AddResponse $ addX b + addY b
, m , metadata c
, StatusOk , StatusOk
, StatusDetails "" , StatusDetails ""
) )

View file

@ -15,45 +15,41 @@ import qualified Data.ByteString.Lazy as BL
import Data.Protobuf.Wire.Class import Data.Protobuf.Wire.Class
import Network.GRPC.LowLevel import Network.GRPC.LowLevel
import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Call.Unregistered as U
import Network.GRPC.LowLevel.GRPC
import qualified Network.GRPC.LowLevel.Server.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U
type ServerHandler' a b = type ServerHandler a b =
forall c . ServerCall a
ServerCall c
-> a
-> MetadataMap
-> IO (b, MetadataMap, StatusCode, StatusDetails) -> IO (b, MetadataMap, StatusCode, StatusDetails)
convertServerHandler :: (Message a, Message b) convertServerHandler :: (Message a, Message b)
=> ServerHandler' a b => ServerHandler a b
-> ServerHandler -> ServerHandlerLL
convertServerHandler f c bs m = case fromByteString bs of convertServerHandler f c = case fromByteString (payload c) of
Left x -> error $ "Failed to deserialize message: " ++ show x Left x -> error $ "Failed to deserialize message: " ++ show x
Right x -> do (y, tm, sc, sd) <- f c x m Right x -> do (y, tm, sc, sd) <- f (fmap (const x) c)
return (toBS y, tm, sc, sd) return (toBS y, tm, sc, sd)
type ServerReaderHandler' a b = type ServerReaderHandler a b =
ServerCall () ServerCall ()
-> StreamRecv a -> StreamRecv a
-> Streaming (Maybe b, MetadataMap, StatusCode, StatusDetails) -> Streaming (Maybe b, MetadataMap, StatusCode, StatusDetails)
convertServerReaderHandler :: (Message a, Message b) convertServerReaderHandler :: (Message a, Message b)
=> ServerReaderHandler' a b => ServerReaderHandler a b
-> ServerReaderHandler -> ServerReaderHandlerLL
convertServerReaderHandler f c recv = convertServerReaderHandler f c recv =
serialize <$> f c (convertRecv recv) serialize <$> f c (convertRecv recv)
where where
serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd) serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd)
type ServerWriterHandler' a b = type ServerWriterHandler a b =
ServerCall a ServerCall a
-> StreamSend b -> StreamSend b
-> Streaming (MetadataMap, StatusCode, StatusDetails) -> Streaming (MetadataMap, StatusCode, StatusDetails)
convertServerWriterHandler :: (Message a, Message b) => convertServerWriterHandler :: (Message a, Message b) =>
ServerWriterHandler' a b ServerWriterHandler a b
-> ServerWriterHandler -> ServerWriterHandlerLL
convertServerWriterHandler f c send = convertServerWriterHandler f c send =
f (convert <$> c) (convertSend send) f (convert <$> c) (convertSend send)
where where
@ -61,15 +57,15 @@ convertServerWriterHandler f c send =
Left x -> error $ "deserialization error: " ++ show x -- TODO FIXME Left x -> error $ "deserialization error: " ++ show x -- TODO FIXME
Right x -> x Right x -> x
type ServerRWHandler' a b = type ServerRWHandler a b =
ServerCall () ServerCall ()
-> StreamRecv a -> StreamRecv a
-> StreamSend b -> StreamSend b
-> Streaming (MetadataMap, StatusCode, StatusDetails) -> Streaming (MetadataMap, StatusCode, StatusDetails)
convertServerRWHandler :: (Message a, Message b) convertServerRWHandler :: (Message a, Message b)
=> ServerRWHandler' a b => ServerRWHandler a b
-> ServerRWHandler -> ServerRWHandlerLL
convertServerRWHandler f c recv send = convertServerRWHandler f c recv send =
f c (convertRecv recv) (convertSend send) f c (convertRecv recv) (convertSend send)
@ -93,25 +89,25 @@ data Handler (a :: GRPCMethodType) where
UnaryHandler UnaryHandler
:: (Message c, Message d) :: (Message c, Message d)
=> MethodName => MethodName
-> ServerHandler' c d -> ServerHandler c d
-> Handler 'Normal -> Handler 'Normal
ClientStreamHandler ClientStreamHandler
:: (Message c, Message d) :: (Message c, Message d)
=> MethodName => MethodName
-> ServerReaderHandler' c d -> ServerReaderHandler c d
-> Handler 'ClientStreaming -> Handler 'ClientStreaming
ServerStreamHandler ServerStreamHandler
:: (Message c, Message d) :: (Message c, Message d)
=> MethodName => MethodName
-> ServerWriterHandler' c d -> ServerWriterHandler c d
-> Handler 'ServerStreaming -> Handler 'ServerStreaming
BiDiStreamHandler BiDiStreamHandler
:: (Message c, Message d) :: (Message c, Message d)
=> MethodName => MethodName
-> ServerRWHandler' c d -> ServerRWHandler c d
-> Handler 'BiDiStreaming -> Handler 'BiDiStreaming
data AnyHandler = forall (a :: GRPCMethodType) . AnyHandler (Handler a) data AnyHandler = forall (a :: GRPCMethodType) . AnyHandler (Handler a)

View file

@ -6,16 +6,11 @@
module Network.GRPC.HighLevel.Server.Unregistered where module Network.GRPC.HighLevel.Server.Unregistered where
import Control.Applicative ((<|>))
import Control.Concurrent.Async
import Control.Monad import Control.Monad
import Data.ByteString (ByteString)
import Data.Protobuf.Wire.Class import Data.Protobuf.Wire.Class
import Data.Foldable (find) import Data.Foldable (find)
import Network.GRPC.HighLevel.Server import Network.GRPC.HighLevel.Server
import Network.GRPC.LowLevel import Network.GRPC.LowLevel
import Network.GRPC.LowLevel.GRPC
import Network.GRPC.LowLevel.Call
import qualified Network.GRPC.LowLevel.Server.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U
import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Call.Unregistered as U
@ -48,48 +43,45 @@ dispatchLoop server hN hC hS hB =
handleError f = f >>= handleCallError handleError f = f >>= handleCallError
unaryHandler :: (Message a, Message b) => unaryHandler :: (Message a, Message b) =>
U.ServerCall U.ServerCall
-> ServerHandler' a b -> ServerHandler a b
-> IO () -> IO ()
unaryHandler call h = unaryHandler call h =
handleError $ handleError $
U.serverHandleNormalCall' server call mempty $ \call' bs -> do U.serverHandleNormalCall' server call mempty $ \_call' bs ->
let h' = convertServerHandler h convertServerHandler h (fmap (const bs) $ U.convertCall call)
h' (fmap (const bs) $ U.convertCall call)
bs
(U.requestMetadataRecv call)
csHandler :: (Message a, Message b) => csHandler :: (Message a, Message b) =>
U.ServerCall U.ServerCall
-> ServerReaderHandler' a b -> ServerReaderHandler a b
-> IO () -> IO ()
csHandler call h = csHandler call h =
handleError $ handleError $
U.serverReader server call mempty (convertServerReaderHandler h) U.serverReader server call mempty (convertServerReaderHandler h)
ssHandler :: (Message a, Message b) => ssHandler :: (Message a, Message b) =>
U.ServerCall U.ServerCall
-> ServerWriterHandler' a b -> ServerWriterHandler a b
-> IO () -> IO ()
ssHandler call h = ssHandler call h =
handleError $ handleError $
U.serverWriter server call mempty (convertServerWriterHandler h) U.serverWriter server call mempty (convertServerWriterHandler h)
bdHandler :: (Message a, Message b) => bdHandler :: (Message a, Message b) =>
U.ServerCall U.ServerCall
-> ServerRWHandler' a b -> ServerRWHandler a b
-> IO () -> IO ()
bdHandler call h = bdHandler call h =
handleError $ handleError $
U.serverRW server call mempty (convertServerRWHandler h) U.serverRW server call mempty (convertServerRWHandler h)
serverLoop :: ServerOptions -> IO () serverLoop :: ServerOptions -> IO ()
serverLoop opts@ServerOptions{..} = serverLoop ServerOptions{..} =
withGRPC $ \grpc -> withGRPC $ \grpc ->
withServer grpc (mkConfig opts) $ \server -> do withServer grpc config $ \server -> do
dispatchLoop server dispatchLoop server
optNormalHandlers optNormalHandlers
optClientStreamHandlers optClientStreamHandlers
optServerStreamHandlers optServerStreamHandlers
optBiDiStreamHandlers optBiDiStreamHandlers
where where
mkConfig ServerOptions{..} = config =
ServerConfig ServerConfig
{ host = "localhost" { host = "localhost"
, port = optServerPort , port = optServerPort

View file

@ -35,19 +35,19 @@ GRPC
, ServerConfig(..) , ServerConfig(..)
, Server(normalMethods, sstreamingMethods, cstreamingMethods, , Server(normalMethods, sstreamingMethods, cstreamingMethods,
bidiStreamingMethods) bidiStreamingMethods)
, ServerCall(optionalPayload, requestMetadataRecv) , ServerCall(payload, metadata)
, withServer , withServer
, serverHandleNormalCall , serverHandleNormalCall
, ServerHandler , ServerHandlerLL
, withServerCall , withServerCall
, serverCallCancel , serverCallCancel
, serverCallIsExpired , serverCallIsExpired
, serverReader -- for client streaming , serverReader -- for client streaming
, ServerReaderHandler , ServerReaderHandlerLL
, serverWriter -- for server streaming , serverWriter -- for server streaming
, ServerWriterHandler , ServerWriterHandlerLL
, serverRW -- for bidirectional streaming , serverRW -- for bidirectional streaming
, ServerRWHandler , ServerRWHandlerLL
-- * Client -- * Client
, ClientConfig(..) , ClientConfig(..)

View file

@ -48,13 +48,15 @@ type family MethodPayload a where
--TODO: try replacing this class with a plain old function so we don't have the --TODO: try replacing this class with a plain old function so we don't have the
-- Payloadable constraint everywhere. -- Payloadable constraint everywhere.
payload :: RegisteredMethod mt -> Ptr C.ByteBuffer -> IO (MethodPayload mt) extractPayload :: RegisteredMethod mt
payload (RegisteredMethodNormal _ _ _) p = -> Ptr C.ByteBuffer
-> IO (MethodPayload mt)
extractPayload (RegisteredMethodNormal _ _ _) p =
peek p >>= C.copyByteBufferToByteString peek p >>= C.copyByteBufferToByteString
payload (RegisteredMethodClientStreaming _ _ _) _ = return () extractPayload (RegisteredMethodClientStreaming _ _ _) _ = return ()
payload (RegisteredMethodServerStreaming _ _ _) p = extractPayload (RegisteredMethodServerStreaming _ _ _) p =
peek p >>= C.copyByteBufferToByteString peek p >>= C.copyByteBufferToByteString
payload (RegisteredMethodBiDiStreaming _ _ _) _ = return () extractPayload (RegisteredMethodBiDiStreaming _ _ _) _ = return ()
newtype MethodName = MethodName {unMethodName :: String} newtype MethodName = MethodName {unMethodName :: String}
deriving (Show, Eq, IsString) deriving (Show, Eq, IsString)
@ -147,8 +149,8 @@ clientCallCancel cc = C.grpcCallCancel (unsafeCC cc) C.reserved
data ServerCall a = ServerCall data ServerCall a = ServerCall
{ unsafeSC :: C.Call { unsafeSC :: C.Call
, callCQ :: CompletionQueue , callCQ :: CompletionQueue
, requestMetadataRecv :: MetadataMap , metadata :: MetadataMap
, optionalPayload :: a , payload :: a
, callDeadline :: TimeSpec , callDeadline :: TimeSpec
} deriving (Functor, Show) } deriving (Functor, Show)
@ -194,7 +196,7 @@ debugServerCall sc@(ServerCall (C.Call ptr) _ _ _ _) = do
let dbug = grpcDebug . ("debugServerCall(R): " ++) let dbug = grpcDebug . ("debugServerCall(R): " ++)
dbug $ "server call: " ++ show ptr dbug $ "server call: " ++ show ptr
dbug $ "callCQ: " ++ show (callCQ sc) dbug $ "callCQ: " ++ show (callCQ sc)
dbug $ "metadata ptr: " ++ show (requestMetadataRecv sc) dbug $ "metadata: " ++ show (metadata sc)
dbug $ "deadline ptr: " ++ show (callDeadline sc) dbug $ "deadline ptr: " ++ show (callDeadline sc)
#else #else
{-# INLINE debugServerCall #-} {-# INLINE debugServerCall #-}

View file

@ -2,15 +2,8 @@
module Network.GRPC.LowLevel.Call.Unregistered where module Network.GRPC.LowLevel.Call.Unregistered where
import Control.Monad
import Foreign.Marshal.Alloc (free)
import Foreign.Ptr (Ptr)
#ifdef DEBUG
import Foreign.Storable (peek)
#endif
import qualified Network.GRPC.LowLevel.Call as Reg import qualified Network.GRPC.LowLevel.Call as Reg
import Network.GRPC.LowLevel.CompletionQueue import Network.GRPC.LowLevel.CompletionQueue
import Network.GRPC.LowLevel.CompletionQueue.Internal
import Network.GRPC.LowLevel.GRPC (MetadataMap, import Network.GRPC.LowLevel.GRPC (MetadataMap,
grpcDebug) grpcDebug)
import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe as C
@ -22,7 +15,7 @@ import System.Clock (TimeSpec)
data ServerCall = ServerCall data ServerCall = ServerCall
{ unsafeSC :: C.Call { unsafeSC :: C.Call
, callCQ :: CompletionQueue , callCQ :: CompletionQueue
, requestMetadataRecv :: MetadataMap , metadata :: MetadataMap
, callDeadline :: TimeSpec , callDeadline :: TimeSpec
, callMethod :: Reg.MethodName , callMethod :: Reg.MethodName
, callHost :: Reg.Host , callHost :: Reg.Host
@ -30,7 +23,7 @@ data ServerCall = ServerCall
convertCall :: ServerCall -> Reg.ServerCall () convertCall :: ServerCall -> Reg.ServerCall ()
convertCall ServerCall{..} = convertCall ServerCall{..} =
Reg.ServerCall unsafeSC callCQ requestMetadataRecv () callDeadline Reg.ServerCall unsafeSC callCQ metadata () callDeadline
serverCallCancel :: ServerCall -> C.StatusCode -> String -> IO () serverCallCancel :: ServerCall -> C.StatusCode -> String -> IO ()
serverCallCancel sc code reason = serverCallCancel sc code reason =
@ -43,7 +36,7 @@ debugServerCall ServerCall{..} = do
dbug = grpcDebug . ("debugServerCall(U): " ++) dbug = grpcDebug . ("debugServerCall(U): " ++)
dbug $ "server call: " ++ show ptr dbug $ "server call: " ++ show ptr
dbug $ "metadata: " ++ show requestMetadataRecv dbug $ "metadata: " ++ show metadata
dbug $ "deadline: " ++ show callDeadline dbug $ "deadline: " ++ show callDeadline
dbug $ "method: " ++ show callMethod dbug $ "method: " ++ show callMethod

View file

@ -35,25 +35,19 @@ module Network.GRPC.LowLevel.CompletionQueue
) )
where where
import Control.Concurrent.STM (atomically, import Control.Concurrent.STM.TVar (newTVarIO)
check)
import Control.Concurrent.STM.TVar (newTVarIO,
readTVar,
writeTVar)
import Control.Exception (bracket) import Control.Exception (bracket)
import Control.Monad.Managed import Control.Monad.Managed
import Control.Monad.Trans.Class (MonadTrans (lift)) import Control.Monad.Trans.Class (MonadTrans (lift))
import Control.Monad.Trans.Except import Control.Monad.Trans.Except
import Data.IORef (newIORef) import Data.IORef (newIORef)
import Data.List (intersperse) import Data.List (intersperse)
import Foreign.Marshal.Alloc (free, malloc) import Foreign.Ptr (nullPtr)
import Foreign.Ptr (Ptr, nullPtr) import Foreign.Storable (peek)
import Foreign.Storable (Storable, peek)
import Network.GRPC.LowLevel.Call import Network.GRPC.LowLevel.Call
import Network.GRPC.LowLevel.CompletionQueue.Internal import Network.GRPC.LowLevel.CompletionQueue.Internal
import Network.GRPC.LowLevel.GRPC import Network.GRPC.LowLevel.GRPC
import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe as C
import qualified Network.GRPC.Unsafe.ByteBuffer as C
import qualified Network.GRPC.Unsafe.Constants as C import qualified Network.GRPC.Unsafe.Constants as C
import qualified Network.GRPC.Unsafe.Metadata as C import qualified Network.GRPC.Unsafe.Metadata as C
import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.Op as C
@ -136,7 +130,7 @@ serverRequestCall rm s scq ccq =
<$> peek call <$> peek call
<*> return ccq <*> return ccq
<*> C.getAllMetadataArray md <*> C.getAllMetadataArray md
<*> payload rm pay <*> extractPayload rm pay
<*> convertDeadline dead <*> convertDeadline dead
_ -> do _ -> do
lift $ dbug $ "Throwing callError: " ++ show ce lift $ dbug $ "Throwing callError: " ++ show ce

View file

@ -8,13 +8,10 @@
module Network.GRPC.LowLevel.CompletionQueue.Unregistered where module Network.GRPC.LowLevel.CompletionQueue.Unregistered where
import Control.Exception (bracket)
import Control.Monad.Managed import Control.Monad.Managed
import Control.Monad.Trans.Class (MonadTrans (lift)) import Control.Monad.Trans.Class (MonadTrans (lift))
import Control.Monad.Trans.Except import Control.Monad.Trans.Except
import Foreign.Marshal.Alloc (free, malloc) import Foreign.Storable (peek)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, peek)
import Network.GRPC.LowLevel.Call import Network.GRPC.LowLevel.Call
import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Call.Unregistered as U
import Network.GRPC.LowLevel.CompletionQueue.Internal import Network.GRPC.LowLevel.CompletionQueue.Internal

View file

@ -13,10 +13,6 @@ import qualified Network.GRPC.Unsafe as C
import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.Op as C
import Proto3.Wire.Decode (ParseError) import Proto3.Wire.Decode (ParseError)
#ifdef DEBUG
import GHC.Conc (myThreadId)
#endif
type MetadataMap = M.Map B.ByteString B.ByteString type MetadataMap = M.Map B.ByteString B.ByteString
newtype StatusDetails = StatusDetails B.ByteString newtype StatusDetails = StatusDetails B.ByteString

View file

@ -6,7 +6,6 @@
module Network.GRPC.LowLevel.Op where module Network.GRPC.LowLevel.Op where
import Control.Arrow
import Control.Exception import Control.Exception
import Control.Monad import Control.Monad
import Control.Monad.Trans.Class (MonadTrans(lift)) import Control.Monad.Trans.Class (MonadTrans(lift))
@ -28,7 +27,6 @@ import qualified Network.GRPC.Unsafe.ByteBuffer as C
import qualified Network.GRPC.Unsafe.Metadata as C import qualified Network.GRPC.Unsafe.Metadata as C
import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.Op as C
import qualified Network.GRPC.Unsafe.Slice as C (Slice, freeSlice) import qualified Network.GRPC.Unsafe.Slice as C (Slice, freeSlice)
import Pipes ((>->))
import qualified Pipes as P import qualified Pipes as P
import qualified Pipes.Core as P import qualified Pipes.Core as P

View file

@ -25,8 +25,7 @@ import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue,
serverRegisterCompletionQueue, serverRegisterCompletionQueue,
serverRequestCall, serverRequestCall,
serverShutdownAndNotify, serverShutdownAndNotify,
shutdownCompletionQueue, shutdownCompletionQueue)
withCompletionQueue)
import Network.GRPC.LowLevel.GRPC import Network.GRPC.LowLevel.GRPC
import Network.GRPC.LowLevel.Op import Network.GRPC.LowLevel.Op
import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe as C
@ -249,7 +248,7 @@ withServerCall s rm f =
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- serverReader (server side of client streaming mode) -- serverReader (server side of client streaming mode)
type ServerReaderHandler type ServerReaderHandlerLL
= ServerCall () = ServerCall ()
-> StreamRecv ByteString -> StreamRecv ByteString
-> Streaming (Maybe ByteString, MetadataMap, C.StatusCode, StatusDetails) -> Streaming (Maybe ByteString, MetadataMap, C.StatusCode, StatusDetails)
@ -257,7 +256,7 @@ type ServerReaderHandler
serverReader :: Server serverReader :: Server
-> RegisteredMethod 'ClientStreaming -> RegisteredMethod 'ClientStreaming
-> MetadataMap -- ^ initial server metadata -> MetadataMap -- ^ initial server metadata
-> ServerReaderHandler -> ServerReaderHandlerLL
-> IO (Either GRPCIOError ()) -> IO (Either GRPCIOError ())
serverReader s rm initMeta f = withServerCall s rm go serverReader s rm initMeta f = withServerCall s rm go
where where
@ -273,7 +272,7 @@ serverReader s rm initMeta f = withServerCall s rm go
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- serverWriter (server side of server streaming mode) -- serverWriter (server side of server streaming mode)
type ServerWriterHandler type ServerWriterHandlerLL
= ServerCall ByteString = ServerCall ByteString
-> StreamSend ByteString -> StreamSend ByteString
-> Streaming (MetadataMap, C.StatusCode, StatusDetails) -> Streaming (MetadataMap, C.StatusCode, StatusDetails)
@ -283,7 +282,7 @@ serverWriter :: Server
-> RegisteredMethod 'ServerStreaming -> RegisteredMethod 'ServerStreaming
-> MetadataMap -> MetadataMap
-- ^ Initial server metadata -- ^ Initial server metadata
-> ServerWriterHandler -> ServerWriterHandlerLL
-> IO (Either GRPCIOError ()) -> IO (Either GRPCIOError ())
serverWriter s rm initMeta f = withServerCall s rm go serverWriter s rm initMeta f = withServerCall s rm go
where where
@ -295,7 +294,7 @@ serverWriter s rm initMeta f = withServerCall s rm go
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- serverRW (server side of bidirectional streaming mode) -- serverRW (server side of bidirectional streaming mode)
type ServerRWHandler type ServerRWHandlerLL
= ServerCall () = ServerCall ()
-> StreamRecv ByteString -> StreamRecv ByteString
-> StreamSend ByteString -> StreamSend ByteString
@ -305,7 +304,7 @@ serverRW :: Server
-> RegisteredMethod 'BiDiStreaming -> RegisteredMethod 'BiDiStreaming
-> MetadataMap -> MetadataMap
-- ^ initial server metadata -- ^ initial server metadata
-> ServerRWHandler -> ServerRWHandlerLL
-> IO (Either GRPCIOError ()) -> IO (Either GRPCIOError ())
serverRW s rm initMeta f = withServerCall s rm go serverRW s rm initMeta f = withServerCall s rm go
where where
@ -323,10 +322,8 @@ serverRW s rm initMeta f = withServerCall s rm go
-- values in the result tuple being the initial and trailing metadata -- values in the result tuple being the initial and trailing metadata
-- respectively. We pass in the 'ServerCall' so that the server can call -- respectively. We pass in the 'ServerCall' so that the server can call
-- 'serverCallCancel' on it if needed. -- 'serverCallCancel' on it if needed.
type ServerHandler type ServerHandlerLL
= ServerCall ByteString = ServerCall ByteString
-> ByteString
-> MetadataMap
-> IO (ByteString, MetadataMap, C.StatusCode, StatusDetails) -> IO (ByteString, MetadataMap, C.StatusCode, StatusDetails)
-- | Wait for and then handle a normal (non-streaming) call. -- | Wait for and then handle a normal (non-streaming) call.
@ -334,13 +331,13 @@ serverHandleNormalCall :: Server
-> RegisteredMethod 'Normal -> RegisteredMethod 'Normal
-> MetadataMap -> MetadataMap
-- ^ Initial server metadata -- ^ Initial server metadata
-> ServerHandler -> ServerHandlerLL
-> IO (Either GRPCIOError ()) -> IO (Either GRPCIOError ())
serverHandleNormalCall s rm initMeta f = serverHandleNormalCall s rm initMeta f =
withServerCall s rm go withServerCall s rm go
where where
go sc@ServerCall{..} = do go sc@ServerCall{..} = do
(rsp, trailMeta, st, ds) <- f sc optionalPayload requestMetadataRecv (rsp, trailMeta, st, ds) <- f sc
void <$> runOps unsafeSC callCQ void <$> runOps unsafeSC callCQ
[ OpSendInitialMetadata initMeta [ OpSendInitialMetadata initMeta
, OpRecvCloseOnServer , OpRecvCloseOnServer

View file

@ -9,9 +9,7 @@ import Control.Monad
import Control.Monad.Trans.Except import Control.Monad.Trans.Except
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Network.GRPC.LowLevel.Call.Unregistered import Network.GRPC.LowLevel.Call.Unregistered
import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue import Network.GRPC.LowLevel.CompletionQueue (createCompletionQueue)
, withCompletionQueue
, createCompletionQueue)
import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall) import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall)
import Network.GRPC.LowLevel.GRPC import Network.GRPC.LowLevel.GRPC
import Network.GRPC.LowLevel.Op (Op (..) import Network.GRPC.LowLevel.Op (Op (..)
@ -25,9 +23,9 @@ import Network.GRPC.LowLevel.Op (Op (..)
, sendStatusFromServer , sendStatusFromServer
, recvInitialMessage) , recvInitialMessage)
import Network.GRPC.LowLevel.Server (Server (..) import Network.GRPC.LowLevel.Server (Server (..)
, ServerReaderHandler , ServerReaderHandlerLL
, ServerWriterHandler , ServerWriterHandlerLL
, ServerRWHandler) , ServerRWHandlerLL)
import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.Op as C
serverCreateCall :: Server serverCreateCall :: Server
@ -56,7 +54,7 @@ withServerCallAsync :: Server
-> IO () -> IO ()
withServerCallAsync s f = withServerCallAsync s f =
serverCreateCall s >>= \case serverCreateCall s >>= \case
Left e -> return () Left _ -> return ()
Right c -> void $ forkIO (f c `finally` do Right c -> void $ forkIO (f c `finally` do
grpcDebug "withServerCallAsync: destroying." grpcDebug "withServerCallAsync: destroying."
destroyServerCall c) destroyServerCall c)
@ -102,7 +100,7 @@ serverHandleNormalCall' :: Server
-> ServerHandler -> ServerHandler
-> IO (Either GRPCIOError ()) -> IO (Either GRPCIOError ())
serverHandleNormalCall' serverHandleNormalCall'
s sc@ServerCall{ unsafeSC = c, callCQ = cq, .. } initMeta f = do _ sc@ServerCall{ unsafeSC = c, callCQ = cq, .. } initMeta f = do
grpcDebug "serverHandleNormalCall(U): starting batch." grpcDebug "serverHandleNormalCall(U): starting batch."
runOps c cq runOps c cq
[ OpSendInitialMetadata initMeta [ OpSendInitialMetadata initMeta
@ -113,7 +111,7 @@ serverHandleNormalCall'
grpcDebug "serverHandleNormalCall(U): ops failed; aborting" grpcDebug "serverHandleNormalCall(U): ops failed; aborting"
return $ Left x return $ Left x
Right [OpRecvMessageResult (Just body)] -> do Right [OpRecvMessageResult (Just body)] -> do
grpcDebug $ "got client metadata: " ++ show requestMetadataRecv grpcDebug $ "got client metadata: " ++ show metadata
grpcDebug $ "call_details host is: " ++ show callHost grpcDebug $ "call_details host is: " ++ show callHost
(rsp, trailMeta, st, ds) <- f sc body (rsp, trailMeta, st, ds) <- f sc body
runOps c cq runOps c cq
@ -133,9 +131,9 @@ serverHandleNormalCall'
serverReader :: Server serverReader :: Server
-> ServerCall -> ServerCall
-> MetadataMap -- ^ initial server metadata -> MetadataMap -- ^ initial server metadata
-> ServerReaderHandler -> ServerReaderHandlerLL
-> IO (Either GRPCIOError ()) -> IO (Either GRPCIOError ())
serverReader s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = serverReader _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f =
runExceptT $ do runExceptT $ do
(mmsg, trailMeta, st, ds) <- (mmsg, trailMeta, st, ds) <-
runStreamingProxy "serverReader" c ccq (f (convertCall sc) streamRecv) runStreamingProxy "serverReader" c ccq (f (convertCall sc) streamRecv)
@ -149,9 +147,9 @@ serverWriter :: Server
-> ServerCall -> ServerCall
-> MetadataMap -> MetadataMap
-- ^ Initial server metadata -- ^ Initial server metadata
-> ServerWriterHandler -> ServerWriterHandlerLL
-> IO (Either GRPCIOError ()) -> IO (Either GRPCIOError ())
serverWriter s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = serverWriter _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f =
runExceptT $ do runExceptT $ do
bs <- recvInitialMessage c ccq bs <- recvInitialMessage c ccq
sendInitialMetadata c ccq initMeta sendInitialMetadata c ccq initMeta
@ -163,9 +161,9 @@ serverRW :: Server
-> ServerCall -> ServerCall
-> MetadataMap -> MetadataMap
-- ^ initial server metadata -- ^ initial server metadata
-> ServerRWHandler -> ServerRWHandlerLL
-> IO (Either GRPCIOError ()) -> IO (Either GRPCIOError ())
serverRW s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = serverRW _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f =
runExceptT $ do runExceptT $ do
sendInitialMetadata c ccq initMeta sendInitialMetadata c ccq initMeta
let regCall = convertCall sc let regCall = convertCall sc

View file

@ -105,8 +105,8 @@ testMixRegisteredUnregistered =
return () return ()
where regThread = do where regThread = do
let rm = head (normalMethods s) let rm = head (normalMethods s)
r <- serverHandleNormalCall s rm dummyMeta $ \_ body _ -> do r <- serverHandleNormalCall s rm dummyMeta $ \c -> do
body @?= "Hello" payload c @?= "Hello"
return ("reply test", dummyMeta, StatusOk, "") return ("reply test", dummyMeta, StatusOk, "")
return () return ()
unregThread = do unregThread = do
@ -138,9 +138,9 @@ testPayload =
trailMD @?= dummyMeta trailMD @?= dummyMeta
server s = do server s = do
let rm = head (normalMethods s) let rm = head (normalMethods s)
r <- serverHandleNormalCall s rm dummyMeta $ \_ reqBody reqMD -> do r <- serverHandleNormalCall s rm dummyMeta $ \c -> do
reqBody @?= "Hello!" payload c @?= "Hello!"
checkMD "Server metadata mismatch" clientMD reqMD checkMD "Server metadata mismatch" clientMD (metadata c)
return ("reply test", dummyMeta, StatusOk, "details string") return ("reply test", dummyMeta, StatusOk, "details string")
r @?= Right () r @?= Right ()
@ -154,7 +154,7 @@ testServerCancel =
res @?= badStatus StatusCancelled res @?= badStatus StatusCancelled
server s = do server s = do
let rm = head (normalMethods s) let rm = head (normalMethods s)
r <- serverHandleNormalCall s rm mempty $ \c _ _ -> do r <- serverHandleNormalCall s rm mempty $ \c -> do
serverCallCancel c StatusCancelled "" serverCallCancel c StatusCancelled ""
return (mempty, mempty, StatusCancelled, "") return (mempty, mempty, StatusCancelled, "")
r @?= Right () r @?= Right ()
@ -181,8 +181,8 @@ testServerStreaming =
r <- serverWriter s rm serverInitMD $ \sc send -> do r <- serverWriter s rm serverInitMD $ \sc send -> do
liftIO $ do liftIO $ do
checkMD "Server request metadata mismatch" checkMD "Server request metadata mismatch"
clientInitMD (requestMetadataRecv sc) clientInitMD (metadata sc)
optionalPayload sc @?= clientPay payload sc @?= clientPay
forM_ pays $ \p -> send p `is` Right () forM_ pays $ \p -> send p `is` Right ()
return (dummyMeta, StatusOk, "dtls") return (dummyMeta, StatusOk, "dtls")
r @?= Right () r @?= Right ()
@ -212,8 +212,8 @@ testServerStreamingUnregistered =
r <- U.serverWriter s call serverInitMD $ \sc send -> do r <- U.serverWriter s call serverInitMD $ \sc send -> do
liftIO $ do liftIO $ do
checkMD "Server request metadata mismatch" checkMD "Server request metadata mismatch"
clientInitMD (requestMetadataRecv sc) clientInitMD (metadata sc)
optionalPayload sc @?= clientPay payload sc @?= clientPay
forM_ pays $ \p -> send p `is` Right () forM_ pays $ \p -> send p `is` Right ()
return (dummyMeta, StatusOk, "dtls") return (dummyMeta, StatusOk, "dtls")
r @?= Right () r @?= Right ()
@ -241,7 +241,7 @@ testClientStreaming =
let rm = head (cstreamingMethods s) let rm = head (cstreamingMethods s)
eea <- serverReader s rm serverInitMD $ \sc recv -> do eea <- serverReader s rm serverInitMD $ \sc recv -> do
liftIO $ checkMD "Client request metadata mismatch" liftIO $ checkMD "Client request metadata mismatch"
clientInitMD (requestMetadataRecv sc) clientInitMD (metadata sc)
forM_ pays $ \p -> recv `is` Right (Just p) forM_ pays $ \p -> recv `is` Right (Just p)
recv `is` Right Nothing recv `is` Right Nothing
return (Just serverRsp, trailMD, serverStatus, serverDtls) return (Just serverRsp, trailMD, serverStatus, serverDtls)
@ -269,7 +269,7 @@ testClientStreamingUnregistered =
server s = U.withServerCallAsync s $ \call -> do server s = U.withServerCallAsync s $ \call -> do
eea <- U.serverReader s call serverInitMD $ \sc recv -> do eea <- U.serverReader s call serverInitMD $ \sc recv -> do
liftIO $ checkMD "Client request metadata mismatch" liftIO $ checkMD "Client request metadata mismatch"
clientInitMD (requestMetadataRecv sc) clientInitMD (metadata sc)
forM_ pays $ \p -> recv `is` Right (Just p) forM_ pays $ \p -> recv `is` Right (Just p)
recv `is` Right Nothing recv `is` Right Nothing
return (Just serverRsp, trailMD, serverStatus, serverDtls) return (Just serverRsp, trailMD, serverStatus, serverDtls)
@ -301,7 +301,7 @@ testBiDiStreaming =
let rm = head (bidiStreamingMethods s) let rm = head (bidiStreamingMethods s)
eea <- serverRW s rm serverInitMD $ \sc recv send -> do eea <- serverRW s rm serverInitMD $ \sc recv send -> do
liftIO $ checkMD "Client request metadata mismatch" liftIO $ checkMD "Client request metadata mismatch"
clientInitMD (requestMetadataRecv sc) clientInitMD (metadata sc)
recv `is` Right (Just "cw0") recv `is` Right (Just "cw0")
send "sw0" `is` Right () send "sw0" `is` Right ()
recv `is` Right (Just "cw1") recv `is` Right (Just "cw1")
@ -336,7 +336,7 @@ testBiDiStreamingUnregistered =
server s = U.withServerCallAsync s $ \call -> do server s = U.withServerCallAsync s $ \call -> do
eea <- U.serverRW s call serverInitMD $ \sc recv send -> do eea <- U.serverRW s call serverInitMD $ \sc recv send -> do
liftIO $ checkMD "Client request metadata mismatch" liftIO $ checkMD "Client request metadata mismatch"
clientInitMD (requestMetadataRecv sc) clientInitMD (metadata sc)
recv `is` Right (Just "cw0") recv `is` Right (Just "cw0")
send "sw0" `is` Right () send "sw0" `is` Right ()
recv `is` Right (Just "cw1") recv `is` Right (Just "cw1")
@ -412,7 +412,7 @@ testSlowServer =
result @?= badStatus StatusDeadlineExceeded result @?= badStatus StatusDeadlineExceeded
server s = do server s = do
let rm = head (normalMethods s) let rm = head (normalMethods s)
serverHandleNormalCall s rm mempty $ \_ _ _ -> do serverHandleNormalCall s rm mempty $ \_ -> do
threadDelay (2*10^(6 :: Int)) threadDelay (2*10^(6 :: Int))
return dummyResp return dummyResp
return () return ()
@ -427,7 +427,7 @@ testServerCallExpirationCheck =
return () return ()
server s = do server s = do
let rm = head (normalMethods s) let rm = head (normalMethods s)
serverHandleNormalCall s rm mempty $ \c _ _ -> do serverHandleNormalCall s rm mempty $ \c -> do
exp1 <- serverCallIsExpired c exp1 <- serverCallIsExpired c
assertBool "Call isn't expired when handler starts" $ not exp1 assertBool "Call isn't expired when handler starts" $ not exp1
threadDelaySecs 1 threadDelaySecs 1
@ -451,8 +451,8 @@ testCustomUserAgent =
return () return ()
server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do
let rm = head (normalMethods s) let rm = head (normalMethods s)
serverHandleNormalCall s rm mempty $ \_ _ meta -> do serverHandleNormalCall s rm mempty $ \c -> do
let ua = meta M.! "user-agent" let ua = (metadata c) M.! "user-agent"
assertBool "User agent prefix is present" $ isPrefixOf "prefix!" ua assertBool "User agent prefix is present" $ isPrefixOf "prefix!" ua
assertBool "User agent suffix is present" $ isSuffixOf "suffix!" ua assertBool "User agent suffix is present" $ isSuffixOf "suffix!" ua
return dummyResp return dummyResp
@ -472,8 +472,8 @@ testClientCompression =
return () return ()
server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do
let rm = head (normalMethods s) let rm = head (normalMethods s)
serverHandleNormalCall s rm mempty $ \_ body _ -> do serverHandleNormalCall s rm mempty $ \c -> do
body @?= "hello" payload c @?= "hello"
return dummyResp return dummyResp
return () return ()
@ -500,8 +500,8 @@ testClientServerCompression =
[CompressionAlgArg GrpcCompressDeflate] [CompressionAlgArg GrpcCompressDeflate]
server = TestServer sconf $ \s -> do server = TestServer sconf $ \s -> do
let rm = head (normalMethods s) let rm = head (normalMethods s)
serverHandleNormalCall s rm dummyMeta $ \_sc body _ -> do serverHandleNormalCall s rm dummyMeta $ \sc -> do
body @?= "hello" payload sc @?= "hello"
return ("hello", dummyMeta, StatusOk, StatusDetails "") return ("hello", dummyMeta, StatusOk, StatusDetails "")
return () return ()
@ -517,9 +517,9 @@ dummyMeta = [("foo","bar")]
dummyResp :: (ByteString, MetadataMap, StatusCode, StatusDetails) dummyResp :: (ByteString, MetadataMap, StatusCode, StatusDetails)
dummyResp = ("", mempty, StatusOk, StatusDetails "") dummyResp = ("", mempty, StatusOk, StatusDetails "")
dummyHandler :: ServerCall a -> ByteString -> MetadataMap dummyHandler :: ServerCall a
-> IO (ByteString, MetadataMap, StatusCode, StatusDetails) -> IO (ByteString, MetadataMap, StatusCode, StatusDetails)
dummyHandler _ _ _ = return dummyResp dummyHandler _ = return dummyResp
dummyResult' :: StatusDetails dummyResult' :: StatusDetails
-> IO (ByteString, MetadataMap, StatusCode, StatusDetails) -> IO (ByteString, MetadataMap, StatusCode, StatusDetails)