From 2ad0465df6ff5b30e67c74c5a24434e921fe66be Mon Sep 17 00:00:00 2001 From: Connor Clark Date: Tue, 24 May 2016 13:34:50 -0700 Subject: [PATCH] Begin safe low-level Haskell layer (#7) * grpc_server_request_call * basic slice functionality * rename function to emphasize side effects * add docs * ByteBuffer function bindings * replace unsafeCoerce with more specific function, add docs, tests. * add newtypes for Tag and Reserved void pointers * manually fix request_registered_call binding * use nocode keyword to fix Ptr () problems * decouple copying Slice from freeing slice * Add time ops * remove nocode decls * Start Op module, fix c2hs preprocessing order * metadata manipulation operations * metadata free function, test * helper functions for constructing ops of each type * bindings for op creation functions * finish up Op creation functions, implement Op destruction, add docs. * tweak documentation * rework Op creation functions to work with an array of ops, for ease of use with grpc_call_start_batch * forgot to change return types * wrap hook lines, fix types to op creation functions * implement part of the payload test * hideous, but working, end to end test * bindings for connectivity state checks, split test into two threads * various cleanup * rename Core to Unsafe for emphasis, clean up tests more * begin safe low-level facilities * begin completion queue and server stuff * Finish server start/stop, cq start/stop, add tests * facilities for safely executing op batches * reorganize LowLevel modules, begin explicit export list * client functionality, stub payload test, various refactors * tweak cabal file, add test * add more documentation * doc tweaks * begin refactor to improve CompletionQueue safety * export only thread-safe CQ functions, add registered call creation and other CQ utilities * begin refactor to use GRPCIO monad, fix missing push semaphore, fix mem leak in server calls * switch to explicit Either where needed * add crashing tests, continue fleshing out serverHandleNormalCall * fix haddock error, finish first draft of request handling function * reduce GHC warnings * non-registered client request helpers * initial request/response test working * don't pass tags around; generate where needed * server call bracket functions * correct order of semaphore acquisition and shutdown check * simple debug flag logging, simplify Call type * fix various registered method issues (but still not working) * cleanup * delete old code * remove old todo * use MetadataMap synonym pervasively * more comments * update TODOs * tweak safety caveat * docs tweaks * improve haddocks * add casts to eliminate clang warnings, remove unused function * update options to eliminate cabal warnings * remove outdated todo * remove unneeded exports from CompletionQueue * rename to GRPCIOCallError, re-add create/shutdown exports (needed for Server module) * newtypes for hosts and method names * more newtypes * more debug logging * Fix flag name collision * instrument uses of free * more debug * switch to STM for completion queue stuff * reduce warnings * more debugging, create/destroy call tests * refactor, fix failure cleanup for server call creation. More tests passing. * formatting tweaks --- cbits/grpc_haskell.c | 74 ++-- grpc-haskell.cabal | 26 ++ include/grpc_haskell.h | 11 + src/Network/GRPC/LowLevel.hs | 60 ++++ src/Network/GRPC/LowLevel/Call.hs | 109 ++++++ src/Network/GRPC/LowLevel/Client.hs | 223 ++++++++++++ src/Network/GRPC/LowLevel/CompletionQueue.hs | 352 +++++++++++++++++++ src/Network/GRPC/LowLevel/GRPC.hs | 85 +++++ src/Network/GRPC/LowLevel/Op.hs | 203 +++++++++++ src/Network/GRPC/LowLevel/Server.hs | 251 +++++++++++++ src/Network/GRPC/Unsafe.chs | 31 +- src/Network/GRPC/Unsafe/ByteBuffer.chs | 5 + src/Network/GRPC/Unsafe/Metadata.chs | 26 ++ src/Network/GRPC/Unsafe/Op.chs | 8 +- src/Network/GRPC/Unsafe/Slice.chs | 3 - src/Network/GRPC/Unsafe/Time.chs | 5 +- tests/LowLevelTests.hs | 158 +++++++++ tests/Properties.hs | 39 +- 18 files changed, 1630 insertions(+), 39 deletions(-) create mode 100644 src/Network/GRPC/LowLevel.hs create mode 100644 src/Network/GRPC/LowLevel/Call.hs create mode 100644 src/Network/GRPC/LowLevel/Client.hs create mode 100644 src/Network/GRPC/LowLevel/CompletionQueue.hs create mode 100644 src/Network/GRPC/LowLevel/GRPC.hs create mode 100644 src/Network/GRPC/LowLevel/Op.hs create mode 100644 src/Network/GRPC/LowLevel/Server.hs create mode 100644 tests/LowLevelTests.hs diff --git a/cbits/grpc_haskell.c b/cbits/grpc_haskell.c index af429d6..320da31 100644 --- a/cbits/grpc_haskell.c +++ b/cbits/grpc_haskell.c @@ -7,6 +7,13 @@ #include #include +void grpc_haskell_free(char *debugMsg, void *ptr){ + #ifdef GRPC_HASKELL_DEBUG + printf("C wrapper: freeing %s, ptr: %p\n", debugMsg, ptr); + #endif + free(ptr); +} + grpc_event *grpc_completion_queue_next_(grpc_completion_queue *cq, gpr_timespec *deadline, void *reserved) { @@ -49,9 +56,13 @@ gpr_slice* gpr_slice_from_copied_string_(const char *source){ return retval; } +void gpr_slice_unref_(gpr_slice* slice){ + gpr_slice_unref(*slice); +} + void free_slice(gpr_slice *slice){ gpr_slice_unref(*slice); - free(slice); + grpc_haskell_free("free_slice", slice); } grpc_byte_buffer **create_receiving_byte_buffer(){ @@ -62,7 +73,7 @@ grpc_byte_buffer **create_receiving_byte_buffer(){ void destroy_receiving_byte_buffer(grpc_byte_buffer **bb){ grpc_byte_buffer_destroy(*bb); - free(bb); + grpc_haskell_free("destroy_receiving_byte_buffer", bb); } grpc_byte_buffer_reader *byte_buffer_reader_create(grpc_byte_buffer *buffer){ @@ -73,7 +84,7 @@ grpc_byte_buffer_reader *byte_buffer_reader_create(grpc_byte_buffer *buffer){ void byte_buffer_reader_destroy(grpc_byte_buffer_reader *reader){ grpc_byte_buffer_reader_destroy(reader); - free(reader); + grpc_haskell_free("byte_buffer_reader_destroy", reader); } gpr_slice *grpc_byte_buffer_reader_readall_(grpc_byte_buffer_reader *reader){ @@ -83,7 +94,7 @@ gpr_slice *grpc_byte_buffer_reader_readall_(grpc_byte_buffer_reader *reader){ } void timespec_destroy(gpr_timespec* t){ - free(t); + grpc_haskell_free("timespec_destroy", t); } gpr_timespec* gpr_inf_future_(gpr_clock_type t){ @@ -125,8 +136,8 @@ grpc_metadata_array** metadata_array_create(){ void metadata_array_destroy(grpc_metadata_array **arr){ grpc_metadata_array_destroy(*arr); - free(*arr); - free(arr); + grpc_haskell_free("metadata_array_destroy1", *arr); + grpc_haskell_free("metadata_array_destroy1", arr); } grpc_metadata* metadata_alloc(size_t n){ @@ -135,7 +146,7 @@ grpc_metadata* metadata_alloc(size_t n){ } void metadata_free(grpc_metadata* m){ - free(m); + grpc_haskell_free("metadata_free", m); } void set_metadata_key_val(char *key, char *val, grpc_metadata *arr, size_t i){ @@ -174,8 +185,10 @@ void op_array_destroy(grpc_op* op_array, size_t n){ case GRPC_OP_SEND_CLOSE_FROM_CLIENT: break; case GRPC_OP_SEND_STATUS_FROM_SERVER: - free(op->data.send_status_from_server.trailing_metadata); - free(op->data.send_status_from_server.status_details); + grpc_haskell_free("op_array_destroy: GRPC_OP_SEND_STATUS_FROM_SERVER", + op->data.send_status_from_server.trailing_metadata); + grpc_haskell_free("op_array_destroy: GRPC_OP_SEND_STATUS_FROM_SERVER", + (char*)(op->data.send_status_from_server.status_details)); break; case GRPC_OP_RECV_INITIAL_METADATA: break; @@ -187,7 +200,7 @@ void op_array_destroy(grpc_op* op_array, size_t n){ break; } } - free(op_array); + grpc_haskell_free("op_array_destroy", op_array); } void op_send_initial_metadata(grpc_op *op_array, size_t i, @@ -280,17 +293,7 @@ void op_send_status_server(grpc_op *op_array, size_t i, op->data.send_status_from_server.status = status; op->data.send_status_from_server.status_details = malloc(sizeof(char)*(strlen(details) + 1)); - strcpy(op->data.send_status_from_server.status_details, details); - op->flags = 0; - op->reserved = NULL; -} - -void op_send_ok_status_server(grpc_op *op_array, size_t i){ - grpc_op *op = op_array + i; - op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; - op->data.send_status_from_server.trailing_metadata_count = 0; - op->data.send_status_from_server.status = GRPC_STATUS_OK; - op->data.send_status_from_server.status_details = "OK"; + strcpy((char*)(op->data.send_status_from_server.status_details), details); op->flags = 0; op->reserved = NULL; } @@ -299,8 +302,12 @@ grpc_status_code* create_status_code_ptr(){ return malloc(sizeof(grpc_status_code)); } +grpc_status_code deref_status_code_ptr(grpc_status_code* p){ + return *p; +} + void destroy_status_code_ptr(grpc_status_code* p){ - free(p); + grpc_haskell_free("destroy_status_code_ptr", p); } grpc_call_details* create_call_details(){ @@ -311,7 +318,7 @@ grpc_call_details* create_call_details(){ void destroy_call_details(grpc_call_details* cd){ grpc_call_details_destroy(cd); - free(cd); + grpc_haskell_free("destroy_call_details", cd); } void grpc_channel_watch_connectivity_state_(grpc_channel *channel, @@ -323,3 +330,24 @@ void grpc_channel_watch_connectivity_state_(grpc_channel *channel, grpc_channel_watch_connectivity_state(channel, last_observed_state, *deadline, cq, tag); } + +grpc_metadata* metadata_array_get_metadata(grpc_metadata_array* arr){ + return arr->metadata; +} + +size_t metadata_array_get_count(grpc_metadata_array* arr){ + return arr->count; +} + +grpc_call* grpc_channel_create_registered_call_( + grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask, + grpc_completion_queue *completion_queue, void *registered_call_handle, + gpr_timespec *deadline, void *reserved){ + #ifdef GRPC_HASKELL_DEBUG + printf("calling grpc_channel_create_registered_call with deadline %p\n", + deadline); + #endif + return grpc_channel_create_registered_call(channel, parent_call, + propagation_mask, completion_queue, registered_call_handle, + *deadline, reserved); +} diff --git a/grpc-haskell.cabal b/grpc-haskell.cabal index 3f4c8a9..4cadf7c 100644 --- a/grpc-haskell.cabal +++ b/grpc-haskell.cabal @@ -13,11 +13,18 @@ build-type: Simple cabal-version: >=1.10 extra-source-files: cbits, include +Flag Debug + Description: Adds debug logging. + Manual: True + Default: False + library build-depends: base ==4.8.* , clock ==0.6.* , bytestring ==0.10.* + , stm == 2.4.* + , containers ==0.5.* c-sources: cbits/grpc_haskell.c exposed-modules: @@ -29,6 +36,14 @@ library Network.GRPC.Unsafe.Metadata Network.GRPC.Unsafe.Op Network.GRPC.Unsafe + Network.GRPC.LowLevel + other-modules: + Network.GRPC.LowLevel.CompletionQueue + Network.GRPC.LowLevel.GRPC + Network.GRPC.LowLevel.Op + Network.GRPC.LowLevel.Server + Network.GRPC.LowLevel.Call + Network.GRPC.LowLevel.Client extra-libraries: grpc includes: @@ -44,6 +59,11 @@ library ghc-options: -Wall -fwarn-incomplete-patterns include-dirs: include hs-source-dirs: src + default-extensions: CPP + + if flag(debug) + CPP-Options: -DDEBUG + CC-Options: -DGRPC_HASKELL_DEBUG test-suite test build-depends: @@ -55,8 +75,14 @@ test-suite test , async , tasty >= 0.11 && <0.12 , tasty-hunit >= 0.9 && <0.10 + , containers ==0.5.* + other-modules: LowLevelTests default-language: Haskell2010 ghc-options: -Wall -fwarn-incomplete-patterns -g -threaded hs-source-dirs: tests main-is: Properties.hs type: exitcode-stdio-1.0 + extensions: CPP + + if flag(debug) + GHC-Options: -DDEBUG diff --git a/include/grpc_haskell.h b/include/grpc_haskell.h index faadfbb..aaa77f8 100644 --- a/include/grpc_haskell.h +++ b/include/grpc_haskell.h @@ -99,6 +99,8 @@ void op_send_status_server(grpc_op *op_array, size_t i, grpc_status_code* create_status_code_ptr(); +grpc_status_code deref_status_code_ptr(grpc_status_code* p); + void destroy_status_code_ptr(grpc_status_code* p); grpc_call_details* create_call_details(); @@ -112,4 +114,13 @@ void grpc_channel_watch_connectivity_state_(grpc_channel *channel, grpc_completion_queue *cq, void *tag); +grpc_metadata* metadata_array_get_metadata(grpc_metadata_array* arr); + +size_t metadata_array_get_count(grpc_metadata_array* arr); + +grpc_call* grpc_channel_create_registered_call_( + grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask, + grpc_completion_queue *completion_queue, void *registered_call_handle, + gpr_timespec *deadline, void *reserved); + #endif //GRPC_HASKELL diff --git a/src/Network/GRPC/LowLevel.hs b/src/Network/GRPC/LowLevel.hs new file mode 100644 index 0000000..ae348a5 --- /dev/null +++ b/src/Network/GRPC/LowLevel.hs @@ -0,0 +1,60 @@ +-- | Low-level safe interface to gRPC. By "safe", we mean: +-- 1. all gRPC objects are guaranteed to be cleaned up correctly. +-- 2. all functions are thread-safe. +-- 3. all functions leave gRPC in a consistent, safe state. +-- These guarantees only apply to the functions exported by this module, +-- and not to helper functions in submodules that aren't exported here. + +{-# LANGUAGE RecordWildCards #-} + +module Network.GRPC.LowLevel ( +-- * Important types +GRPC +, withGRPC +, GRPCIOError(..) +, StatusCode(..) + +-- * Completion queue utilities +, CompletionQueue +, withCompletionQueue + +-- * Calls +, GRPCMethodType(..) +, RegisteredMethod +, Call +, NormalRequestResult(..) + +-- * Server +, ServerConfig(..) +, Server +, registeredMethods +, withServer +, serverHandleNormalRegisteredCall +, serverHandleNormalCall +, withServerCall +, withServerRegisteredCall + +-- * Client +, ClientConfig(..) +, Client +, withClient +, clientRegisterMethod +, clientRegisteredRequest +, clientRequest +, withClientCall + +-- * Ops +, runOps +, Op(..) +, OpRecvResult(..) + +) where + +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Server +import Network.GRPC.LowLevel.CompletionQueue +import Network.GRPC.LowLevel.Op +import Network.GRPC.LowLevel.Client +import Network.GRPC.LowLevel.Call + +import Network.GRPC.Unsafe.Op (StatusCode(..)) diff --git a/src/Network/GRPC/LowLevel/Call.hs b/src/Network/GRPC/LowLevel/Call.hs new file mode 100644 index 0000000..15559db --- /dev/null +++ b/src/Network/GRPC/LowLevel/Call.hs @@ -0,0 +1,109 @@ +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +module Network.GRPC.LowLevel.Call where + +import Control.Monad +import Data.String (IsString) +import Foreign.Marshal.Alloc (free) +import Foreign.Ptr (Ptr, castPtr) +import Foreign.Storable (peek) + +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Time as C +import qualified Network.GRPC.Unsafe.Metadata as C +import qualified Network.GRPC.Unsafe.ByteBuffer as C + +import Network.GRPC.LowLevel.GRPC (grpcDebug) + +-- | Models the four types of RPC call supported by gRPC. We currently only +-- support the first alternative, and only in a preliminary fashion. +data GRPCMethodType = Normal | ClientStreaming | ServerStreaming | BiDiStreaming + deriving (Show, Eq, Ord, Enum) + +newtype MethodName = MethodName {unMethodName :: String} + deriving (Show, Eq, IsString) + +newtype Host = Host {unHost :: String} + deriving (Show, Eq, IsString) + +-- | Represents a registered method. Methods can optionally be registered in +-- order to make the C-level request/response code simpler. +-- Before making or awaiting a registered call, the +-- method must be registered with the client (see 'clientRegisterMethod') and +-- the server (see 'serverRegisterMethod'). +-- Contains state for identifying that method in the underlying gRPC library. +data RegisteredMethod = RegisteredMethod {methodType :: GRPCMethodType, + methodName :: MethodName, + methodHost :: Host, + methodHandle :: C.CallHandle} + +-- | Represents one GRPC call (i.e. request). This type is used on both the +-- client and server. Contains pointers to all the necessary C state needed to +-- send and respond to a call. +-- This is used to associate send/receive 'Op's with a request. +-- There are separate functions for creating these depending on whether the +-- method is registered and whether the call is on the client or server side. +data Call = ClientCall {internalCall :: C.Call} + | ServerCall + {internalCall :: C.Call, + requestMetadataRecv :: (Ptr C.MetadataArray), + optionalPayload :: Maybe (Ptr C.ByteBuffer), + parentPtr :: Maybe (Ptr C.Call), + callDetails :: Maybe (C.CallDetails), + -- ^ used on the server for non-registered calls + --, to identify the endpoint being used. + callDeadline :: Maybe C.CTimeSpecPtr + } + +debugCall :: Call -> IO () +#ifdef DEBUG +debugCall (ClientCall (C.Call ptr)) = + grpcDebug $ "debugCall: client call: " ++ (show ptr) +debugCall call@(ServerCall (C.Call ptr) _ _ _ _ _) = do + grpcDebug $ "debugCall: server call: " ++ (show ptr) + grpcDebug $ "debugCall: metadata ptr: " ++ show (requestMetadataRecv call) + metadataArr <- peek (requestMetadataRecv call) + metadata <- C.getAllMetadataArray metadataArr + grpcDebug $ "debugCall: metadata received: " ++ (show metadata) + forM_ (optionalPayload call) $ \payloadPtr -> do + grpcDebug $ "debugCall: payload ptr: " ++ show payloadPtr + payload <- peek payloadPtr + bs <- C.copyByteBufferToByteString payload + grpcDebug $ "debugCall: payload contents: " ++ show bs + forM_ (parentPtr call) $ \parentPtr' -> do + grpcDebug $ "debugCall: parent ptr: " ++ show parentPtr' + (C.Call parent) <- peek parentPtr' + grpcDebug $ "debugCall: parent: " ++ show parent + forM_ (callDetails call) $ \(C.CallDetails callDetailsPtr) -> do + grpcDebug $ "debugCall: callDetails ptr: " ++ show callDetailsPtr + --TODO: need functions for getting data out of call_details. + forM_ (callDeadline call) $ \timespecptr -> do + grpcDebug $ "debugCall: deadline ptr: " ++ show timespecptr + timespec <- peek timespecptr + grpcDebug $ "debugCall: deadline: " ++ show (C.timeSpec timespec) +#else +{-# INLINE debugCall #-} +debugCall = const $ return () +#endif + +-- | Destroys a 'Call'. +destroyCall :: Call -> IO () +destroyCall ClientCall{..} = do + grpcDebug "Destroying client-side call object." + C.grpcCallDestroy internalCall +destroyCall call@ServerCall{..} = do + grpcDebug "destroyCall: entered." + debugCall call + grpcDebug $ "Destroying server-side call object: " ++ show internalCall + C.grpcCallDestroy internalCall + grpcDebug $ "destroying metadata array: " ++ show requestMetadataRecv + C.metadataArrayDestroy requestMetadataRecv + grpcDebug $ "destroying optional payload" ++ show optionalPayload + forM_ optionalPayload C.destroyReceivingByteBuffer + grpcDebug $ "freeing parentPtr: " ++ show parentPtr + forM_ parentPtr free + grpcDebug $ "destroying call details" ++ show callDetails + forM_ callDetails C.destroyCallDetails + grpcDebug $ "destroying deadline." ++ show callDeadline + forM_ callDeadline C.timespecDestroy diff --git a/src/Network/GRPC/LowLevel/Client.hs b/src/Network/GRPC/LowLevel/Client.hs new file mode 100644 index 0000000..eea43c1 --- /dev/null +++ b/src/Network/GRPC/LowLevel/Client.hs @@ -0,0 +1,223 @@ +{-# LANGUAGE RecordWildCards #-} + +module Network.GRPC.LowLevel.Client where + +import Control.Exception (bracket, finally) +import Data.ByteString (ByteString) +import Foreign.Ptr (nullPtr) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Time as C +import qualified Network.GRPC.Unsafe.Constants as C +import qualified Network.GRPC.Unsafe.Op as C + +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.CompletionQueue +import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.Op + +-- | Represents the context needed to perform client-side gRPC operations. +data Client = Client {clientChannel :: C.Channel, + clientCQ :: CompletionQueue} + +-- | Configuration necessary to set up a client. +data ClientConfig = ClientConfig {clientServerHost :: Host, + clientServerPort :: Int} + +createClient :: GRPC -> ClientConfig -> IO Client +createClient grpc ClientConfig{..} = do + let hostPort = (unHost clientServerHost) ++ ":" ++ (show clientServerPort) + chan <- C.grpcInsecureChannelCreate hostPort nullPtr C.reserved + cq <- createCompletionQueue grpc + return $ Client chan cq + +destroyClient :: Client -> IO () +destroyClient Client{..} = do + shutdownResult <- shutdownCompletionQueue clientCQ + case shutdownResult of + Left x -> do putStrLn $ "Failed to stop client CQ: " ++ show x + putStrLn $ "Trying to shut down anyway." + Right _ -> return () + C.grpcChannelDestroy clientChannel + +withClient :: GRPC -> ClientConfig -> (Client -> IO a) -> IO a +withClient grpc config = bracket (createClient grpc config) + (\c -> grpcDebug "withClient: destroying." + >> destroyClient c) + +-- | Register a method on the client so that we can call it with +-- 'clientRegisteredRequest'. +clientRegisterMethod :: Client + -> MethodName + -- ^ method name, e.g. "/foo" + -> Host + -- ^ host name, e.g. "localhost" + -> GRPCMethodType + -> IO RegisteredMethod +clientRegisterMethod Client{..} name host Normal = do + handle <- C.grpcChannelRegisterCall clientChannel (unMethodName name) + (unHost host) C.reserved + return $ RegisteredMethod Normal name host handle +clientRegisterMethod _ _ _ _ = error "Streaming methods not yet implemented." + +-- | Create a new call on the client for a registered method. +-- Returns 'Left' if the CQ is shutting down or if the job to create a call +-- timed out. +clientCreateRegisteredCall :: Client -> RegisteredMethod -> TimeoutSeconds + -> IO (Either GRPCIOError Call) +clientCreateRegisteredCall Client{..} RegisteredMethod{..} timeout = do + let parentCall = C.Call nullPtr --Unsure what this does. null is safe, though. + C.withDeadlineSeconds timeout $ \deadline -> do + channelCreateRegisteredCall clientChannel parentCall C.propagateDefaults + clientCQ methodHandle deadline + +-- TODO: the error-handling refactor made this quite ugly. It could be fixed +-- by switching to ExceptT IO. +-- | Handles safe creation and cleanup of a client call +withClientRegisteredCall :: Client -> RegisteredMethod -> TimeoutSeconds + -> (Call + -> IO (Either GRPCIOError a)) + -> IO (Either GRPCIOError a) +withClientRegisteredCall client regmethod timeout f = do + createResult <- clientCreateRegisteredCall client regmethod timeout + case createResult of + Left x -> return $ Left x + Right call -> f call `finally` logDestroy call + where logDestroy c = grpcDebug "withClientRegisteredCall: destroying." + >> destroyCall c + +-- | Create a call on the client for an endpoint without using the +-- method registration machinery. In practice, we'll probably only use the +-- registered method version, but we include this for completeness and testing. +clientCreateCall :: Client + -> MethodName + -- ^ The method name + -> Host + -- ^ The host. + -> TimeoutSeconds + -> IO (Either GRPCIOError Call) +clientCreateCall Client{..} method host timeout = do + let parentCall = C.Call nullPtr + C.withDeadlineSeconds timeout $ \deadline -> do + channelCreateCall clientChannel parentCall C.propagateDefaults + clientCQ method host deadline + +withClientCall :: Client -> MethodName -> Host -> TimeoutSeconds + -> (Call -> IO (Either GRPCIOError a)) + -> IO (Either GRPCIOError a) +withClientCall client method host timeout f = do + createResult <- clientCreateCall client method host timeout + case createResult of + Left x -> return $ Left x + Right call -> f call `finally` logDestroy call + where logDestroy c = grpcDebug "withClientCall: destroying." + >> destroyCall c + +data NormalRequestResult = NormalRequestResult + ByteString + MetadataMap --init metadata + MetadataMap --trailing metadata + C.StatusCode + deriving (Show, Eq) + +-- | Function for assembling call result when the 'MethodType' is 'Normal'. +compileNormalRequestResults :: [OpRecvResult] -> NormalRequestResult +compileNormalRequestResults + --TODO: consider using more precise type instead of match. + -- Whether we do so depends on whether this layer of abstraction is supposed + -- to be a safe interface to the gRPC C core library, or something that makes + -- core use cases easy. + [OpRecvInitialMetadataResult m, + OpRecvMessageResult body, + OpRecvStatusOnClientResult m2 status] + = NormalRequestResult body m m2 status +compileNormalRequestResults _ = + --TODO: impossible case should be enforced by more precise types. + error "non-normal request input to compileNormalRequestResults." + + +-- | Make a request of the given method with the given body. Returns the +-- server's response. TODO: This is preliminary until we figure out how many +-- different variations on sending request ops will be needed for full gRPC +-- functionality. +clientRegisteredRequest :: Client + -> RegisteredMethod + -> TimeoutSeconds + -- ^ Timeout of both the grpc_call and the + -- max time to wait for the completion of the batch. + -- TODO: I think we will need to decouple the + -- lifetime of the call from the queue deadline once + -- we expose functionality for streaming calls, where + -- one call object persists across many batches. + -> ByteString + -- ^ The body of the request. + -> MetadataMap + -- ^ Metadata to send with the request. + -> IO (Either GRPCIOError NormalRequestResult) +clientRegisteredRequest client@(Client{..}) rm@(RegisteredMethod{..}) + timeLimit body meta = + case methodType of + Normal -> withClientRegisteredCall client rm timeLimit $ \call -> do + grpcDebug "clientRegisteredRequest: created call." + debugCall call + --TODO: doing one op at a time to debug. Some were hanging. + let op1 = [OpSendInitialMetadata meta] + res1 <- runOps call clientCQ op1 timeLimit + grpcDebug $ "finished res1: " ++ show res1 + let op2 = [OpSendMessage body] + res2 <- runOps call clientCQ op2 timeLimit + grpcDebug $ "finished res2: " ++ show res2 + let op3 = [OpSendCloseFromClient] + res3 <- runOps call clientCQ op3 timeLimit + grpcDebug $ "finished res3: " ++ show res3 + let op4 = [OpRecvMessage] + res4 <- runOps call clientCQ op4 timeLimit + grpcDebug $ "finished res4: " ++ show res4 + let op5 = [OpRecvStatusOnClient] + res5 <- runOps call clientCQ op5 timeLimit + grpcDebug $ "finished res5: " ++ show res5 + let results = do + r1 <- res1 + r2 <- res2 + r3 <- res3 + r4 <- res4 + r5 <- res5 + return $ r1 ++ r2 ++ r3 ++ r4 ++ r5 + case results of + Left x -> return $ Left x + Right rs -> return $ + Right $ compileNormalRequestResults rs + _ -> error "Streaming methods not yet implemented." + +-- | Makes a normal (non-streaming) request without needing to register a method +-- first. Probably only useful for testing. TODO: This is preliminary, like +-- 'clientRegisteredRequest'. +clientRequest :: Client + -> MethodName + -- ^ Method name, e.g. "/foo" + -> Host + -- ^ Host. Not sure if used. + -> TimeoutSeconds + -> ByteString + -- ^ Request body. + -> MetadataMap + -- ^ Request metadata. + -> IO (Either GRPCIOError NormalRequestResult) +clientRequest client@(Client{..}) (MethodName method) (Host host) + timeLimit body meta = do + withClientCall client (MethodName method) (Host host) timeLimit $ \call -> do + let ops = clientNormalRequestOps body meta + results <- runOps call clientCQ ops timeLimit + grpcDebug "clientRequest: ops ran." + case results of + Left x -> return $ Left x + Right rs -> return $ Right $ compileNormalRequestResults rs + + +clientNormalRequestOps :: ByteString -> MetadataMap -> [Op] +clientNormalRequestOps body metadata = + [OpSendInitialMetadata metadata, + OpSendMessage body, + OpSendCloseFromClient, + OpRecvInitialMetadata, + OpRecvMessage, + OpRecvStatusOnClient] diff --git a/src/Network/GRPC/LowLevel/CompletionQueue.hs b/src/Network/GRPC/LowLevel/CompletionQueue.hs new file mode 100644 index 0000000..e8d4b3a --- /dev/null +++ b/src/Network/GRPC/LowLevel/CompletionQueue.hs @@ -0,0 +1,352 @@ +-- | Unlike most of the other internal low-level modules, we don't export +-- everything here. There are several things in here that, if accessed, could +-- cause race conditions, so we only expose functions that are thread safe. +-- However, some of the functions we export here can cause memory leaks if used +-- improperly. + +{-# LANGUAGE RecordWildCards #-} + +module Network.GRPC.LowLevel.CompletionQueue ( + CompletionQueue + , withCompletionQueue + , createCompletionQueue + , shutdownCompletionQueue + , pluck + , startBatch + , channelCreateRegisteredCall + , channelCreateCall + , TimeoutSeconds + , eventSuccess + , serverRegisterCompletionQueue + , serverShutdownAndNotify + , serverRequestRegisteredCall + , serverRequestCall + , newTag +) where + +import Control.Concurrent (forkIO, threadDelay) +import Control.Concurrent.STM (atomically, retry, check) +import Control.Concurrent.STM.TVar (TVar, newTVarIO, modifyTVar', + readTVar, writeTVar) +import Control.Exception (bracket) +import Data.IORef (IORef, newIORef, atomicModifyIORef') +import Foreign.Marshal.Alloc (malloc, free) +import Foreign.Ptr (nullPtr, plusPtr) +import Foreign.Storable (peek) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Constants as C +import qualified Network.GRPC.Unsafe.Time as C +import qualified Network.GRPC.Unsafe.Op as C +import qualified Network.GRPC.Unsafe.Metadata as C +import System.Timeout (timeout) + +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Call + +-- NOTE: the concurrency requirements for a CompletionQueue are a little +-- complicated. There are two read operations: next and pluck. We can either +-- call next on a CQ or call pluck up to 'maxCompletionQueuePluckers' times +-- concurrently, but we can't mix next and pluck calls. Fortunately, we only +-- need to use next when we are shutting down the queue. Thus, we do two things +-- to shut down: +-- 1. Set the shuttingDown 'TVar' to 'True'. When this is set, no new pluck +-- calls will be allowed to start. +-- 2. Wait until no threads are plucking, as counted by 'currentPluckers'. +-- This logic can be seen in 'pluck' and 'shutdownCompletionQueue'. + +-- NOTE: There is one more possible race condition: pushing work onto the queue +-- after we begin to shut down. +-- Solution: another counter, which must reach zero before the shutdown +-- can start. + +-- TODO: 'currentPushers' currently imposes an arbitrary limit on the number of +-- concurrent pushers to the CQ, but I don't know what the limit should be set +-- to. I haven't found any documentation that suggests there is a limit imposed +-- by the gRPC library, but there might be. Need to investigate further. + +-- | Wraps the state necessary to use a gRPC completion queue. Completion queues +-- are used to wait for batches gRPC operations ('Op's) to finish running, as +-- well as wait for various other operations, such as server shutdown, pinging, +-- checking to see if we've been disconnected, and so forth. +data CompletionQueue = CompletionQueue {unsafeCQ :: C.CompletionQueue, + -- ^ All access to this field must be + -- guarded by a check of 'shuttingDown'. + currentPluckers :: TVar Int, + -- ^ Used to limit the number of + -- concurrent calls to pluck on this + -- queue. + -- The max value is set by gRPC in + -- 'C.maxCompletionQueuePluckers' + currentPushers :: TVar Int, + -- ^ Used to prevent new work from + -- being pushed onto the queue when + -- the queue begins to shut down. + shuttingDown :: TVar Bool, + -- ^ Used to prevent new pluck calls on + -- the queue when the queue begins to + -- shut down. + nextTag :: IORef Int + -- ^ Used to supply unique tags for work + -- items pushed onto the queue. + } + +-- | Create a new 'C.Tag' for identifying work items on the 'CompletionQueue'. +-- This will eventually wrap around after reaching @maxBound :: Int@, but from a +-- practical perspective, that should be safe. +newTag :: CompletionQueue -> IO C.Tag +newTag CompletionQueue{..} = do + i <- atomicModifyIORef' nextTag (\i -> (i+1,i)) + return $ C.Tag $ plusPtr nullPtr i + +maxWorkPushers :: Int +maxWorkPushers = 100 --TODO: figure out what this should be. + +data CQOpType = Push | Pluck deriving (Show, Eq, Enum) + +getCount :: CQOpType -> CompletionQueue -> TVar Int +getCount Push = currentPushers +getCount Pluck = currentPluckers + +getLimit :: CQOpType -> Int +getLimit Push = maxWorkPushers +getLimit Pluck = C.maxCompletionQueuePluckers + +-- | Safely brackets an operation that pushes work onto or plucks results from +-- the given 'CompletionQueue'. +withPermission :: CQOpType + -> CompletionQueue + -> IO (Either GRPCIOError a) + -> IO (Either GRPCIOError a) +withPermission op cq f = + bracket acquire release doOp + where acquire = atomically $ do + isShuttingDown <- readTVar (shuttingDown cq) + if isShuttingDown + then return False + else do currCount <- readTVar $ getCount op cq + if currCount < getLimit op + then modifyTVar' (getCount op cq) (+1) >> return True + else retry + doOp gotResource = if gotResource + then f + else return $ Left GRPCIOShutdown + release gotResource = + if gotResource + then atomically $ modifyTVar' (getCount op cq) (subtract 1) + else return () + +withCompletionQueue :: GRPC -> (CompletionQueue -> IO a) -> IO a +withCompletionQueue grpc = bracket (createCompletionQueue grpc) + shutdownCompletionQueue + +createCompletionQueue :: GRPC -> IO CompletionQueue +createCompletionQueue _ = do + unsafeCQ <- C.grpcCompletionQueueCreate C.reserved + currentPluckers <- newTVarIO 0 + currentPushers <- newTVarIO 0 + shuttingDown <- newTVarIO False + nextTag <- newIORef minBound + return $ CompletionQueue{..} + +type TimeoutSeconds = Int + +-- | Translate 'C.Event' to an error. The caller is responsible for ensuring +-- that the event actually corresponds to an error condition; a successful event +-- will be translated to a 'GRPCIOUnknownError'. +eventToError :: C.Event -> (Either GRPCIOError a) +eventToError (C.Event C.QueueShutdown _ _) = Left GRPCIOShutdown +eventToError (C.Event C.QueueTimeout _ _) = Left GRPCIOTimeout +eventToError _ = Left GRPCIOUnknownError + +isFailedEvent :: C.Event -> Bool +isFailedEvent C.Event{..} = (eventCompletionType /= C.OpComplete) + || not eventSuccess + +-- | Waits for the given number of seconds for the given tag to appear on the +-- completion queue. Throws 'GRPCIOShutdown' if the completion queue is shutting +--down and cannot handle new requests. +pluck :: CompletionQueue -> C.Tag -> TimeoutSeconds + -> IO (Either GRPCIOError ()) +pluck cq@CompletionQueue{..} tag waitSeconds = do + grpcDebug $ "pluck: called with tag: " ++ show tag + ++ " and wait: " ++ show waitSeconds + withPermission Pluck cq $ do + C.withDeadlineSeconds waitSeconds $ \deadline -> do + ev <- C.grpcCompletionQueuePluck unsafeCQ tag deadline C.reserved + grpcDebug $ "pluck: finished. Event: " ++ show ev + if isFailedEvent ev + then return $ eventToError ev + else return $ Right () + +-- TODO: I'm thinking it might be easier to use 'Either' uniformly everywhere +-- even when it's isomorphic to 'Maybe'. If that doesn't turn out to be the +-- case, switch these to 'Maybe'. +-- | Very simple wrapper around 'grpcCallStartBatch'. Throws 'GRPCIOShutdown' +-- without calling 'grpcCallStartBatch' if the queue is shutting down. +-- Throws 'CallError' if 'grpcCallStartBatch' returns a non-OK code. +startBatch :: CompletionQueue -> C.Call -> C.OpArray -> Int -> C.Tag + -> IO (Either GRPCIOError ()) +startBatch cq@CompletionQueue{..} call opArray opArraySize tag = + withPermission Push cq $ fmap throwIfCallError $ do + grpcDebug "startBatch: calling grpc_call_start_batch." + res <- C.grpcCallStartBatch call opArray opArraySize tag C.reserved + grpcDebug "startBatch: grpc_call_start_batch call returned." + return res + + +-- | Shuts down the completion queue. See the comment above 'CompletionQueue' +-- for the strategy we use to ensure that no one tries to use the +-- queue after we begin the shutdown process. Errors with +-- 'GRPCIOShutdownFailure' if the queue can't be shut down within 5 seconds. +shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ()) +shutdownCompletionQueue (CompletionQueue{..}) = do + atomically $ writeTVar shuttingDown True + atomically $ readTVar currentPushers >>= \x -> check (x == 0) + atomically $ readTVar currentPluckers >>= \x -> check (x == 0) + --drain the queue + C.grpcCompletionQueueShutdown unsafeCQ + loopRes <- timeout (5*10^6) drainLoop + case loopRes of + Nothing -> return $ Left GRPCIOShutdownFailure + Just () -> C.grpcCompletionQueueDestroy unsafeCQ >> return (Right ()) + + where drainLoop :: IO () + drainLoop = do + deadline <- C.secondsToDeadline 1 + ev <- C.grpcCompletionQueueNext unsafeCQ deadline C.reserved + case (C.eventCompletionType ev) of + C.QueueShutdown -> return () + C.QueueTimeout -> drainLoop + C.OpComplete -> drainLoop + +-- | Returns true iff the given grpc_event was a success. +eventSuccess :: C.Event -> Bool +eventSuccess (C.Event C.OpComplete True _) = True +eventSuccess _ = False + +channelCreateRegisteredCall :: C.Channel -> C.Call -> C.PropagationMask + -> CompletionQueue -> C.CallHandle + -> C.CTimeSpecPtr -> IO (Either GRPCIOError Call) +channelCreateRegisteredCall + chan parent mask cq@CompletionQueue{..} handle deadline = + withPermission Push cq $ do + call <- C.grpcChannelCreateRegisteredCall chan parent mask unsafeCQ + handle deadline C.reserved + return $ Right $ ClientCall call + +channelCreateCall :: C.Channel -> C.Call -> C.PropagationMask -> CompletionQueue + -> MethodName -> Host -> C.CTimeSpecPtr + -> IO (Either GRPCIOError Call) +channelCreateCall + chan parent mask cq@CompletionQueue{..} (MethodName methodName) (Host host) + deadline = + withPermission Push cq $ do + call <- C.grpcChannelCreateCall chan parent mask unsafeCQ methodName host + deadline C.reserved + return $ Right $ ClientCall call + +-- | Create the call object to handle a registered call. +serverRequestRegisteredCall :: C.Server -> CompletionQueue -> TimeoutSeconds + -> RegisteredMethod + -> IO (Either GRPCIOError Call) +serverRequestRegisteredCall + server cq@CompletionQueue{..} timeLimit RegisteredMethod{..} = + withPermission Push cq $ do + -- TODO: Is gRPC supposed to populate this deadline? + -- NOTE: the below stuff is freed when we free the call we return. + deadline <- C.secondsToDeadline timeLimit + callPtr <- malloc + metadataArrayPtr <- C.metadataArrayCreate + metadataArray <- peek metadataArrayPtr + bbPtr <- malloc + tag <- newTag cq + callError <- C.grpcServerRequestRegisteredCall + server methodHandle callPtr deadline + metadataArray bbPtr unsafeCQ unsafeCQ tag + grpcDebug $ "serverRequestRegisteredCall: callError: " + ++ show callError + if callError /= C.CallOk + then do grpcDebug "serverRequestRegisteredCall: callError. cleaning up" + failureCleanup deadline callPtr metadataArrayPtr bbPtr + return $ Left $ GRPCIOCallError callError + else do pluckResult <- pluck cq tag timeLimit + grpcDebug "serverRequestRegisteredCall: finished pluck." + case pluckResult of + Left x -> do + grpcDebug "serverRequestRegisteredCall: cleanup pluck err" + failureCleanup deadline callPtr metadataArrayPtr bbPtr + return $ Left x + Right () -> do + rawCall <- peek callPtr + let assembledCall = ServerCall rawCall metadataArrayPtr + (Just bbPtr) Nothing Nothing + (Just deadline) + return $ Right assembledCall + -- TODO: see TODO for failureCleanup in serverRequestCall. + where failureCleanup deadline callPtr metadataArrayPtr bbPtr = forkIO $ do + threadDelay (30*10^6) + grpcDebug "serverRequestRegisteredCall: doing delayed cleanup." + C.timespecDestroy deadline + free callPtr + C.metadataArrayDestroy metadataArrayPtr + free bbPtr + +serverRequestCall :: C.Server -> CompletionQueue -> TimeoutSeconds + -> IO (Either GRPCIOError Call) +serverRequestCall server cq@CompletionQueue{..} timeLimit = + withPermission Push cq $ do + callPtr <- malloc + grpcDebug $ "serverRequestCall: callPtr is " ++ show callPtr + callDetails <- C.createCallDetails + metadataArrayPtr <- C.metadataArrayCreate + metadataArray <- peek metadataArrayPtr + tag <- newTag cq + callError <- C.grpcServerRequestCall server callPtr callDetails + metadataArray unsafeCQ unsafeCQ tag + grpcDebug $ "serverRequestCall: callError was " ++ show callError + if callError /= C.CallOk + then do grpcDebug "serverRequestCall: got call error; cleaning up." + failureCleanup callPtr callDetails metadataArrayPtr + return $ Left $ GRPCIOCallError callError + else do pluckResult <- pluck cq tag timeLimit + grpcDebug $ "serverRequestCall: pluckResult was " + ++ show pluckResult + case pluckResult of + Left x -> do + grpcDebug "serverRequestCall: pluck error; cleaning up." + failureCleanup callPtr callDetails + metadataArrayPtr + return $ Left x + Right () -> do + rawCall <- peek callPtr + let call = ServerCall rawCall + metadataArrayPtr + Nothing + Nothing + (Just callDetails) + Nothing + return $ Right call + + --TODO: the gRPC library appears to hold onto these pointers for a random + -- amount of time, even after returning from the only call that uses them. + -- This results in malloc errors if + -- gRPC tries to modify them after we free them. To work around it, + -- we sleep for a while before freeing the objects. We should find a + -- permanent solution that's more robust. + where failureCleanup callPtr callDetails metadataArrayPtr = forkIO $ do + threadDelay (30*10^6) + grpcDebug "serverRequestCall: doing delayed cleanup." + free callPtr + C.destroyCallDetails callDetails + C.metadataArrayDestroy metadataArrayPtr + return () + +-- | Register the server's completion queue. Must be done before the server is +-- started. +serverRegisterCompletionQueue :: C.Server -> CompletionQueue -> IO () +serverRegisterCompletionQueue server CompletionQueue{..} = + C.grpcServerRegisterCompletionQueue server unsafeCQ C.reserved + +serverShutdownAndNotify :: C.Server -> CompletionQueue -> C.Tag -> IO () +serverShutdownAndNotify server CompletionQueue{..} tag = + C.grpcServerShutdownAndNotify server unsafeCQ tag diff --git a/src/Network/GRPC/LowLevel/GRPC.hs b/src/Network/GRPC/LowLevel/GRPC.hs new file mode 100644 index 0000000..4d7bd8f --- /dev/null +++ b/src/Network/GRPC/LowLevel/GRPC.hs @@ -0,0 +1,85 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} + +module Network.GRPC.LowLevel.GRPC where +{- +-- TODO: remove if not needed +import Control.Monad.IO.Class (liftIO, MonadIO) +import Control.Monad.Except (ExceptT(..), runExceptT, throwError, + MonadError) +-} +import Control.Exception +import qualified Network.GRPC.Unsafe as C + +#ifdef DEBUG +import GHC.Conc (myThreadId) +#endif + +-- | Functions as a proof that the gRPC core has been started. The gRPC core +-- must be initialized to create any gRPC state, so this is a requirement for +-- the server and client create/start functions. +data GRPC = GRPC + +withGRPC :: (GRPC -> IO a) -> IO a +withGRPC = bracket (C.grpcInit >> return GRPC) (const C.grpcShutdown) + +-- | Describes all errors that can occur while running a GRPC-related IO action. +data GRPCIOError = GRPCIOCallError C.CallError + -- ^ Errors that can occur while the call is in flight. These + -- errors come from the core gRPC library directly. + | GRPCIOTimeout + -- ^ Indicates that we timed out while waiting for an + -- operation to complete on the 'CompletionQueue'. + | GRPCIOShutdown + -- ^ Indicates that the 'CompletionQueue' is shutting down + -- and no more work can be processed. This can happen if the + -- client or server is shutting down. + | GRPCIOShutdownFailure + -- ^ Thrown if a 'CompletionQueue' fails to shut down in a + -- reasonable amount of time. + | GRPCIOUnknownError + deriving (Show, Eq) + +throwIfCallError :: C.CallError -> Either GRPCIOError () +throwIfCallError C.CallOk = Right () +throwIfCallError x = Left $ GRPCIOCallError x + +grpcDebug :: String -> IO () +{-# INLINE grpcDebug #-} +#ifdef DEBUG +grpcDebug str = do tid <- myThreadId + putStrLn $ (show tid) ++ ": " ++ str +#else +grpcDebug str = return () +#endif + +{- +-- TODO: remove this once finally decided on whether to use it. +-- | Monad for running gRPC operations. +newtype GRPCIO a = GRPCIO {unGRPCIO :: ExceptT GRPCIOError IO a} + deriving (Functor, Applicative, Monad, MonadIO) + +deriving instance MonadError GRPCIOError GRPCIO + +runGRPCIO :: GRPCIO a -> IO (Either GRPCIOError a) +runGRPCIO = runExceptT . unGRPCIO + +unrunGRPCIO :: IO (Either GRPCIOError a) -> GRPCIO a +unrunGRPCIO = GRPCIO . ExceptT + +continueFrom :: (a -> GRPCIO b) -> (Either GRPCIOError a) -> GRPCIO b +continueFrom f (Left x) = throwError x +continueFrom f (Right x) = f x + +wrapGRPC :: Either GRPCIOError a -> GRPCIO a +wrapGRPC (Left x) = throwError x +wrapGRPC (Right x) = return x + +grpcBracket :: GRPCIO a -> (a -> GRPCIO b) -> (a -> GRPCIO c) -> GRPCIO c +grpcBracket create destroy f = unrunGRPCIO $ do + let createAction = runGRPCIO create + let fAction = runGRPCIO . continueFrom f + let destroyAction = runGRPCIO . continueFrom destroy + bracket createAction destroyAction fAction +-} diff --git a/src/Network/GRPC/LowLevel/Op.hs b/src/Network/GRPC/LowLevel/Op.hs new file mode 100644 index 0000000..4f788bc --- /dev/null +++ b/src/Network/GRPC/LowLevel/Op.hs @@ -0,0 +1,203 @@ +{-# LANGUAGE RecordWildCards #-} + +module Network.GRPC.LowLevel.Op where + +import Control.Exception +import qualified Data.ByteString as B +import qualified Data.Map.Strict as M +import Data.Maybe (catMaybes) +import Foreign.C.Types (CInt) +import Foreign.Marshal.Alloc (malloc, free) +import Foreign.Ptr (Ptr) +import Foreign.Storable (peek) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Metadata as C +import qualified Network.GRPC.Unsafe.ByteBuffer as C +import qualified Network.GRPC.Unsafe.Op as C + +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.CompletionQueue +import Network.GRPC.LowLevel.Call + +type MetadataMap = M.Map B.ByteString B.ByteString + +-- | Sum describing all possible send and receive operations that can be batched +-- and executed by gRPC. Usually these are processed in a handful of +-- combinations depending on the 'MethodType' of the call being run. +data Op = OpSendInitialMetadata MetadataMap + | OpSendMessage B.ByteString + | OpSendCloseFromClient + | OpSendStatusFromServer MetadataMap C.StatusCode --TODO: Issue #6 + | OpRecvInitialMetadata + | OpRecvMessage + | OpRecvStatusOnClient + | OpRecvCloseOnServer + deriving (Eq, Show) + +-- | Container holding the pointers to the C and gRPC data needed to execute the +-- corresponding 'Op'. These are obviously unsafe, and should only be used with +-- 'withOpContexts'. +data OpContext = + OpSendInitialMetadataContext C.MetadataKeyValPtr Int + | OpSendMessageContext C.ByteBuffer + | OpSendCloseFromClientContext + | OpSendStatusFromServerContext C.MetadataKeyValPtr Int C.StatusCode + | OpRecvInitialMetadataContext (Ptr C.MetadataArray) + | OpRecvMessageContext (Ptr C.ByteBuffer) + | OpRecvStatusOnClientContext (Ptr C.MetadataArray) (Ptr C.StatusCode) + | OpRecvCloseOnServerContext (Ptr CInt) + +-- | Allocates and initializes the 'Opcontext' corresponding to the given 'Op'. +createOpContext :: Op -> IO OpContext +createOpContext (OpSendInitialMetadata m) = + OpSendInitialMetadataContext + <$> C.createMetadata m + <*> return (M.size m) +createOpContext (OpSendMessage bs) = + fmap OpSendMessageContext (C.createByteBuffer bs) +createOpContext (OpSendCloseFromClient) = return OpSendCloseFromClientContext +createOpContext (OpSendStatusFromServer m code) = + OpSendStatusFromServerContext + <$> C.createMetadata m + <*> return (M.size m) + <*> return code +createOpContext OpRecvInitialMetadata = + fmap OpRecvInitialMetadataContext C.metadataArrayCreate +createOpContext OpRecvMessage = + fmap OpRecvMessageContext C.createReceivingByteBuffer +createOpContext OpRecvStatusOnClient = + OpRecvStatusOnClientContext + <$> C.metadataArrayCreate + <*> C.createStatusCodePtr +createOpContext OpRecvCloseOnServer = + fmap OpRecvCloseOnServerContext $ malloc + +-- | Mutates the given raw array of ops at the given index according to the +-- given 'OpContext'. +setOpArray :: C.OpArray -> Int -> OpContext -> IO () +setOpArray arr i (OpSendInitialMetadataContext kvs l) = + C.opSendInitialMetadata arr i kvs l +setOpArray arr i (OpSendMessageContext bb) = + C.opSendMessage arr i bb +setOpArray arr i OpSendCloseFromClientContext = + C.opSendCloseClient arr i +setOpArray arr i (OpSendStatusFromServerContext kvs l code) = + C.opSendStatusServer arr i l kvs code "" --TODO: Issue #6 +setOpArray arr i (OpRecvInitialMetadataContext pmetadata) = + C.opRecvInitialMetadata arr i pmetadata +setOpArray arr i (OpRecvMessageContext pbb) = + C.opRecvMessage arr i pbb +setOpArray arr i (OpRecvStatusOnClientContext pmetadata pstatus) = do + pCString <- malloc --TODO: Issue #6 + C.opRecvStatusClient arr i pmetadata pstatus pCString 0 +setOpArray arr i (OpRecvCloseOnServerContext pcancelled) = do + C.opRecvCloseServer arr i pcancelled + +-- | Cleans up an 'OpContext'. +freeOpContext :: OpContext -> IO () +freeOpContext (OpSendInitialMetadataContext m _) = C.metadataFree m +freeOpContext (OpSendMessageContext bb) = C.grpcByteBufferDestroy bb +freeOpContext OpSendCloseFromClientContext = return () +freeOpContext (OpSendStatusFromServerContext metadata _ _) = + C.metadataFree metadata +freeOpContext (OpRecvInitialMetadataContext metadata) = + C.metadataArrayDestroy metadata +freeOpContext (OpRecvMessageContext pbb) = + C.destroyReceivingByteBuffer pbb +freeOpContext (OpRecvStatusOnClientContext metadata pcode) = + C.metadataArrayDestroy metadata + >> C.destroyStatusCodePtr pcode +freeOpContext (OpRecvCloseOnServerContext pcancelled) = + grpcDebug ("freeOpContext: freeing pcancelled: " ++ show pcancelled) + >> free pcancelled + +-- | Converts a list of 'Op's into the corresponding 'OpContext's and guarantees +-- they will be cleaned up correctly. +withOpContexts :: [Op] -> ([OpContext] -> IO a) -> IO a +withOpContexts ops = bracket (mapM createOpContext ops) + (mapM freeOpContext) + +withOpArray :: Int -> (C.OpArray -> IO a) -> IO a +withOpArray n = bracket (C.opArrayCreate n) + (flip C.opArrayDestroy n) + +-- | Container holding GC-managed results for 'Op's which receive data. +data OpRecvResult = + OpRecvInitialMetadataResult MetadataMap + | OpRecvMessageResult B.ByteString + | OpRecvStatusOnClientResult MetadataMap C.StatusCode + | OpRecvCloseOnServerResult Bool -- ^ True if call was cancelled. + deriving (Eq, Show) + +-- | For the given 'OpContext', if the 'Op' receives data, copies the data out +-- of the 'OpContext' and into GC-managed Haskell types. After this, it is safe +-- to destroy the 'OpContext'. +resultFromOpContext :: OpContext -> IO (Maybe OpRecvResult) +resultFromOpContext (OpRecvInitialMetadataContext pmetadata) = do + grpcDebug "resultFromOpContext: OpRecvInitialMetadataContext" + metadata <- peek pmetadata + metadataMap <- C.getAllMetadataArray metadata + return $ Just $ OpRecvInitialMetadataResult metadataMap +resultFromOpContext (OpRecvMessageContext pbb) = do + grpcDebug "resultFromOpContext: OpRecvMessageContext" + bb <- peek pbb + grpcDebug "resultFromOpContext: bytebuffer peeked." + bs <- C.copyByteBufferToByteString bb + grpcDebug "resultFromOpContext: bb copied." + return $ Just $ OpRecvMessageResult bs +resultFromOpContext (OpRecvStatusOnClientContext pmetadata pcode) = do + grpcDebug "resultFromOpContext: OpRecvStatusOnClientContext" + metadata <- peek pmetadata + metadataMap <- C.getAllMetadataArray metadata + code <- C.derefStatusCodePtr pcode + return $ Just $ OpRecvStatusOnClientResult metadataMap code +resultFromOpContext (OpRecvCloseOnServerContext pcancelled) = do + grpcDebug "resultFromOpContext: OpRecvCloseOnServerContext" + cancelled <- fmap (\x -> if x > 0 then True else False) + (peek pcancelled) + return $ Just $ OpRecvCloseOnServerResult cancelled +resultFromOpContext _ = do + grpcDebug "resultFromOpContext: saw non-result op type." + return Nothing + +--TODO: the list of 'Op's type is less specific than it could be. There are only +-- a few different sequences of 'Op's we will see in practice. Once we figure +-- out what those are, we should create a more specific sum type. This will also +-- allow us to make a more specific sum type to replace @[OpRecvResult]@, too. + +-- | For a given call, run the given 'Op's on the given completion queue with +-- the given tag. Blocks until the ops are complete or the given number of +-- seconds have elapsed. +runOps :: Call + -- ^ 'Call' that this batch is associated with. One call can be + -- associated with many batches. + -> CompletionQueue + -- ^ Queue on which our tag will be placed once our ops are done + -- running. + -> [Op] + -> TimeoutSeconds + -- ^ How long to block waiting for the tag to appear on the queue. + -- If we time out, the result of this action will be + -- @CallBatchError BatchTimeout@. + -> IO (Either GRPCIOError [OpRecvResult]) +runOps call cq ops timeLimit = + let l = length ops in + withOpArray l $ \opArray -> do + grpcDebug "runOps: created op array." + withOpContexts ops $ \contexts -> do + grpcDebug "runOps: allocated op contexts." + sequence_ $ zipWith (setOpArray opArray) [0..l-1] contexts + tag <- newTag cq + callError <- startBatch cq (internalCall call) opArray l tag + grpcDebug $ "runOps: called start_batch. callError: " + ++ (show callError) + case callError of + Left x -> return $ Left x + Right () -> do + ev <- pluck cq tag timeLimit + grpcDebug $ "runOps: pluck returned " ++ show ev + case ev of + Right () -> do + grpcDebug "runOps: got good op; starting." + fmap (Right . catMaybes) $ mapM resultFromOpContext contexts + Left err -> return $ Left err diff --git a/src/Network/GRPC/LowLevel/Server.hs b/src/Network/GRPC/LowLevel/Server.hs new file mode 100644 index 0000000..29f974d --- /dev/null +++ b/src/Network/GRPC/LowLevel/Server.hs @@ -0,0 +1,251 @@ +{-# LANGUAGE RecordWildCards #-} + +module Network.GRPC.LowLevel.Server where + +import Control.Exception (bracket, finally) +import Control.Monad +import Data.ByteString (ByteString) +import qualified Data.Map as M +import Foreign.Ptr (nullPtr) +import Foreign.Storable (peek) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Op as C + +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, + pluck, serverRegisterCompletionQueue, serverShutdownAndNotify, + createCompletionQueue, shutdownCompletionQueue, TimeoutSeconds, + serverRequestRegisteredCall, serverRequestCall) +import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.Op + +import qualified Network.GRPC.Unsafe.ByteBuffer as C +import qualified Network.GRPC.Unsafe.Metadata as C + +-- | Wraps various gRPC state needed to run a server. +data Server = Server {internalServer :: C.Server, serverCQ :: CompletionQueue, + registeredMethods :: [RegisteredMethod]} + +-- | Configuration needed to start a server. There might be more fields that +-- need to be added to this in the future. +data ServerConfig = + ServerConfig {hostName :: Host, + -- ^ Name of the host the server is running on. Not sure + -- how this is used. Setting to "localhost" works fine in tests. + port :: Int, + -- ^ Port to listen for requests on. + methodsToRegister :: [(MethodName, Host, GRPCMethodType)] + -- ^ List of (method name, method host, method type) tuples + -- specifying all methods to register. You can also handle + -- other unregistered methods with `serverHandleNormalCall`. + } + deriving (Show, Eq) + +startServer :: GRPC -> ServerConfig -> IO Server +startServer grpc ServerConfig{..} = do + server <- C.grpcServerCreate nullPtr C.reserved + let hostPort = (unHost hostName) ++ ":" ++ (show port) + actualPort <- C.grpcServerAddInsecureHttp2Port server hostPort + when (actualPort /= port) (error $ "Unable to bind port: " ++ (show port)) + cq <- createCompletionQueue grpc + serverRegisterCompletionQueue server cq + methods <- forM methodsToRegister $ + \(name, host, mtype) -> + serverRegisterMethod server name host mtype + C.grpcServerStart server + return $ Server server cq methods + +stopServer :: Server -> IO () +-- TODO: Do method handles need to be freed? +stopServer (Server server cq _) = do + grpcDebug "stopServer: calling shutdownNotify." + shutdownNotify + grpcDebug "stopServer: cancelling all calls." + C.grpcServerCancelAllCalls server + grpcDebug "stopServer: call grpc_server_destroy." + C.grpcServerDestroy server + grpcDebug "stopServer: shutting down CQ." + shutdownCQ + + where shutdownCQ = do + shutdownResult <- shutdownCompletionQueue cq + case shutdownResult of + Left _ -> do putStrLn "Warning: completion queue didn't shut down." + putStrLn "Trying to stop server anyway." + Right _ -> return () + shutdownNotify = do + let shutdownTag = C.tag 0 + serverShutdownAndNotify server cq shutdownTag + shutdownEvent <- pluck cq shutdownTag 30 + case shutdownEvent of + -- This case occurs when we pluck but the queue is already in the + -- 'shuttingDown' state, implying we already tried to shut down. + (Left GRPCIOShutdown) -> error "Called stopServer twice!" + (Left _) -> error "Failed to stop server." + (Right _) -> return () + +-- Uses 'bracket' to safely start and stop a server, even if exceptions occur. +withServer :: GRPC -> ServerConfig -> (Server -> IO a) -> IO a +withServer grpc cfg f = bracket (startServer grpc cfg) stopServer f + +-- | Register a method on a server. The 'RegisteredMethod' type can then be used +-- to wait for a request to arrive. Note: gRPC claims this must be called before +-- the server is started, so we do it during startup according to the +-- 'ServerConfig'. +serverRegisterMethod :: C.Server + -> MethodName + -- ^ method name, e.g. "/foo" + -> Host + -- ^ host name, e.g. "localhost". I have no idea + -- why this is needed since we have to supply a host + -- name to start a server in the first place. It doesn't + -- seem to have any effect, even if it's filled with + -- nonsense. + -> GRPCMethodType + -- ^ Type of method this will be. In the future, this + -- will be used to switch to the correct handling logic. + -- Currently, the only valid choice is 'Normal'. + -> IO RegisteredMethod +serverRegisterMethod internalServer name host Normal = do + handle <- C.grpcServerRegisterMethod internalServer + (unMethodName name) + (unHost host) + grpcDebug $ "registered method to handle " ++ show handle + return $ RegisteredMethod Normal name host handle +serverRegisterMethod _ _ _ _ = error "Streaming methods not implemented yet." + +-- | Create a 'Call' with which to wait for the invocation of a registered +-- method. +serverCreateRegisteredCall :: Server -> RegisteredMethod -> TimeoutSeconds + -> IO (Either GRPCIOError Call) +serverCreateRegisteredCall Server{..} rm timeLimit = + serverRequestRegisteredCall internalServer serverCQ timeLimit rm + +withServerRegisteredCall :: Server -> RegisteredMethod -> TimeoutSeconds + -> (Call + -> IO (Either GRPCIOError a)) + -> IO (Either GRPCIOError a) +withServerRegisteredCall server regmethod timeout f = do + createResult <- serverCreateRegisteredCall server regmethod timeout + case createResult of + Left x -> return $ Left x + Right call -> f call `finally` logDestroy call + where logDestroy c = grpcDebug "withServerRegisteredCall: destroying." + >> destroyCall c + +serverCreateCall :: Server -> TimeoutSeconds + -> IO (Either GRPCIOError Call) +serverCreateCall Server{..} timeLimit = + serverRequestCall internalServer serverCQ timeLimit + +withServerCall :: Server -> TimeoutSeconds + -> (Call -> IO (Either GRPCIOError a)) + -> IO (Either GRPCIOError a) +withServerCall server timeout f = do + createResult <- serverCreateCall server timeout + case createResult of + Left x -> return $ Left x + Right call -> f call `finally` logDestroy call + where logDestroy c = grpcDebug "withServerCall: destroying." + >> destroyCall c + +-- | Sequence of 'Op's needed to receive a normal (non-streaming) call. +serverOpsGetNormalCall :: MetadataMap -> [Op] +serverOpsGetNormalCall initMetadata = + [OpSendInitialMetadata initMetadata, + OpRecvMessage] + +-- | Sequence of 'Op's needed to respond to a normal (non-streaming) call. +serverOpsSendNormalResponse :: ByteString + -> MetadataMap + -> C.StatusCode + -> [Op] +serverOpsSendNormalResponse body metadata code = + [OpRecvCloseOnServer, + OpSendMessage body, + OpSendStatusFromServer metadata code] + +serverOpsSendNormalRegisteredResponse :: ByteString + -> MetadataMap + -- ^ initial metadata + -> MetadataMap + -- ^ trailing metadata + -> C.StatusCode + -> [Op] +serverOpsSendNormalRegisteredResponse body initMetadata trailingMeta code = + [OpSendInitialMetadata initMetadata, + OpRecvCloseOnServer, + OpSendMessage body, + OpSendStatusFromServer trailingMeta code] + +-- TODO: we will want to replace this with some more general concept that also +-- works with streaming calls in the future. +-- | Wait for and then handle a normal (non-streaming) call. +serverHandleNormalRegisteredCall :: Server + -> RegisteredMethod + -> TimeoutSeconds + -> MetadataMap + -- ^ Initial server metadata + -> (ByteString -> MetadataMap + -> IO (ByteString, + MetadataMap, + MetadataMap)) + -- ^ Handler function takes a request body and + -- metadata and returns a response body and + -- metadata. + -> IO (Either GRPCIOError ()) +serverHandleNormalRegisteredCall s@Server{..} rm timeLimit initMetadata f = do + -- TODO: we use this timeLimit twice, so the max time spent is 2*timeLimit. + -- Should we just hard-code time limits instead? Not sure if client + -- programmer cares, since this function will likely just be put in a loop + -- anyway. + withServerRegisteredCall s rm timeLimit $ \call -> do + grpcDebug "serverHandleNormalRegisteredCall: starting batch." + debugCall call + case optionalPayload call of + Nothing -> error "Impossible: not a registered call." --TODO: better types + Just payloadPtr -> do + payload <- peek payloadPtr + requestBody <- C.copyByteBufferToByteString payload + metadataArray <- peek $ requestMetadataRecv call + metadata <- C.getAllMetadataArray metadataArray + (respBody, initMeta, trailingMeta) <- f requestBody metadata + let status = C.GrpcStatusOk + let respOps = serverOpsSendNormalRegisteredResponse + respBody initMeta trailingMeta status + respOpsResults <- runOps call serverCQ respOps timeLimit + grpcDebug "serverHandleNormalRegisteredCall: finished response ops." + case respOpsResults of + Left x -> return $ Left x + Right _ -> return $ Right () + +-- TODO: This is preliminary. +-- We still need to provide the method name to the handler. +-- | Handle one unregistered call. +serverHandleNormalCall :: Server -> TimeoutSeconds + -> MetadataMap + -- ^ Initial metadata. + -> (ByteString -> MetadataMap + -> IO (ByteString, MetadataMap)) + -- ^ Handler function takes a request body and + -- metadata and returns a response body and metadata. + -> IO (Either GRPCIOError ()) +serverHandleNormalCall s@Server{..} timeLimit initMetadata f = do + withServerCall s timeLimit $ \call -> do + grpcDebug "serverHandleNormalCall: starting batch." + let recvOps = serverOpsGetNormalCall initMetadata + opResults <- runOps call serverCQ recvOps timeLimit + case opResults of + Left x -> return $ Left x + Right [OpRecvMessageResult body] -> do + --TODO: we need to get client metadata + (respBody, respMetadata) <- f body M.empty + let status = C.GrpcStatusOk + let respOps = serverOpsSendNormalResponse respBody respMetadata status + respOpsResults <- runOps call serverCQ respOps timeLimit + case respOpsResults of + Left x -> do grpcDebug "serverHandleNormalCall: resp failed." + return $ Left x + Right _ -> grpcDebug "serverHandleNormalCall: ops done." + >> return (Right ()) + x -> error $ "impossible pattern match: " ++ show x diff --git a/src/Network/GRPC/Unsafe.chs b/src/Network/GRPC/Unsafe.chs index f8e4d18..185bf92 100644 --- a/src/Network/GRPC/Unsafe.chs +++ b/src/Network/GRPC/Unsafe.chs @@ -29,8 +29,15 @@ import Network.GRPC.Unsafe.Constants -- | Represents a pointer to a call. To users of the gRPC core library, this -- type is abstract; we have no access to its fields. {#pointer *grpc_call as Call newtype #} + +instance Show Call where + show (Call ptr) = show ptr + {#pointer *grpc_call_details as CallDetails newtype #} +instance Show CallDetails where + show (CallDetails ptr) = show ptr + {#fun create_call_details as ^ {} -> `CallDetails'#} {#fun destroy_call_details as ^ {`CallDetails'} -> `()'#} @@ -66,6 +73,11 @@ instance Storable Tag where peek p = fmap Tag (peek (castPtr p)) poke p (Tag r) = poke (castPtr p) r +-- | A 'CallHandle' is an identifier used to refer to a registered call. Create +-- one on the client with 'grpcChannelRegisterCall', and on the server with +-- 'grpcServerRegisterMethod'. +newtype CallHandle = CallHandle {unCallHandle :: Ptr ()} deriving (Show, Eq) + -- | 'Reserved' is an as-yet unused void pointer param to several gRPC -- functions. Create one with 'reserved'. newtype Reserved = Reserved {unReserved :: Ptr ()} @@ -158,6 +170,14 @@ castPeek p = peek (castPtr p) {#fun grpc_insecure_channel_create as ^ {`String', `ChannelArgsPtr',unReserved `Reserved'} -> `Channel'#} +{#fun grpc_channel_register_call as ^ + {`Channel', `String', `String',unReserved `Reserved'} + -> `CallHandle' CallHandle#} + +{#fun grpc_channel_create_registered_call_ as ^ + {`Channel', `Call', fromIntegral `PropagationMask', `CompletionQueue', + unCallHandle `CallHandle', `CTimeSpecPtr', unReserved `Reserved'} -> `Call'#} + -- | get the current connectivity state of the given channel. The 'Bool' is -- True if we should try to connect the channel. {#fun grpc_channel_check_connectivity_state as ^ @@ -190,11 +210,17 @@ castPeek p = peek (castPtr p) {#fun grpc_call_destroy as ^ {`Call'} -> `()'#} +--TODO: we need to free this string with gpr_free! +{#fun grpc_call_get_peer as ^ {`Call'} -> `String' #} + -- Server stuff {#fun grpc_server_create as ^ {`ChannelArgsPtr',unReserved `Reserved'} -> `Server'#} +{#fun grpc_server_register_method as ^ + {`Server', `String', `String'} -> `CallHandle' CallHandle#} + {#fun grpc_server_register_completion_queue as ^ {`Server', `CompletionQueue', unReserved `Reserved'} -> `()'#} @@ -229,6 +255,7 @@ castPeek p = peek (castPtr p) -- | TODO: I am not yet sure how this function is supposed to be used. {#fun grpc_server_request_registered_call as ^ - {`Server',unTag `Tag',id `Ptr Call', `CTimeSpecPtr', `MetadataArray' id, - id `Ptr ByteBuffer', `CompletionQueue', `CompletionQueue',unTag `Tag'} + {`Server',unCallHandle `CallHandle',id `Ptr Call', `CTimeSpecPtr', + `MetadataArray', id `Ptr ByteBuffer', `CompletionQueue', + `CompletionQueue',unTag `Tag'} -> `CallError'#} diff --git a/src/Network/GRPC/Unsafe/ByteBuffer.chs b/src/Network/GRPC/Unsafe/ByteBuffer.chs index 36bc2fc..39d573f 100644 --- a/src/Network/GRPC/Unsafe/ByteBuffer.chs +++ b/src/Network/GRPC/Unsafe/ByteBuffer.chs @@ -74,11 +74,16 @@ withByteBufferPtr {#fun grpc_raw_byte_buffer_from_reader as ^ {`ByteBufferReader'} -> `ByteBuffer'#} +-- TODO: Issue #5 withByteStringAsByteBuffer :: B.ByteString -> (ByteBuffer -> IO a) -> IO a withByteStringAsByteBuffer bs f = do bracket (byteStringToSlice bs) freeSlice $ \slice -> do bracket (grpcRawByteBufferCreate slice 1) grpcByteBufferDestroy f +-- TODO: Issue #5 +createByteBuffer :: B.ByteString -> IO ByteBuffer +createByteBuffer bs = byteStringToSlice bs >>= flip grpcRawByteBufferCreate 1 + copyByteBufferToByteString :: ByteBuffer -> IO B.ByteString copyByteBufferToByteString bb = do bracket (byteBufferReaderCreate bb) byteBufferReaderDestroy $ \bbr -> do diff --git a/src/Network/GRPC/Unsafe/Metadata.chs b/src/Network/GRPC/Unsafe/Metadata.chs index 4603940..da6174d 100644 --- a/src/Network/GRPC/Unsafe/Metadata.chs +++ b/src/Network/GRPC/Unsafe/Metadata.chs @@ -3,6 +3,7 @@ module Network.GRPC.Unsafe.Metadata where import Control.Exception import Control.Monad import Data.ByteString (ByteString, useAsCString, packCString) +import Data.Map.Strict as M import Foreign.C.String import Foreign.Ptr import Foreign.Storable @@ -24,6 +25,11 @@ import Foreign.Storable -- and length from this type. {#pointer *grpc_metadata_array as MetadataArray newtype#} +{#fun metadata_array_get_metadata as ^ + {`MetadataArray'} -> `MetadataKeyValPtr'#} + +{#fun metadata_array_get_count as ^ {`MetadataArray'} -> `Int'#} + instance Storable MetadataArray where sizeOf (MetadataArray r) = sizeOf r alignment (MetadataArray r) = alignment r @@ -68,3 +74,23 @@ getMetadataKey m = getMetadataKey' m >=> packCString getMetadataVal :: MetadataKeyValPtr -> Int -> IO ByteString getMetadataVal m = getMetadataVal' m >=> packCString + +createMetadata :: M.Map ByteString ByteString -> IO MetadataKeyValPtr +createMetadata m = do + let l = M.size m + let indexedKeyVals = zip [0..] $ M.toList m + metadata <- metadataAlloc l + forM_ indexedKeyVals $ \(i,(k,v)) -> setMetadataKeyVal k v metadata i + return metadata + +getAllMetadataArray :: MetadataArray -> IO (M.Map ByteString ByteString) +getAllMetadataArray m = do + kvs <- metadataArrayGetMetadata m + l <- metadataArrayGetCount m + getAllMetadata kvs l + +getAllMetadata :: MetadataKeyValPtr -> Int -> IO (M.Map ByteString ByteString) +getAllMetadata m count = do + let indices = [0..count-1] + fmap M.fromList $ forM indices $ + \i -> liftM2 (,) (getMetadataKey m i) (getMetadataVal m i) diff --git a/src/Network/GRPC/Unsafe/Op.chs b/src/Network/GRPC/Unsafe/Op.chs index 9f6346e..0d7ccf2 100644 --- a/src/Network/GRPC/Unsafe/Op.chs +++ b/src/Network/GRPC/Unsafe/Op.chs @@ -1,8 +1,6 @@ module Network.GRPC.Unsafe.Op where import Control.Exception -import Control.Monad -import qualified Data.ByteString as B import Foreign.C.String import Foreign.C.Types import Foreign.Ptr @@ -14,8 +12,8 @@ import Foreign.Ptr #include #include -{#enum grpc_op_type as OpType {underscoreToCase} deriving (Eq)#} -{#enum grpc_status_code as StatusCode {underscoreToCase} deriving (Eq)#} +{#enum grpc_op_type as OpType {underscoreToCase} deriving (Eq, Show)#} +{#enum grpc_status_code as StatusCode {underscoreToCase} deriving (Eq, Show)#} -- NOTE: We don't alloc the space for the enum in Haskell because enum size is -- implementation-dependent. See: @@ -24,6 +22,8 @@ import Foreign.Ptr -- receive a status code from the server with 'opRecvStatusClient'. {#fun create_status_code_ptr as ^ {} -> `Ptr StatusCode' castPtr#} +{#fun deref_status_code_ptr as ^ {castPtr `Ptr StatusCode'} -> `StatusCode'#} + {#fun destroy_status_code_ptr as ^ {castPtr `Ptr StatusCode'} -> `()' #} -- | Represents an array of ops to be passed to 'grpcCallStartBatch'. diff --git a/src/Network/GRPC/Unsafe/Slice.chs b/src/Network/GRPC/Unsafe/Slice.chs index 160a15d..b6cb89c 100644 --- a/src/Network/GRPC/Unsafe/Slice.chs +++ b/src/Network/GRPC/Unsafe/Slice.chs @@ -4,12 +4,9 @@ module Network.GRPC.Unsafe.Slice where #include import qualified Data.ByteString as B -import Control.Applicative -import Data.Word import Foreign.C.String import Foreign.C.Types import Foreign.Ptr -import Foreign.Marshal.Alloc -- | A 'Slice' is gRPC's string type. We can easily convert these to and from -- ByteStrings. This type is a pointer to a C type. diff --git a/src/Network/GRPC/Unsafe/Time.chs b/src/Network/GRPC/Unsafe/Time.chs index 6da98c3..ab88d8f 100644 --- a/src/Network/GRPC/Unsafe/Time.chs +++ b/src/Network/GRPC/Unsafe/Time.chs @@ -1,6 +1,6 @@ module Network.GRPC.Unsafe.Time where -import Control.Applicative +import Control.Exception (bracket) import Control.Monad import Foreign.C.Types import Foreign.Storable @@ -43,6 +43,9 @@ instance Storable CTimeSpec where -- future. {#fun seconds_to_deadline as ^ {`Int'} -> `CTimeSpecPtr'#} +withDeadlineSeconds :: Int -> (CTimeSpecPtr -> IO a) -> IO a +withDeadlineSeconds i = bracket (secondsToDeadline i) timespecDestroy + -- | Returns a GprClockMonotonic representing a deadline n milliseconds -- in the future. {#fun millis_to_deadline as ^ {`Int'} -> `CTimeSpecPtr'#} diff --git a/tests/LowLevelTests.hs b/tests/LowLevelTests.hs new file mode 100644 index 0000000..4a570ba --- /dev/null +++ b/tests/LowLevelTests.hs @@ -0,0 +1,158 @@ +{-# LANGUAGE OverloadedStrings #-} + +module LowLevelTests where + +import Control.Concurrent.Async (withAsync, wait) +import Data.ByteString (ByteString) +import qualified Data.Map as M +import Network.GRPC.LowLevel +import Test.Tasty +import Test.Tasty.HUnit ((@?=), testCase) + +lowLevelTests :: TestTree +lowLevelTests = testGroup "Unit tests of low-level Haskell library" + [ testGRPCBracket + , testCompletionQueueCreateDestroy + , testServerCreateDestroy + , testClientCreateDestroy + , testWithServerCall + , testWithClientCall + --, testPayloadLowLevel --TODO: currently crashing from free on unalloced ptr + --, testClientRequestNoServer --TODO: succeeds when no other tests run. + , testServerAwaitNoClient + --, testPayloadLowLevelUnregistered --TODO: succeeds when no other tests run. + ] + +dummyMeta :: M.Map ByteString ByteString +dummyMeta = M.fromList [("foo","bar")] + +testGRPCBracket :: TestTree +testGRPCBracket = testCase "No errors starting and stopping GRPC" $ + withGRPC $ const $ return () + +testCompletionQueueCreateDestroy :: TestTree +testCompletionQueueCreateDestroy = + testCase "No errors creating and destroying a CQ" $ withGRPC $ \grpc -> + withCompletionQueue grpc $ const (return ()) + +testServerCreateDestroy :: TestTree +testServerCreateDestroy = + testCase "No errors when starting and stopping a server" $ + withGRPC $ \grpc -> withServer grpc (ServerConfig "localhost" 50051 []) + (const $ return ()) + +testClientCreateDestroy :: TestTree +testClientCreateDestroy = + testCase "No errors when starting and stopping a client" $ + withGRPC $ \grpc -> withClient grpc (ClientConfig "localhost" 50051) + (const $ return ()) + +testPayloadLowLevelServer :: GRPC -> IO () +testPayloadLowLevelServer grpc = do + let conf = (ServerConfig "localhost" 50051 [("/foo", "localhost", Normal)]) + withServer grpc conf $ \server -> do + let method = head (registeredMethods server) + result <- serverHandleNormalRegisteredCall server method 11 M.empty $ + \reqBody reqMeta -> return ("reply test", dummyMeta, dummyMeta) + case result of + Left err -> error $ show err + Right _ -> return () + +testPayloadLowLevelClient :: GRPC -> IO () +testPayloadLowLevelClient grpc = + withClient grpc (ClientConfig "localhost" 50051) $ \client -> do + method <- clientRegisterMethod client "/foo" "localhost" Normal + putStrLn "registered method on client." + reqResult <- clientRegisteredRequest client method 10 "Hello!" M.empty + case reqResult of + Left x -> error $ "Client got error: " ++ show x + Right (NormalRequestResult respBody initMeta trailingMeta respCode) -> do + respBody @?= "reply test" + respCode @?= GrpcStatusOk + +testPayloadLowLevelClientUnregistered :: GRPC -> IO () +testPayloadLowLevelClientUnregistered grpc = do + withClient grpc (ClientConfig "localhost" 50051) $ \client -> do + reqResult <- clientRequest client "/foo" "localhost" 10 "Hello!" M.empty + case reqResult of + Left x -> error $ "Client got error: " ++ show x + Right (NormalRequestResult respBody initMeta trailingMeta respCode) -> do + respBody @?= "reply test" + respCode @?= GrpcStatusOk + +testPayloadLowLevelServerUnregistered :: GRPC -> IO () +testPayloadLowLevelServerUnregistered grpc = do + withServer grpc (ServerConfig "localhost" 50051 []) $ \server -> do + result <- serverHandleNormalCall server 11 M.empty $ + \reqBody reqMeta -> return ("reply test", M.empty) + case result of + Left x -> error $ show x + Right _ -> return () + +testClientRequestNoServer :: TestTree +testClientRequestNoServer = testCase "request times out when no server " $ do + withGRPC $ \grpc -> do + withClient grpc (ClientConfig "localhost" 50051) $ \client -> do + method <- clientRegisterMethod client "/foo" "localhost" Normal + reqResult <- clientRegisteredRequest client method 1 "Hello" M.empty + reqResult @?= (Left GRPCIOTimeout) + +testServerAwaitNoClient :: TestTree +testServerAwaitNoClient = testCase "server wait times out when no client " $ do + withGRPC $ \grpc -> do + let conf = (ServerConfig "localhost" 50051 [("/foo", "localhost", Normal)]) + withServer grpc conf $ \server -> do + let method = head (registeredMethods server) + result <- serverHandleNormalRegisteredCall server method 1 M.empty $ + \_ _ -> return ("", M.empty, M.empty) + result @?= Left GRPCIOTimeout + +testServerUnregisteredAwaitNoClient :: TestTree +testServerUnregisteredAwaitNoClient = + testCase "server wait times out when no client -- unregistered method " $ do + withGRPC $ \grpc -> do + let conf = ServerConfig "localhost" 50051 [] + withServer grpc conf $ \server -> do + result <- serverHandleNormalCall server 10 M.empty $ + \_ _ -> return ("", M.empty) + case result of + Left err -> error $ show err + Right _ -> return () + +testPayloadLowLevel :: TestTree +testPayloadLowLevel = testCase "LowLevel Haskell library request/response " $ do + withGRPC $ \grpc -> do + withAsync (testPayloadLowLevelServer grpc) $ \a1 -> do + withAsync (testPayloadLowLevelClient grpc) $ \a2 -> do + wait a1 + wait a2 + +testPayloadLowLevelUnregistered :: TestTree +testPayloadLowLevelUnregistered = + testCase "LowLevel Haskell library unregistered request/response " $ do + withGRPC $ \grpc -> do + withAsync (testPayloadLowLevelServerUnregistered grpc) $ \a1 -> + withAsync (testPayloadLowLevelClientUnregistered grpc) $ \a2 -> do + wait a1 + wait a2 + +testWithServerCall :: TestTree +testWithServerCall = + testCase "Creating and destroying a call: no errors. " $ + withGRPC $ \grpc -> do + let conf = ServerConfig "localhost" 50051 [] + withServer grpc conf $ \server -> do + result <- withServerCall server 1 $ const $ return $ Right () + result @?= Left GRPCIOTimeout + +testWithClientCall :: TestTree +testWithClientCall = + testCase "Creating and destroying a client call: no errors. " $ + withGRPC $ \grpc -> do + let conf = ClientConfig "localhost" 50051 + withClient grpc conf $ \client -> do + result <- withClientCall client "foo" "localhost" 10 $ + const $ return $ Right () + case result of + Left err -> error $ show err + Right _ -> return () diff --git a/tests/Properties.hs b/tests/Properties.hs index e866a5f..6208951 100644 --- a/tests/Properties.hs +++ b/tests/Properties.hs @@ -9,14 +9,14 @@ import Network.GRPC.Unsafe.Metadata import Network.GRPC.Unsafe.Op import Network.GRPC.Unsafe.Constants import qualified Data.ByteString as B -import Data.Time.Clock.POSIX -import GHC.Exts import Foreign.Marshal.Alloc import Foreign.Storable import Foreign.Ptr import Test.Tasty import Test.Tasty.HUnit as HU +import LowLevelTests + roundtripSlice :: B.ByteString -> TestTree roundtripSlice bs = testCase "Slice C bindings roundtrip" $ do slice <- byteStringToSlice bs @@ -205,13 +205,40 @@ testPayload = testCase "low-level C bindings request/response " $ do grpcShutdown putStrLn "Done." -unitTests :: TestTree -unitTests = testGroup "Unit tests" +testCreateDestroyMetadata :: TestTree +testCreateDestroyMetadata = testCase "create/destroy metadataArrayPtr " $ do + grpcInit + withMetadataArrayPtr $ const (return ()) + grpcShutdown + +testCreateDestroyMetadataKeyVals :: TestTree +testCreateDestroyMetadataKeyVals = testCase "create/destroy metadata k/vs " $ do + grpcInit + withMetadataKeyValPtr 10 $ const (return ()) + grpcShutdown + +testCreateDestroyDeadline :: TestTree +testCreateDestroyDeadline = testCase "create/destroy deadline " $ do + grpcInit + withDeadlineSeconds 10 $ const (return ()) + grpcShutdown + +unsafeTests :: TestTree +unsafeTests = testGroup "Unit tests for unsafe C bindings." [testPayload, roundtripSlice "Hello, world!", roundtripByteBuffer "Hwaet! We gardena in geardagum...", testMetadata, - testNow] + testNow, + testCreateDestroyMetadata, + testCreateDestroyMetadataKeyVals, + testCreateDestroyDeadline + ] + +allTests :: TestTree +allTests = testGroup "All tests" + [ unsafeTests, + lowLevelTests] main :: IO () -main = defaultMain unitTests +main = defaultMain allTests