From e4a28e9e4be79106238bb2ecc9c0f7c0016fc7f9 Mon Sep 17 00:00:00 2001 From: Connor Clark Date: Thu, 14 Jul 2016 09:53:28 -0700 Subject: [PATCH] High-level unregistered concurrent interface (#41) * remove parent ptr from unregistered calls -- unneeded * begin unregistered high level server loop * undo changes to highlevel server, add mkConfig for unregistered server * move call CQ create/destroy into call create/destroy * async normal call function * preliminary unregistered server loop for non-streaming methods * working unregistered highlevel example * loop counters for benchmarking * changes for benchmarking, add ruby example server for benchmarking * async version of withCall, refactor unregistered server loop to handle all method types * unregistered client streaming * add remaining streaming modes * unregistered server streaming test * unregistered streaming tests * add error logging * fix bug in add example * remove old TODOs * fix bug: don't assume slices are null-terminated * add TODO re: unregistered client streaming functions --- cbits/grpc_haskell.c | 4 +- examples/echo/echo-client/Main.hs | 17 ++- examples/echo/echo-cpp/echo-client.cc | 11 +- examples/echo/echo-cpp/echo-server.cc | 8 ++ examples/echo/echo-python/echo_client.py | 9 +- examples/echo/echo-ruby/echo-server.rb | 26 +++++ examples/echo/echo-ruby/echo.rb | 23 ++++ examples/echo/echo-ruby/echo_services.rb | 40 +++++++ examples/echo/echo-server/Main.hs | 64 +++++----- grpc-haskell.cabal | 1 + include/grpc_haskell.h | 2 +- src/Network/GRPC/HighLevel/Server.hs | 47 +++++--- .../GRPC/HighLevel/Server/Unregistered.hs | 105 +++++++++++++++++ src/Network/GRPC/LowLevel/Call.hs | 3 +- .../GRPC/LowLevel/Call/Unregistered.hs | 21 ++-- src/Network/GRPC/LowLevel/CompletionQueue.hs | 31 ----- .../GRPC/LowLevel/CompletionQueue/Internal.hs | 32 ++++- .../LowLevel/CompletionQueue/Unregistered.hs | 1 - src/Network/GRPC/LowLevel/Op.hs | 13 ++- src/Network/GRPC/LowLevel/Server.hs | 8 +- .../GRPC/LowLevel/Server/Unregistered.hs | 109 +++++++++++++++--- src/Network/GRPC/Unsafe/Slice.chs | 4 +- tests/LowLevelTests.hs | 97 ++++++++++++++++ tests/UnsafeTests.hs | 3 +- 24 files changed, 548 insertions(+), 131 deletions(-) create mode 100644 examples/echo/echo-ruby/echo-server.rb create mode 100644 examples/echo/echo-ruby/echo.rb create mode 100644 examples/echo/echo-ruby/echo_services.rb create mode 100644 src/Network/GRPC/HighLevel/Server/Unregistered.hs diff --git a/cbits/grpc_haskell.c b/cbits/grpc_haskell.c index 80942d7..cf16b80 100644 --- a/cbits/grpc_haskell.c +++ b/cbits/grpc_haskell.c @@ -50,10 +50,10 @@ uint8_t *gpr_slice_start_(gpr_slice *slice){ return GPR_SLICE_START_PTR(*slice); } -gpr_slice* gpr_slice_from_copied_string_(const char *source){ +gpr_slice* gpr_slice_from_copied_buffer_(const char *source, size_t len){ gpr_slice* retval = malloc(sizeof(gpr_slice)); //note: 'gpr_slice_from_copied_string' handles allocating space for 'source'. - *retval = gpr_slice_from_copied_string(source); + *retval = gpr_slice_from_copied_buffer(source, len); return retval; } diff --git a/examples/echo/echo-client/Main.hs b/examples/echo/echo-client/Main.hs index fc71a54..77b285a 100644 --- a/examples/echo/echo-client/Main.hs +++ b/examples/echo/echo-client/Main.hs @@ -7,13 +7,16 @@ import Control.Monad import qualified Data.ByteString.Lazy as BL import Data.Protobuf.Wire.Class +import Data.Protobuf.Wire.Types import qualified Data.Text as T import Data.Word import GHC.Generics (Generic) import Network.GRPC.LowLevel import qualified Network.GRPC.LowLevel.Client.Unregistered as U +import Proto3.Wire.Decode (ParseError) echoMethod = MethodName "/echo.Echo/DoEcho" +addMethod = MethodName "/echo.Add/DoAdd" _unregistered c = U.clientRequest c echoMethod 1 "hi" mempty @@ -30,9 +33,9 @@ regMain = withGRPC $ \g -> -- TODO: Put these in a common location (or just hack around it until CG is working) data EchoRequest = EchoRequest {message :: T.Text} deriving (Show, Eq, Ord, Generic) instance Message EchoRequest -data AddRequest = AddRequest {addX :: Word32, addY :: Word32} deriving (Show, Eq, Ord, Generic) +data AddRequest = AddRequest {addX :: Fixed Word32, addY :: Fixed Word32} deriving (Show, Eq, Ord, Generic) instance Message AddRequest -data AddResponse = AddResponse {answer :: Word32} deriving (Show, Eq, Ord, Generic) +data AddResponse = AddResponse {answer :: Fixed Word32} deriving (Show, Eq, Ord, Generic) instance Message AddResponse -- TODO: Create Network.GRPC.HighLevel.Client w/ request variants @@ -49,5 +52,15 @@ highlevelMain = withGRPC $ \g -> Right dec | dec == pay -> return () | otherwise -> error $ "Got unexpected payload: " ++ show dec + rmAdd <- clientRegisterMethodNormal c addMethod + let addPay = AddRequest 1 2 + addEnc = BL.toStrict . toLazyByteString $ addPay + replicateM_ 1 $ clientRequest c rmAdd 5 addEnc mempty >>= \case + Left e -> error $ "Got client error on add request: " ++ show e + Right r -> case fromByteString (rspBody r) of + Left e -> error $ "failed to decode add response: " ++ show e + Right dec + | dec == AddResponse 3 -> return () + | otherwise -> error $ "Got wrong add answer: " ++ show dec main = highlevelMain diff --git a/examples/echo/echo-cpp/echo-client.cc b/examples/echo/echo-cpp/echo-client.cc index cda410e..628d674 100644 --- a/examples/echo/echo-cpp/echo-client.cc +++ b/examples/echo/echo-cpp/echo-client.cc @@ -52,22 +52,23 @@ private: }; int main(){ - /* + EchoClient client(grpc::CreateChannel("localhost:50051", grpc::InsecureChannelCredentials())); string msg("hi"); - for(int i = 0; i < 100000; i++){ + /* + while(true){ Status status = client.DoEcho(msg); if(!status.ok()){ cout<<"Error: "< +#include +#include #include #include "echo.grpc.pb.h" @@ -13,9 +15,15 @@ using grpc::Status; using echo::EchoRequest; using echo::Echo; +atomic_int reqCount; + class EchoServiceImpl final : public Echo::Service { Status DoEcho(ServerContext* ctx, const EchoRequest* req, EchoRequest* resp) override { + reqCount++; + if(reqCount % 100 == 0){ + cout<set_message(req->message()); return Status::OK; } diff --git a/examples/echo/echo-python/echo_client.py b/examples/echo/echo-python/echo_client.py index 46fc45e..6ab09df 100644 --- a/examples/echo/echo-python/echo_client.py +++ b/examples/echo/echo-python/echo_client.py @@ -3,11 +3,10 @@ from grpc.beta import implementations import echo_pb2 def main(): - for _ in xrange(1000): - channel = implementations.insecure_channel('localhost', 50051) - stub = echo_pb2.beta_create_Echo_stub(channel) - message = echo_pb2.EchoRequest(message='foo') - response = stub.DoEcho(message, 15) + channel = implementations.insecure_channel('localhost', 50051) + stub = echo_pb2.beta_create_Echo_stub(channel) + message = echo_pb2.EchoRequest(message='foo') + response = stub.DoEcho(message, 15) if __name__ == '__main__': main() diff --git a/examples/echo/echo-ruby/echo-server.rb b/examples/echo/echo-ruby/echo-server.rb new file mode 100644 index 0000000..52136c6 --- /dev/null +++ b/examples/echo/echo-ruby/echo-server.rb @@ -0,0 +1,26 @@ +this_dir = File.expand_path(File.dirname(__FILE__)) +$LOAD_PATH.unshift(this_dir) + +require 'grpc' +require 'echo_services' + +$i = 0 + +class EchoServer < Echo::Echo::Service + def do_echo(echo_req, _unused_call) + $i = $i+1 + if $i % 100 == 0 + puts($i) + end + return echo_req + end +end + +def main + s = GRPC::RpcServer.new + s.add_http2_port('0.0.0.0:50051', :this_port_is_insecure) + s.handle(EchoServer) + s.run_till_terminated +end + +main diff --git a/examples/echo/echo-ruby/echo.rb b/examples/echo/echo-ruby/echo.rb new file mode 100644 index 0000000..88caba1 --- /dev/null +++ b/examples/echo/echo-ruby/echo.rb @@ -0,0 +1,23 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: echo.proto + +require 'google/protobuf' + +Google::Protobuf::DescriptorPool.generated_pool.build do + add_message "echo.EchoRequest" do + optional :message, :string, 1 + end + add_message "echo.AddRequest" do + optional :addX, :fixed32, 1 + optional :addY, :fixed32, 2 + end + add_message "echo.AddResponse" do + optional :answer, :fixed32, 1 + end +end + +module Echo + EchoRequest = Google::Protobuf::DescriptorPool.generated_pool.lookup("echo.EchoRequest").msgclass + AddRequest = Google::Protobuf::DescriptorPool.generated_pool.lookup("echo.AddRequest").msgclass + AddResponse = Google::Protobuf::DescriptorPool.generated_pool.lookup("echo.AddResponse").msgclass +end diff --git a/examples/echo/echo-ruby/echo_services.rb b/examples/echo/echo-ruby/echo_services.rb new file mode 100644 index 0000000..0bb3f59 --- /dev/null +++ b/examples/echo/echo-ruby/echo_services.rb @@ -0,0 +1,40 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# Source: echo.proto for package 'echo' + +require 'grpc' +require 'echo' + +module Echo + module Echo + + # TODO: add proto service documentation here + class Service + + include GRPC::GenericService + + self.marshal_class_method = :encode + self.unmarshal_class_method = :decode + self.service_name = 'echo.Echo' + + rpc :DoEcho, EchoRequest, EchoRequest + end + + Stub = Service.rpc_stub_class + end + module Add + + # TODO: add proto service documentation here + class Service + + include GRPC::GenericService + + self.marshal_class_method = :encode + self.unmarshal_class_method = :decode + self.service_name = 'echo.Add' + + rpc :DoAdd, AddRequest, AddResponse + end + + Stub = Service.rpc_stub_class + end +end diff --git a/examples/echo/echo-server/Main.hs b/examples/echo/echo-server/Main.hs index 2d0070e..e0f0368 100644 --- a/examples/echo/echo-server/Main.hs +++ b/examples/echo/echo-server/Main.hs @@ -11,10 +11,12 @@ import Control.Concurrent.Async import Control.Monad import Data.ByteString (ByteString) import Data.Protobuf.Wire.Class +import Data.Protobuf.Wire.Types import qualified Data.Text as T import Data.Word import GHC.Generics (Generic) import Network.GRPC.HighLevel.Server +import qualified Network.GRPC.HighLevel.Server.Unregistered as U import Network.GRPC.LowLevel import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U @@ -80,44 +82,48 @@ regMainThreaded = do -- TODO: Put these in a common location (or just hack around it until CG is working) data EchoRequest = EchoRequest {message :: T.Text} deriving (Show, Eq, Ord, Generic) instance Message EchoRequest -data AddRequest = AddRequest {addX :: Word32, addY :: Word32} deriving (Show, Eq, Ord, Generic) + +echoHandler :: Handler 'Normal +echoHandler = + UnaryHandler "/echo.Echo/DoEcho" $ + \_c body m -> do + return ( body :: EchoRequest + , m + , StatusOk + , StatusDetails "" + ) + +data AddRequest = AddRequest {addX :: Fixed Word32 + , addY :: Fixed Word32} + deriving (Show, Eq, Ord, Generic) instance Message AddRequest -data AddResponse = AddResponse {answer :: Word32} deriving (Show, Eq, Ord, Generic) +data AddResponse = AddResponse {answer :: Fixed Word32} + deriving (Show, Eq, Ord, Generic) instance Message AddResponse +addHandler :: Handler 'Normal +addHandler = + UnaryHandler "/echo.Add/DoAdd" $ + \_c b m -> do + --tputStrLn $ "UnaryHandler for DoAdd hit, b=" ++ show b + print (addX b) + print (addY b) + return ( AddResponse $ addX b + addY b + , m + , StatusOk + , StatusDetails "" + ) + highlevelMain :: IO () highlevelMain = serverLoop defaultOptions{optNormalHandlers = [echoHandler, addHandler]} - where echoHandler = - UnaryHandler "/echo.Echo/DoEcho" $ - \_c body m -> do - tputStrLn $ "UnaryHandler for DoEcho hit, body=" ++ show body - return ( body :: EchoRequest - , m - , StatusOk - , StatusDetails "" - ) - addHandler = - --TODO: I can't get this one to execute. Is the generated method - --name different? - -- static const char* Add_method_names[] = { - -- "/echo.Add/DoAdd", - -- }; - - UnaryHandler "/echo.Add/DoAdd" $ - \_c b m -> do - tputStrLn $ "UnaryHandler for DoAdd hit, b=" ++ show b - print (addX b) - print (addY b) - return ( AddResponse $ addX b + addY b - , m - , StatusOk - , StatusDetails "" - ) +highlevelMainUnregistered :: IO () +highlevelMainUnregistered = + U.serverLoop defaultOptions{optNormalHandlers = [echoHandler, addHandler]} main :: IO () -main = highlevelMain +main = highlevelMainUnregistered defConfig :: ServerConfig defConfig = ServerConfig "localhost" 50051 [] [] [] [] [] diff --git a/grpc-haskell.cabal b/grpc-haskell.cabal index f29fb32..ecb9d4a 100644 --- a/grpc-haskell.cabal +++ b/grpc-haskell.cabal @@ -67,6 +67,7 @@ library Network.GRPC.LowLevel.Client Network.GRPC.HighLevel Network.GRPC.HighLevel.Server + Network.GRPC.HighLevel.Server.Unregistered extra-libraries: grpc includes: diff --git a/include/grpc_haskell.h b/include/grpc_haskell.h index 907eae8..5a773cc 100644 --- a/include/grpc_haskell.h +++ b/include/grpc_haskell.h @@ -26,7 +26,7 @@ size_t gpr_slice_length_(gpr_slice *slice); uint8_t *gpr_slice_start_(gpr_slice *slice); -gpr_slice* gpr_slice_from_copied_string_(const char *source); +gpr_slice* gpr_slice_from_copied_buffer_(const char *source, size_t len); void free_slice(gpr_slice *slice); diff --git a/src/Network/GRPC/HighLevel/Server.hs b/src/Network/GRPC/HighLevel/Server.hs index 0b9b1cd..e4ec226 100644 --- a/src/Network/GRPC/HighLevel/Server.hs +++ b/src/Network/GRPC/HighLevel/Server.hs @@ -1,9 +1,10 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} module Network.GRPC.HighLevel.Server where @@ -28,7 +29,7 @@ convertServerHandler :: (Message a, Message b) => ServerHandler' a b -> ServerHandler convertServerHandler f c bs m = case fromByteString bs of - Left{} -> error "TODO: find a way to keep this from killing the server." + Left x -> error $ "Failed to deserialize message: " ++ show x Right x -> do (y, tm, sc, sd) <- f c x m return (toBS y, tm, sc, sd) @@ -88,7 +89,7 @@ convertSend s = s . toBS toBS :: Message a => a -> ByteString toBS = BL.toStrict . toLazyByteString -data Handler a where +data Handler (a :: GRPCMethodType) where UnaryHandler :: (Message c, Message d) => MethodName @@ -113,6 +114,11 @@ data Handler a where -> ServerRWHandler' c d -> Handler 'BiDiStreaming +data AnyHandler = forall (a :: GRPCMethodType) . AnyHandler (Handler a) + +anyHandlerMethodName :: AnyHandler -> MethodName +anyHandlerMethodName (AnyHandler m) = handlerMethodName m + handlerMethodName :: Handler a -> MethodName handlerMethodName (UnaryHandler m _) = m handlerMethodName (ClientStreamHandler m _) = m @@ -142,21 +148,28 @@ handleCallError (Left GRPCIOShutdown) = return () handleCallError (Left x) = logAskReport x -loopWError :: IO (Either GRPCIOError a) -> IO () -loopWError f = forever $ f >>= handleCallError +loopWError :: Int + -> IO (Either GRPCIOError a) + -> IO () +loopWError i f = do + when (i `mod` 100 == 0) $ putStrLn $ "i = " ++ show i + f >>= handleCallError + loopWError (i + 1) f --TODO: options for setting initial/trailing metadata -handleLoop :: Server -> (Handler a, RegisteredMethod a) -> IO () +handleLoop :: Server + -> (Handler a, RegisteredMethod a) + -> IO () handleLoop s (UnaryHandler _ f, rm) = - loopWError $ do - grpcDebug' "handleLoop about to block on serverHandleNormalCall" + loopWError 0 $ do + --grpcDebug' "handleLoop about to block on serverHandleNormalCall" serverHandleNormalCall s rm mempty $ convertServerHandler f handleLoop s (ClientStreamHandler _ f, rm) = - loopWError $ serverReader s rm mempty $ convertServerReaderHandler f + loopWError 0 $ serverReader s rm mempty $ convertServerReaderHandler f handleLoop s (ServerStreamHandler _ f, rm) = - loopWError $ serverWriter s rm mempty $ convertServerWriterHandler f + loopWError 0 $ serverWriter s rm mempty $ convertServerWriterHandler f handleLoop s (BiDiStreamHandler _ f, rm) = - loopWError $ serverRW s rm mempty $ convertServerRWHandler f + loopWError 0 $ serverRW s rm mempty $ convertServerRWHandler f data ServerOptions = ServerOptions {optNormalHandlers :: [Handler 'Normal], @@ -194,7 +207,7 @@ serverLoop opts = asyncsCS <- mapM async $ map loop rmsCS asyncsSS <- mapM async $ map loop rmsSS asyncsB <- mapM async $ map loop rmsB - asyncUnk <- async $ loopWError $ unknownHandler server + asyncUnk <- async $ loopWError 0 $ unknownHandler server waitAnyCancel $ asyncUnk : asyncsN ++ asyncsCS ++ asyncsSS ++ asyncsB return () where diff --git a/src/Network/GRPC/HighLevel/Server/Unregistered.hs b/src/Network/GRPC/HighLevel/Server/Unregistered.hs new file mode 100644 index 0000000..1d63a3b --- /dev/null +++ b/src/Network/GRPC/HighLevel/Server/Unregistered.hs @@ -0,0 +1,105 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} + +module Network.GRPC.HighLevel.Server.Unregistered where + +import Control.Applicative ((<|>)) +import Control.Concurrent.Async +import Control.Monad +import Data.ByteString (ByteString) +import Data.Protobuf.Wire.Class +import Data.Foldable (find) +import Network.GRPC.HighLevel.Server +import Network.GRPC.LowLevel +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Call +import qualified Network.GRPC.LowLevel.Server.Unregistered as U +import qualified Network.GRPC.LowLevel.Call.Unregistered as U + +dispatchLoop :: Server + -> [Handler 'Normal] + -> [Handler 'ClientStreaming] + -> [Handler 'ServerStreaming] + -> [Handler 'BiDiStreaming] + -> IO () +dispatchLoop server hN hC hS hB = + forever $ U.withServerCallAsync server $ \call -> do + case findHandler call allHandlers of + Just (AnyHandler (UnaryHandler _ h)) -> unaryHandler call h + Just (AnyHandler (ClientStreamHandler _ h)) -> csHandler call h + Just (AnyHandler (ServerStreamHandler _ h)) -> ssHandler call h + Just (AnyHandler (BiDiStreamHandler _ h)) -> bdHandler call h + Nothing -> unknownHandler call + where allHandlers = map AnyHandler hN + ++ map AnyHandler hC + ++ map AnyHandler hS + ++ map AnyHandler hB + findHandler call = find ((== (U.callMethod call)) + . anyHandlerMethodName) + unknownHandler call = + void $ U.serverHandleNormalCall' server call mempty $ \_ _ -> + return (mempty + , mempty + , StatusNotFound + , StatusDetails "unknown method") + handleError f = f >>= handleCallError + unaryHandler :: (Message a, Message b) => + U.ServerCall + -> ServerHandler' a b + -> IO () + unaryHandler call h = + handleError $ + U.serverHandleNormalCall' server call mempty $ \call' bs -> do + let h' = convertServerHandler h + h' (fmap (const bs) $ U.convertCall call) + bs + (U.requestMetadataRecv call) + csHandler :: (Message a, Message b) => + U.ServerCall + -> ServerReaderHandler' a b + -> IO () + csHandler call h = + handleError $ + U.serverReader server call mempty (convertServerReaderHandler h) + ssHandler :: (Message a, Message b) => + U.ServerCall + -> ServerWriterHandler' a b + -> IO () + ssHandler call h = + handleError $ + U.serverWriter server call mempty (convertServerWriterHandler h) + bdHandler :: (Message a, Message b) => + U.ServerCall + -> ServerRWHandler' a b + -> IO () + bdHandler call h = + handleError $ + U.serverRW server call mempty (convertServerRWHandler h) + +serverLoop :: ServerOptions -> IO () +serverLoop opts@ServerOptions{..} = + withGRPC $ \grpc -> + withServer grpc (mkConfig opts) $ \server -> do + dispatchLoop server + optNormalHandlers + optClientStreamHandlers + optServerStreamHandlers + optBiDiStreamHandlers + where + mkConfig ServerOptions{..} = + ServerConfig + { host = "localhost" + , port = optServerPort + , methodsToRegisterNormal = [] + , methodsToRegisterClientStreaming = [] + , methodsToRegisterServerStreaming = [] + , methodsToRegisterBiDiStreaming = [] + , serverArgs = + ([CompressionAlgArg GrpcCompressDeflate | optUseCompression] + ++ + [UserAgentPrefix optUserAgentPrefix + , UserAgentSuffix optUserAgentSuffix]) + } diff --git a/src/Network/GRPC/LowLevel/Call.hs b/src/Network/GRPC/LowLevel/Call.hs index d850522..f4ff133 100644 --- a/src/Network/GRPC/LowLevel/Call.hs +++ b/src/Network/GRPC/LowLevel/Call.hs @@ -207,8 +207,9 @@ destroyClientCall cc = do C.grpcCallDestroy (unsafeCC cc) destroyServerCall :: ServerCall a -> IO () -destroyServerCall sc@ServerCall{ unsafeSC = c } = do +destroyServerCall sc@ServerCall{ unsafeSC = c, .. } = do grpcDebug "destroyServerCall(R): entered." debugServerCall sc + _ <- shutdownCompletionQueue callCQ grpcDebug $ "Destroying server-side call object: " ++ show c C.grpcCallDestroy c diff --git a/src/Network/GRPC/LowLevel/Call/Unregistered.hs b/src/Network/GRPC/LowLevel/Call/Unregistered.hs index a125cf2..7bb2653 100644 --- a/src/Network/GRPC/LowLevel/Call/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Call/Unregistered.hs @@ -8,8 +8,8 @@ import Foreign.Ptr (Ptr) #ifdef DEBUG import Foreign.Storable (peek) #endif -import Network.GRPC.LowLevel.Call (Host (..), - MethodName (..)) +import qualified Network.GRPC.LowLevel.Call as Reg +import Network.GRPC.LowLevel.CompletionQueue import Network.GRPC.LowLevel.CompletionQueue.Internal import Network.GRPC.LowLevel.GRPC (MetadataMap, grpcDebug) @@ -23,12 +23,15 @@ data ServerCall = ServerCall { unsafeSC :: C.Call , callCQ :: CompletionQueue , requestMetadataRecv :: MetadataMap - , parentPtr :: Maybe (Ptr C.Call) , callDeadline :: TimeSpec - , callMethod :: MethodName - , callHost :: Host + , callMethod :: Reg.MethodName + , callHost :: Reg.Host } +convertCall :: ServerCall -> Reg.ServerCall () +convertCall ServerCall{..} = + Reg.ServerCall unsafeSC callCQ requestMetadataRecv () callDeadline + serverCallCancel :: ServerCall -> C.StatusCode -> String -> IO () serverCallCancel sc code reason = C.grpcCallCancelWithStatus (unsafeSC sc) code reason C.reserved @@ -42,11 +45,6 @@ debugServerCall ServerCall{..} = do dbug $ "server call: " ++ show ptr dbug $ "metadata: " ++ show requestMetadataRecv - forM_ parentPtr $ \parentPtr' -> do - dbug $ "parent ptr: " ++ show parentPtr' - C.Call parent <- peek parentPtr' - dbug $ "parent: " ++ show parent - dbug $ "deadline: " ++ show callDeadline dbug $ "method: " ++ show callMethod dbug $ "host: " ++ show callHost @@ -60,6 +58,5 @@ destroyServerCall call@ServerCall{..} = do grpcDebug "destroyServerCall(U): entered." debugServerCall call grpcDebug $ "Destroying server-side call object: " ++ show unsafeSC + shutdownCompletionQueue callCQ C.grpcCallDestroy unsafeSC - grpcDebug $ "freeing parentPtr: " ++ show parentPtr - forM_ parentPtr free diff --git a/src/Network/GRPC/LowLevel/CompletionQueue.hs b/src/Network/GRPC/LowLevel/CompletionQueue.hs index fbdc2b0..90878b8 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue.hs @@ -61,7 +61,6 @@ import qualified Network.GRPC.Unsafe.Time as C import System.Clock (Clock (..), getTime) import System.Info (os) -import System.Timeout (timeout) withCompletionQueue :: GRPC -> (CompletionQueue -> IO a) -> IO a withCompletionQueue grpc = bracket (createCompletionQueue grpc) @@ -89,36 +88,6 @@ startBatch cq@CompletionQueue{..} call opArray opArraySize tag = grpcDebug "startBatch: grpc_call_start_batch call returned." return res - --- | Shuts down the completion queue. See the comment above 'CompletionQueue' --- for the strategy we use to ensure that no one tries to use the --- queue after we begin the shutdown process. Errors with --- 'GRPCIOShutdownFailure' if the queue can't be shut down within 5 seconds. -shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ()) -shutdownCompletionQueue CompletionQueue{..} = do - atomically $ writeTVar shuttingDown True - atomically $ do - readTVar currentPushers >>= check . (==0) - readTVar currentPluckers >>= check . (==0) - --drain the queue - C.grpcCompletionQueueShutdown unsafeCQ - loopRes <- timeout (5*10^(6::Int)) drainLoop - grpcDebug $ "Got CQ loop shutdown result of: " ++ show loopRes - case loopRes of - Nothing -> return $ Left GRPCIOShutdownFailure - Just () -> C.grpcCompletionQueueDestroy unsafeCQ >> return (Right ()) - - where drainLoop :: IO () - drainLoop = do - grpcDebug "drainLoop: before next() call" - ev <- C.withDeadlineSeconds 1 $ \deadline -> - C.grpcCompletionQueueNext unsafeCQ deadline C.reserved - grpcDebug $ "drainLoop: next() call got " ++ show ev - case C.eventCompletionType ev of - C.QueueShutdown -> return () - C.QueueTimeout -> drainLoop - C.OpComplete -> drainLoop - channelCreateCall :: C.Channel -> Maybe (ServerCall a) -> C.PropagationMask diff --git a/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs b/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs index e0215ff..464dddf 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs @@ -2,7 +2,7 @@ module Network.GRPC.LowLevel.CompletionQueue.Internal where -import Control.Concurrent.STM (atomically, retry) +import Control.Concurrent.STM (atomically, retry, check) import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar, writeTVar) import Control.Exception (bracket) @@ -13,6 +13,7 @@ import Network.GRPC.LowLevel.GRPC import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.Constants as C import qualified Network.GRPC.Unsafe.Time as C +import System.Timeout (timeout) -- NOTE: the concurrency requirements for a CompletionQueue are a little -- complicated. There are two read operations: next and pluck. We can either @@ -143,3 +144,32 @@ getCount Pluck = currentPluckers getLimit :: CQOpType -> Int getLimit Push = maxWorkPushers getLimit Pluck = C.maxCompletionQueuePluckers + +-- | Shuts down the completion queue. See the comment above 'CompletionQueue' +-- for the strategy we use to ensure that no one tries to use the +-- queue after we begin the shutdown process. Errors with +-- 'GRPCIOShutdownFailure' if the queue can't be shut down within 5 seconds. +shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ()) +shutdownCompletionQueue CompletionQueue{..} = do + atomically $ writeTVar shuttingDown True + atomically $ do + readTVar currentPushers >>= check . (==0) + readTVar currentPluckers >>= check . (==0) + --drain the queue + C.grpcCompletionQueueShutdown unsafeCQ + loopRes <- timeout (5*10^(6::Int)) drainLoop + grpcDebug $ "Got CQ loop shutdown result of: " ++ show loopRes + case loopRes of + Nothing -> return $ Left GRPCIOShutdownFailure + Just () -> C.grpcCompletionQueueDestroy unsafeCQ >> return (Right ()) + + where drainLoop :: IO () + drainLoop = do + grpcDebug "drainLoop: before next() call" + ev <- C.withDeadlineSeconds 1 $ \deadline -> + C.grpcCompletionQueueNext unsafeCQ deadline C.reserved + grpcDebug $ "drainLoop: next() call got " ++ show ev + case C.eventCompletionType ev of + C.QueueShutdown -> return () + C.QueueTimeout -> drainLoop + C.OpComplete -> drainLoop diff --git a/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs b/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs index f217ba5..e909672 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs @@ -61,7 +61,6 @@ serverRequestCall s scq ccq = <$> peek call <*> return ccq <*> C.getAllMetadataArray md - <*> return Nothing <*> (C.timeSpec <$> C.callDetailsGetDeadline cd) <*> (MethodName <$> C.callDetailsGetMethod cd) <*> (Host <$> C.callDetailsGetHost cd) diff --git a/src/Network/GRPC/LowLevel/Op.hs b/src/Network/GRPC/LowLevel/Op.hs index 616743a..b8a5273 100644 --- a/src/Network/GRPC/LowLevel/Op.hs +++ b/src/Network/GRPC/LowLevel/Op.hs @@ -175,9 +175,10 @@ resultFromOpContext (OpRecvMessageContext pbb) = do grpcDebug "resultFromOpContext: OpRecvMessageContext" bb@(C.ByteBuffer bbptr) <- peek pbb if bbptr == nullPtr - then return $ Just $ OpRecvMessageResult Nothing + then do grpcDebug "resultFromOpContext: WARNING: got empty message." + return $ Just $ OpRecvMessageResult Nothing else do bs <- C.copyByteBufferToByteString bb - grpcDebug "resultFromOpContext: bb copied." + grpcDebug $ "resultFromOpContext: bb copied: " ++ show bs return $ Just $ OpRecvMessageResult (Just bs) resultFromOpContext (OpRecvStatusOnClientContext pmetadata pcode pstr) = do grpcDebug "resultFromOpContext: OpRecvStatusOnClientContext" @@ -294,6 +295,14 @@ recvStatusOnClient c cq = runOps' c cq [OpRecvStatusOnClient] >>= \case -> return (md, st, StatusDetails ds) _ -> throwE (GRPCIOInternalUnexpectedRecv "recvStatusOnClient") +recvInitialMessage :: RecvSingle ByteString +recvInitialMessage c cq = runOps' c cq [OpRecvMessage] >>= \case + [OpRecvMessageResult (Just bs)] + -> return bs + [OpRecvMessageResult Nothing] + -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMessage: no message.") + _ -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMessage") + -------------------------------------------------------------------------------- -- Streaming types and helpers diff --git a/src/Network/GRPC/LowLevel/Server.hs b/src/Network/GRPC/LowLevel/Server.hs index ff53736..f36c43d 100644 --- a/src/Network/GRPC/LowLevel/Server.hs +++ b/src/Network/GRPC/LowLevel/Server.hs @@ -228,17 +228,17 @@ serverRegisterMethodBiDiStreaming internalServer meth e = do -- method. serverCreateCall :: Server -> RegisteredMethod mt - -> CompletionQueue -- ^ call CQ -> IO (Either GRPCIOError (ServerCall (MethodPayload mt))) -serverCreateCall Server{..} rm = serverRequestCall rm unsafeServer serverCQ +serverCreateCall Server{..} rm = do + callCQ <- createCompletionQueue serverGRPC + serverRequestCall rm unsafeServer serverCQ callCQ withServerCall :: Server -> RegisteredMethod mt -> (ServerCall (MethodPayload mt) -> IO (Either GRPCIOError a)) -> IO (Either GRPCIOError a) withServerCall s rm f = - withCompletionQueue (serverGRPC s) $ - serverCreateCall s rm >=> \case + serverCreateCall s rm >>= \case Left e -> return (Left e) Right c -> do debugServerCall c diff --git a/src/Network/GRPC/LowLevel/Server/Unregistered.hs b/src/Network/GRPC/LowLevel/Server/Unregistered.hs index 51445dc..218c80b 100644 --- a/src/Network/GRPC/LowLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Server/Unregistered.hs @@ -3,33 +3,63 @@ module Network.GRPC.LowLevel.Server.Unregistered where +import Control.Concurrent (forkIO) import Control.Exception (finally) import Control.Monad +import Control.Monad.Trans.Except import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call.Unregistered -import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, withCompletionQueue) +import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue + , withCompletionQueue + , createCompletionQueue) import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall) import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op (Op (..), OpRecvResult (..), - runOps) -import Network.GRPC.LowLevel.Server (Server (..)) +import Network.GRPC.LowLevel.Op (Op (..) + , OpRecvResult (..) + , runOps + , runStreamingProxy + , streamRecv + , streamSend + , runOps' + , sendInitialMetadata + , sendStatusFromServer + , recvInitialMessage) +import Network.GRPC.LowLevel.Server (Server (..) + , ServerReaderHandler + , ServerWriterHandler + , ServerRWHandler) import qualified Network.GRPC.Unsafe.Op as C serverCreateCall :: Server - -> CompletionQueue -- ^ call CQ -> IO (Either GRPCIOError ServerCall) -serverCreateCall Server{..} = serverRequestCall unsafeServer serverCQ +serverCreateCall Server{..} = do + callCQ <- createCompletionQueue serverGRPC + serverRequestCall unsafeServer serverCQ callCQ withServerCall :: Server -> (ServerCall -> IO (Either GRPCIOError a)) -> IO (Either GRPCIOError a) withServerCall s f = - withCompletionQueue (serverGRPC s) $ - serverCreateCall s >=> \case - Left e -> return (Left e) - Right c -> f c `finally` do - grpcDebug "withServerCall: destroying." - destroyServerCall c + serverCreateCall s >>= \case + Left e -> return (Left e) + Right c -> f c `finally` do + grpcDebug "withServerCall: destroying." + destroyServerCall c + +-- | Gets a call and then forks the given function on a new thread, with the +-- new call as input. Blocks until a call is received, then returns immediately. +-- Handles cleaning up the call safely. +-- Because this function doesn't wait for the handler to return, it cannot +-- return errors. +withServerCallAsync :: Server + -> (ServerCall -> IO ()) + -> IO () +withServerCallAsync s f = + serverCreateCall s >>= \case + Left e -> return () + Right c -> void $ forkIO (f c `finally` do + grpcDebug "withServerCallAsync: destroying." + destroyServerCall c) -- | Sequence of 'Op's needed to receive a normal (non-streaming) call. -- TODO: We have to put 'OpRecvCloseOnServer' in the response ops, or else the @@ -63,9 +93,16 @@ serverHandleNormalCall :: Server -> MetadataMap -- ^ Initial server metadata. -> ServerHandler -> IO (Either GRPCIOError ()) -serverHandleNormalCall s initMeta f = withServerCall s go - where - go sc@ServerCall{ unsafeSC = c, callCQ = cq, .. } = do +serverHandleNormalCall s initMeta f = + withServerCall s $ \c -> serverHandleNormalCall' s c initMeta f + +serverHandleNormalCall' :: Server + -> ServerCall + -> MetadataMap -- ^ Initial server metadata. + -> ServerHandler + -> IO (Either GRPCIOError ()) +serverHandleNormalCall' + s sc@ServerCall{ unsafeSC = c, callCQ = cq, .. } initMeta f = do grpcDebug "serverHandleNormalCall(U): starting batch." runOps c cq [ OpSendInitialMetadata initMeta @@ -92,3 +129,45 @@ serverHandleNormalCall s initMeta f = withServerCall s go grpcDebug "serverHandleNormalCall(U): ops done." return $ Right () x -> error $ "impossible pattern match: " ++ show x + +serverReader :: Server + -> ServerCall + -> MetadataMap -- ^ initial server metadata + -> ServerReaderHandler + -> IO (Either GRPCIOError ()) +serverReader s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = + runExceptT $ do + (mmsg, trailMeta, st, ds) <- + runStreamingProxy "serverReader" c ccq (f (convertCall sc) streamRecv) + runOps' c ccq ( OpSendInitialMetadata initMeta + : OpSendStatusFromServer trailMeta st ds + : maybe [] ((:[]) . OpSendMessage) mmsg + ) + return () + +serverWriter :: Server + -> ServerCall + -> MetadataMap + -- ^ Initial server metadata + -> ServerWriterHandler + -> IO (Either GRPCIOError ()) +serverWriter s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = + runExceptT $ do + bs <- recvInitialMessage c ccq + sendInitialMetadata c ccq initMeta + let regCall = fmap (const bs) (convertCall sc) + st <- runStreamingProxy "serverWriter" c ccq (f regCall streamSend) + sendStatusFromServer c ccq st + +serverRW :: Server + -> ServerCall + -> MetadataMap + -- ^ initial server metadata + -> ServerRWHandler + -> IO (Either GRPCIOError ()) +serverRW s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = + runExceptT $ do + sendInitialMetadata c ccq initMeta + let regCall = convertCall sc + st <- runStreamingProxy "serverRW" c ccq (f regCall streamRecv streamSend) + sendStatusFromServer c ccq st diff --git a/src/Network/GRPC/Unsafe/Slice.chs b/src/Network/GRPC/Unsafe/Slice.chs index 47df3fd..c33e3ef 100644 --- a/src/Network/GRPC/Unsafe/Slice.chs +++ b/src/Network/GRPC/Unsafe/Slice.chs @@ -33,7 +33,7 @@ deriving instance Show Slice -- slice. {#fun gpr_slice_start_ as ^ {`Slice'} -> `Ptr CChar' castPtr #} -{#fun gpr_slice_from_copied_string_ as ^ {`CString'} -> `Slice'#} +{#fun gpr_slice_from_copied_buffer_ as ^ {`CString', `Int'} -> `Slice'#} -- | Properly cleans up all memory used by a 'Slice'. Danger: the Slice should -- not be used after this function is called on it. @@ -52,4 +52,4 @@ sliceToByteString slice = do -- | Copies a 'ByteString' to a 'Slice'. byteStringToSlice :: B.ByteString -> IO Slice -byteStringToSlice bs = B.useAsCString bs gprSliceFromCopiedString +byteStringToSlice bs = B.useAsCStringLen bs $ uncurry gprSliceFromCopiedBuffer diff --git a/tests/LowLevelTests.hs b/tests/LowLevelTests.hs index 7271be5..f54bde3 100644 --- a/tests/LowLevelTests.hs +++ b/tests/LowLevelTests.hs @@ -49,8 +49,11 @@ lowLevelTests = testGroup "Unit tests of low-level Haskell library" , testClientCompression , testClientServerCompression , testClientStreaming + , testClientStreamingUnregistered , testServerStreaming + , testServerStreamingUnregistered , testBiDiStreaming + , testBiDiStreamingUnregistered ] testGRPCBracket :: TestTree @@ -184,6 +187,37 @@ testServerStreaming = return (dummyMeta, StatusOk, "dtls") r @?= Right () +-- TODO: these unregistered streaming tests are basically the same as the +-- registered ones. Reduce duplication. +-- TODO: Once client-side unregistered streaming functions are added, switch +-- to using them in these tests. +testServerStreamingUnregistered :: TestTree +testServerStreamingUnregistered = + csTest "unregistered server streaming" client server ([],[],[],[]) + where + clientInitMD = [("client","initmd")] + serverInitMD = [("server","initmd")] + clientPay = "FEED ME!" + pays = ["ONE", "TWO", "THREE", "FOUR"] :: [ByteString] + + client c = do + rm <- clientRegisterMethodServerStreaming c "/feed" + eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do + liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD + forM_ pays $ \p -> recv `is` Right (Just p) + recv `is` Right Nothing + eea @?= Right (dummyMeta, StatusOk, "dtls") + + server s = U.withServerCallAsync s $ \call -> do + r <- U.serverWriter s call serverInitMD $ \sc send -> do + liftIO $ do + checkMD "Server request metadata mismatch" + clientInitMD (requestMetadataRecv sc) + optionalPayload sc @?= clientPay + forM_ pays $ \p -> send p `is` Right () + return (dummyMeta, StatusOk, "dtls") + r @?= Right () + testClientStreaming :: TestTree testClientStreaming = csTest "client streaming" client server ([],["/slurp"],[],[]) @@ -213,6 +247,34 @@ testClientStreaming = return (Just serverRsp, trailMD, serverStatus, serverDtls) eea @?= Right () +testClientStreamingUnregistered :: TestTree +testClientStreamingUnregistered = + csTest "unregistered client streaming" client server ([],[],[],[]) + where + clientInitMD = [("a","b")] + serverInitMD = [("x","y")] + trailMD = dummyMeta + serverRsp = "serverReader reply" + serverDtls = "deets" + serverStatus = StatusOk + pays = ["P_ONE", "P_TWO", "P_THREE"] :: [ByteString] + + client c = do + rm <- clientRegisterMethodClientStreaming c "/slurp" + eea <- clientWriter c rm 10 clientInitMD $ \send -> do + -- liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD + forM_ pays $ \p -> send p `is` Right () + eea @?= Right (Just serverRsp, serverInitMD, trailMD, serverStatus, serverDtls) + + server s = U.withServerCallAsync s $ \call -> do + eea <- U.serverReader s call serverInitMD $ \sc recv -> do + liftIO $ checkMD "Client request metadata mismatch" + clientInitMD (requestMetadataRecv sc) + forM_ pays $ \p -> recv `is` Right (Just p) + recv `is` Right Nothing + return (Just serverRsp, trailMD, serverStatus, serverDtls) + eea @?= Right () + testBiDiStreaming :: TestTree testBiDiStreaming = csTest "bidirectional streaming" client server ([],[],[],["/bidi"]) @@ -249,6 +311,41 @@ testBiDiStreaming = return (trailMD, serverStatus, serverDtls) eea @?= Right () +testBiDiStreamingUnregistered :: TestTree +testBiDiStreamingUnregistered = + csTest "unregistered bidirectional streaming" client server ([],[],[],[]) + where + clientInitMD = [("bidi-streaming","client")] + serverInitMD = [("bidi-streaming","server")] + trailMD = dummyMeta + serverStatus = StatusOk + serverDtls = "deets" + is act x = act >>= liftIO . (@?= x) + + client c = do + rm <- clientRegisterMethodBiDiStreaming c "/bidi" + eea <- clientRW c rm 10 clientInitMD $ \initMD recv send -> do + send "cw0" `is` Right () + recv `is` Right (Just "sw0") + send "cw1" `is` Right () + recv `is` Right (Just "sw1") + recv `is` Right (Just "sw2") + return () + eea @?= Right (trailMD, serverStatus, serverDtls) + + server s = U.withServerCallAsync s $ \call -> do + eea <- U.serverRW s call serverInitMD $ \sc recv send -> do + liftIO $ checkMD "Client request metadata mismatch" + clientInitMD (requestMetadataRecv sc) + recv `is` Right (Just "cw0") + send "sw0" `is` Right () + recv `is` Right (Just "cw1") + send "sw1" `is` Right () + send "sw2" `is` Right () + recv `is` Right Nothing + return (trailMD, serverStatus, serverDtls) + eea @?= Right () + -------------------------------------------------------------------------------- -- Unregistered tests diff --git a/tests/UnsafeTests.hs b/tests/UnsafeTests.hs index a43a237..29593dc 100644 --- a/tests/UnsafeTests.hs +++ b/tests/UnsafeTests.hs @@ -26,6 +26,7 @@ import Test.Tasty.HUnit as HU (testCase, (@?=), unsafeTests :: TestTree unsafeTests = testGroup "Unit tests for unsafe C bindings" [ roundtripSlice "Hello, world!" + , roundtripSlice "\NULabc\NUL" , roundtripByteBuffer "Hwaet! We gardena in geardagum..." , roundtripSlice largeByteString , roundtripByteBuffer largeByteString @@ -45,7 +46,7 @@ roundtripSlice :: B.ByteString -> TestTree roundtripSlice bs = testCase "ByteString slice roundtrip" $ do slice <- byteStringToSlice bs unslice <- sliceToByteString slice - bs HU.@?= unslice + unslice HU.@?= bs freeSlice slice roundtripByteBuffer :: B.ByteString -> TestTree