diff --git a/cbits/grpc_haskell.c b/cbits/grpc_haskell.c index d3a6ddb..29c4579 100644 --- a/cbits/grpc_haskell.c +++ b/cbits/grpc_haskell.c @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -158,7 +159,7 @@ void metadata_array_destroy(grpc_metadata_array **arr){ } grpc_metadata* metadata_alloc(size_t n){ - grpc_metadata *retval = malloc(sizeof(grpc_metadata)*n); + grpc_metadata *retval = calloc(n,sizeof(grpc_metadata)); return retval; } @@ -463,3 +464,85 @@ void destroy_arg_array(grpc_arg* args, size_t n){ } free(args); } + +grpc_auth_property_iterator* grpc_auth_context_property_iterator_( + const grpc_auth_context* ctx){ + + grpc_auth_property_iterator* i = malloc(sizeof(grpc_auth_property_iterator)); + *i = grpc_auth_context_property_iterator(ctx); + return i; +} + +grpc_server_credentials* ssl_server_credentials_create_internal( + const char* pem_root_certs, const char* pem_key, const char* pem_cert, + grpc_ssl_client_certificate_request_type force_client_auth){ + + grpc_ssl_pem_key_cert_pair pair = {pem_key, pem_cert}; + grpc_server_credentials* creds = grpc_ssl_server_credentials_create_ex( + pem_root_certs, &pair, 1, force_client_auth, NULL); + return creds; +} + +grpc_channel_credentials* grpc_ssl_credentials_create_internal( + const char* pem_root_certs, const char* pem_key, const char* pem_cert){ + + grpc_channel_credentials* creds; + if(pem_key && pem_cert){ + grpc_ssl_pem_key_cert_pair pair = {pem_key, pem_cert}; + creds = grpc_ssl_credentials_create(pem_root_certs, &pair, NULL); + } + else{ + creds = grpc_ssl_credentials_create(pem_root_certs, NULL, NULL); + } + return creds; +} + +void grpc_server_credentials_set_auth_metadata_processor_( + grpc_server_credentials* creds, grpc_auth_metadata_processor* p){ + + grpc_server_credentials_set_auth_metadata_processor(creds, *p); +} + +grpc_auth_metadata_processor* mk_auth_metadata_processor( + void (*process)(void *state, grpc_auth_context *context, + const grpc_metadata *md, size_t num_md, + grpc_process_auth_metadata_done_cb cb, void *user_data)){ + + //TODO: figure out when to free this. + grpc_auth_metadata_processor* p = malloc(sizeof(grpc_auth_metadata_processor)); + p->process = process; + p->destroy = NULL; + p->state = NULL; + return p; +} + +grpc_call_credentials* grpc_metadata_credentials_create_from_plugin_( + grpc_metadata_credentials_plugin* plugin){ + + return grpc_metadata_credentials_create_from_plugin(*plugin, NULL); +} + +//This is a hack to work around GHC being unable to deal with raw struct params. +//This callback is registered as the get_metadata callback for the call, and its +//only job is to cast the void* state pointer to the correct function pointer +//type and call the Haskell function with it. +void metadata_dispatcher(void *state, grpc_auth_metadata_context context, + grpc_credentials_plugin_metadata_cb cb, void *user_data){ + + ((haskell_get_metadata*)state)(&context, cb, user_data); +} + +grpc_metadata_credentials_plugin* mk_metadata_client_plugin( + haskell_get_metadata* f){ + + //TODO: figure out when to free this. + grpc_metadata_credentials_plugin* p = + malloc(sizeof(grpc_metadata_credentials_plugin)); + + p->get_metadata = metadata_dispatcher; + p->destroy = NULL; + p->state = f; + p->type = "grpc-haskell custom credentials"; + + return p; +} diff --git a/examples/echo/echo-client/Main.hs b/examples/echo/echo-client/Main.hs index ade9dfd..31c0b60 100644 --- a/examples/echo/echo-client/Main.hs +++ b/examples/echo/echo-client/Main.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedLists #-} {-# OPTIONS_GHC -fno-warn-missing-signatures #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} @@ -22,7 +23,7 @@ addMethod = MethodName "/echo.Add/DoAdd" _unregistered c = U.clientRequest c echoMethod 1 "hi" mempty regMain = withGRPC $ \g -> - withClient g (ClientConfig "localhost" 50051 []) $ \c -> do + withClient g (ClientConfig "localhost" 50051 [] Nothing) $ \c -> do rm <- clientRegisterMethodNormal c echoMethod replicateM_ 100000 $ clientRequest c rm 5 "hi" mempty >>= \case Left e -> fail $ "Got client error: " ++ show e @@ -42,7 +43,7 @@ instance Message AddResponse -- TODO: Create Network.GRPC.HighLevel.Client w/ request variants highlevelMain = withGRPC $ \g -> - withClient g (ClientConfig "localhost" 50051 []) $ \c -> do + withClient g (ClientConfig "localhost" 50051 [] Nothing) $ \c -> do rm <- clientRegisterMethodNormal c echoMethod rmAdd <- clientRegisterMethodNormal c addMethod let oneThread = replicateM_ 10000 $ body c rm rmAdd @@ -71,4 +72,5 @@ highlevelMain = withGRPC $ \g -> | dec == AddResponse (x + y) -> return () | otherwise -> fail $ "Got wrong add answer: " ++ show dec ++ "expected: " ++ show x ++ " + " ++ show y ++ " = " ++ show (x+y) +main :: IO () main = highlevelMain diff --git a/examples/echo/echo-server/Main.hs b/examples/echo/echo-server/Main.hs index 241a374..a9c6fa9 100644 --- a/examples/echo/echo-server/Main.hs +++ b/examples/echo/echo-server/Main.hs @@ -20,6 +20,7 @@ import qualified Network.GRPC.HighLevel.Server.Unregistered as U import Network.GRPC.LowLevel import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U +import Network.GRPC.Unsafe.Security as U serverMeta :: MetadataMap serverMeta = [("test_meta", "test_meta_value")] @@ -123,4 +124,4 @@ main :: IO () main = highlevelMainUnregistered defConfig :: ServerConfig -defConfig = ServerConfig "localhost" 50051 [] [] [] [] [] +defConfig = ServerConfig "localhost" 50051 [] [] [] [] [] Nothing diff --git a/examples/hellos/hellos-client/Main.hs b/examples/hellos/hellos-client/Main.hs index 4519c5f..751d56a 100644 --- a/examples/hellos/hellos-client/Main.hs +++ b/examples/hellos/hellos-client/Main.hs @@ -110,7 +110,7 @@ doHelloBi c n = do highlevelMain :: IO () highlevelMain = withGRPC $ \g -> - withClient g (ClientConfig "localhost" 50051 []) $ \c -> do + withClient g (ClientConfig "localhost" 50051 [] Nothing) $ \c -> do let n = 100000 putStrLn "-------------- HelloSS --------------" doHelloSS c n diff --git a/examples/hellos/hellos-server/Main.hs b/examples/hellos/hellos-server/Main.hs index d14256a..82601c9 100644 --- a/examples/hellos/hellos-server/Main.hs +++ b/examples/hellos/hellos-server/Main.hs @@ -80,4 +80,4 @@ main :: IO () main = highlevelMainUnregistered defConfig :: ServerConfig -defConfig = ServerConfig "localhost" 50051 [] [] [] [] [] +defConfig = ServerConfig "localhost" 50051 [] [] [] [] [] Nothing diff --git a/grpc-haskell.cabal b/grpc-haskell.cabal index d5ba88a..628487c 100644 --- a/grpc-haskell.cabal +++ b/grpc-haskell.cabal @@ -55,6 +55,7 @@ library Network.GRPC.Unsafe.Metadata Network.GRPC.Unsafe.Op Network.GRPC.Unsafe + Network.GRPC.Unsafe.Security Network.GRPC.LowLevel Network.GRPC.LowLevel.Server.Unregistered Network.GRPC.LowLevel.Client.Unregistered diff --git a/include/grpc_haskell.h b/include/grpc_haskell.h index 6773189..7093b1f 100644 --- a/include/grpc_haskell.h +++ b/include/grpc_haskell.h @@ -2,6 +2,7 @@ #define GRPC_HASKELL #include +#include #include #include #include @@ -164,4 +165,40 @@ void create_int_arg(grpc_arg* args, size_t i, void destroy_arg_array(grpc_arg* args, size_t n); +grpc_auth_property_iterator* grpc_auth_context_property_iterator_( + const grpc_auth_context* ctx); + +grpc_server_credentials* ssl_server_credentials_create_internal( + const char* pem_root_certs, const char* pem_key, const char* pem_cert, + grpc_ssl_client_certificate_request_type force_client_auth); + +grpc_channel_credentials* grpc_ssl_credentials_create_internal( + const char* pem_root_certs, const char* pem_key, const char* pem_cert); + +void grpc_server_credentials_set_auth_metadata_processor_( + grpc_server_credentials* creds, grpc_auth_metadata_processor* p); + +//packs a Haskell server-side auth processor function pointer into the +//appropriate struct expected by gRPC. +grpc_auth_metadata_processor* mk_auth_metadata_processor( + void (*process)(void *state, grpc_auth_context *context, + const grpc_metadata *md, size_t num_md, + grpc_process_auth_metadata_done_cb cb, void *user_data)); + +grpc_call_credentials* grpc_metadata_credentials_create_from_plugin_( + grpc_metadata_credentials_plugin* plugin); + +//type of the callback used to create auth metadata on the client +typedef void (*get_metadata) + (void *state, grpc_auth_metadata_context context, + grpc_credentials_plugin_metadata_cb cb, void *user_data); + +//type of the Haskell callback that we use to create auth metadata on the client +typedef void haskell_get_metadata(grpc_auth_metadata_context*, + grpc_credentials_plugin_metadata_cb, + void*); + +grpc_metadata_credentials_plugin* mk_metadata_client_plugin( + haskell_get_metadata* f); + #endif //GRPC_HASKELL diff --git a/src/Network/GRPC/HighLevel.hs b/src/Network/GRPC/HighLevel.hs index a230a3d..f0cf411 100644 --- a/src/Network/GRPC/HighLevel.hs +++ b/src/Network/GRPC/HighLevel.hs @@ -1,10 +1,50 @@ module Network.GRPC.HighLevel ( + +-- * Types + MetadataMap(..) +, MethodName(..) +, StatusDetails(..) +, StatusCode(..) +, GRPCIOError(..) + -- * Server -Handler(..) +, Handler(..) , ServerOptions(..) , defaultOptions , serverLoop +, ServerCall(..) +, serverCallCancel +, serverCallIsExpired + +-- * Client +, NormalRequestResult(..) +, ClientCall +, clientCallCancel + +-- * Client and Server Auth +, AuthContext +, AuthProperty(..) +, getAuthProperties +, addAuthProperty + +-- * Server Auth +, ServerSSLConfig(..) +, ProcessMeta +, AuthProcessorResult(..) +, SslClientCertificateRequestType(..) + +-- * Client Auth +, ClientSSLConfig(..) +, ClientSSLKeyCertPair(..) +, ClientMetadataCreate +, ClientMetadataCreateResult(..) +, AuthMetadataContext(..) + +-- * Streaming utilities +, StreamSend +, StreamRecv ) where -import Network.GRPC.HighLevel.Server +import Network.GRPC.HighLevel.Server +import Network.GRPC.LowLevel diff --git a/src/Network/GRPC/HighLevel/Server.hs b/src/Network/GRPC/HighLevel/Server.hs index 03b9e46..475a36c 100644 --- a/src/Network/GRPC/HighLevel/Server.hs +++ b/src/Network/GRPC/HighLevel/Server.hs @@ -151,6 +151,7 @@ data ServerOptions = ServerOptions , optUserAgentPrefix :: String , optUserAgentSuffix :: String , optInitialMetadata :: MetadataMap + , optSSLConfig :: Maybe ServerSSLConfig } defaultOptions :: ServerOptions @@ -164,6 +165,7 @@ defaultOptions = ServerOptions , optUserAgentPrefix = "grpc-haskell/0.0.0" , optUserAgentSuffix = "" , optInitialMetadata = mempty + , optSSLConfig = Nothing } serverLoop :: ServerOptions -> IO () diff --git a/src/Network/GRPC/HighLevel/Server/Unregistered.hs b/src/Network/GRPC/HighLevel/Server/Unregistered.hs index 467fd2b..24d3956 100644 --- a/src/Network/GRPC/HighLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/HighLevel/Server/Unregistered.hs @@ -92,4 +92,5 @@ serverLoop ServerOptions{..} = do [ UserAgentPrefix optUserAgentPrefix , UserAgentSuffix optUserAgentSuffix ] + , sslConfig = optSSLConfig } diff --git a/src/Network/GRPC/LowLevel.hs b/src/Network/GRPC/LowLevel.hs index 880b813..63269a5 100644 --- a/src/Network/GRPC/LowLevel.hs +++ b/src/Network/GRPC/LowLevel.hs @@ -51,6 +51,25 @@ GRPC , serverRW -- for bidirectional streaming , ServerRWHandlerLL +-- * Client and Server Auth +, AuthContext +, AuthProperty(..) +, getAuthProperties +, addAuthProperty + +-- * Server Auth +, ServerSSLConfig(..) +, ProcessMeta +, AuthProcessorResult(..) +, SslClientCertificateRequestType(..) + +-- * Client Auth +, ClientSSLConfig(..) +, ClientSSLKeyCertPair(..) +, ClientMetadataCreate +, ClientMetadataCreateResult(..) +, AuthMetadataContext(..) + -- * Client , ClientConfig(..) , Client @@ -63,6 +82,7 @@ GRPC , clientRegisterMethodServerStreaming , clientRegisterMethodBiDiStreaming , clientRequest +, clientRequestParent , clientReader -- for server streaming , clientWriter -- for client streaming , clientRW -- for bidirectional streaming @@ -80,15 +100,25 @@ GRPC ) where -import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Server -import Network.GRPC.LowLevel.CompletionQueue -import Network.GRPC.LowLevel.Op -import Network.GRPC.LowLevel.Client import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.Client +import Network.GRPC.LowLevel.CompletionQueue +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Op +import Network.GRPC.LowLevel.Server -import Network.GRPC.Unsafe (ConnectivityState(..)) -import Network.GRPC.Unsafe.Op (StatusCode(..)) -import Network.GRPC.Unsafe.ChannelArgs(Arg(..) - , CompressionAlgorithm(..) - , CompressionLevel(..)) +import Network.GRPC.Unsafe (ConnectivityState (..)) +import Network.GRPC.Unsafe.ChannelArgs (Arg (..), CompressionAlgorithm (..), + CompressionLevel (..)) +import Network.GRPC.Unsafe.ChannelArgs (Arg (..), CompressionAlgorithm (..)) +import Network.GRPC.Unsafe.Op (StatusCode (..)) +import Network.GRPC.Unsafe.Security (AuthContext, + AuthMetadataContext (..), + AuthProcessorResult (..), + AuthProperty (..), + ClientMetadataCreate, + ClientMetadataCreateResult (..), + ProcessMeta, + SslClientCertificateRequestType (..), + addAuthProperty, + getAuthProperties) diff --git a/src/Network/GRPC/LowLevel/Call.hs b/src/Network/GRPC/LowLevel/Call.hs index 9d9cca9..8f71bd0 100644 --- a/src/Network/GRPC/LowLevel/Call.hs +++ b/src/Network/GRPC/LowLevel/Call.hs @@ -4,6 +4,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -17,7 +18,9 @@ module Network.GRPC.LowLevel.Call where import Control.Monad.Managed (Managed, managed) import Control.Exception (bracket) import Data.ByteString (ByteString) +import Data.ByteString.Char8 (pack) import Data.List (intersperse) +import Data.Monoid import Data.String (IsString) import Foreign.Marshal.Alloc (free, malloc) import Foreign.Ptr (Ptr, nullPtr) @@ -58,21 +61,21 @@ extractPayload (RegisteredMethodServerStreaming _ _ _) p = peek p >>= C.copyByteBufferToByteString extractPayload (RegisteredMethodBiDiStreaming _ _ _) _ = return () -newtype MethodName = MethodName {unMethodName :: String} +newtype MethodName = MethodName {unMethodName :: ByteString} deriving (Show, Eq, IsString) -newtype Host = Host {unHost :: String} +newtype Host = Host {unHost :: ByteString} deriving (Show, Eq, IsString) newtype Port = Port {unPort :: Int} deriving (Eq, Num, Show) -newtype Endpoint = Endpoint {unEndpoint :: String} +newtype Endpoint = Endpoint {unEndpoint :: ByteString} deriving (Show, Eq, IsString) -- | Given a hostname and port, produces a "host:port" string endpoint :: Host -> Port -> Endpoint -endpoint (Host h) (Port p) = Endpoint (h ++ ":" ++ show p) +endpoint (Host h) (Port p) = Endpoint (h <> ":" <> pack (show p)) -- | Represents a registered method. Methods can optionally be registered in -- order to make the C-level request/response code simpler. Before making or diff --git a/src/Network/GRPC/LowLevel/Client.hs b/src/Network/GRPC/LowLevel/Client.hs index 48e5ad5..a13d3ab 100644 --- a/src/Network/GRPC/LowLevel/Client.hs +++ b/src/Network/GRPC/LowLevel/Client.hs @@ -16,6 +16,7 @@ import Control.Concurrent.MVar import Control.Monad import Control.Monad.IO.Class import Control.Monad.Trans.Except +import qualified Data.ByteString as B import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call import Network.GRPC.LowLevel.CompletionQueue @@ -25,6 +26,7 @@ import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.ChannelArgs as C import qualified Network.GRPC.Unsafe.Constants as C import qualified Network.GRPC.Unsafe.Op as C +import qualified Network.GRPC.Unsafe.Security as C import qualified Network.GRPC.Unsafe.Time as C -- | Represents the context needed to perform client-side gRPC operations. @@ -33,25 +35,72 @@ data Client = Client {clientChannel :: C.Channel, clientConfig :: ClientConfig } +data ClientSSLKeyCertPair = ClientSSLKeyCertPair + {clientPrivateKey :: FilePath, + clientCert :: FilePath} + +-- | SSL configuration for the client. It's perfectly acceptable for both fields +-- to be 'Nothing', in which case default fallbacks will be used for the server +-- root cert. +data ClientSSLConfig = ClientSSLConfig + {serverRootCert :: Maybe FilePath, + -- ^ Path to the server root certificate. If 'Nothing', gRPC will attempt to + -- fall back to a default. + clientSSLKeyCertPair :: Maybe ClientSSLKeyCertPair, + -- ^ The client's private key and cert, if available. + clientMetadataPlugin :: Maybe C.ClientMetadataCreate + -- ^ Optional plugin for attaching additional metadata to each call. + } + -- | Configuration necessary to set up a client. data ClientConfig = ClientConfig {serverHost :: Host, serverPort :: Port, - clientArgs :: [C.Arg] + clientArgs :: [C.Arg], -- ^ Optional arguments for setting up the -- channel on the client. Supplying an empty -- list will cause the channel to use gRPC's -- default options. + clientSSLConfig :: Maybe ClientSSLConfig + -- ^ If 'Nothing', the client will use an + -- insecure connection to the server. + -- Otherwise, will use the supplied config to + -- connect using SSL. } clientEndpoint :: ClientConfig -> Endpoint clientEndpoint ClientConfig{..} = endpoint serverHost serverPort +addMetadataCreds :: C.ChannelCredentials + -> Maybe C.ClientMetadataCreate + -> IO C.ChannelCredentials +addMetadataCreds c Nothing = return c +addMetadataCreds c (Just create) = do + callCreds <- C.createCustomCallCredentials create + C.compositeChannelCredentialsCreate c callCreds C.reserved + +createChannel :: ClientConfig -> C.GrpcChannelArgs -> IO C.Channel +createChannel conf@ClientConfig{..} chanargs = + case clientSSLConfig of + Nothing -> C.grpcInsecureChannelCreate e chanargs C.reserved + Just (ClientSSLConfig rootCertPath Nothing plugin) -> + do rootCert <- mapM B.readFile rootCertPath + C.withChannelCredentials rootCert Nothing Nothing $ \creds -> do + creds' <- addMetadataCreds creds plugin + C.secureChannelCreate creds' e chanargs C.reserved + Just (ClientSSLConfig x (Just (ClientSSLKeyCertPair y z)) plugin) -> + do rootCert <- mapM B.readFile x + privKey <- Just <$> B.readFile y + clientCert <- Just <$> B.readFile z + C.withChannelCredentials rootCert privKey clientCert $ \creds -> do + creds' <- addMetadataCreds creds plugin + C.secureChannelCreate creds' e chanargs C.reserved + where (Endpoint e) = clientEndpoint conf + createClient :: GRPC -> ClientConfig -> IO Client -createClient grpc clientConfig@ClientConfig{..} = - C.withChannelArgs clientArgs $ \chanargs -> do - let Endpoint e = clientEndpoint clientConfig - clientChannel <- C.grpcInsecureChannelCreate e chanargs C.reserved +createClient grpc clientConfig = + C.withChannelArgs (clientArgs clientConfig) $ \chanargs -> do + clientChannel <- createChannel clientConfig chanargs clientCQ <- createCompletionQueue grpc return Client{..} @@ -162,7 +211,7 @@ withClientCall cl rm tm = withClientCallParent cl rm tm Nothing withClientCallParent :: Client -> RegisteredMethod mt -> TimeoutSeconds - -> Maybe (ServerCall a) + -> Maybe (ServerCall b) -- ^ Optional parent call for cascading cancellation -> (ClientCall -> IO (Either GRPCIOError a)) -> IO (Either GRPCIOError a) @@ -359,16 +408,33 @@ clientRW' (clientCQ -> cq) (unsafeCC -> c) initMeta f = runExceptT $ do -- | Make a request of the given method with the given body. Returns the -- server's response. -clientRequest :: Client - -> RegisteredMethod 'Normal - -> TimeoutSeconds - -> ByteString - -- ^ The body of the request - -> MetadataMap - -- ^ Metadata to send with the request - -> IO (Either GRPCIOError NormalRequestResult) -clientRequest cl@(clientCQ -> cq) rm tm body initMeta = - withClientCall cl rm tm (fmap join . go) +clientRequest + :: Client + -> RegisteredMethod 'Normal + -> TimeoutSeconds + -> ByteString + -- ^ The body of the request + -> MetadataMap + -- ^ Metadata to send with the request + -> IO (Either GRPCIOError NormalRequestResult) +clientRequest c = clientRequestParent c Nothing + +-- | Like 'clientRequest', but allows the user to supply an optional parent +-- call, so that call cancellation can be propagated from the parent to the +-- child. This is intended for servers that call other servers. +clientRequestParent + :: Client + -> Maybe (ServerCall a) + -- ^ optional parent call + -> RegisteredMethod 'Normal + -> TimeoutSeconds + -> ByteString + -- ^ The body of the request + -> MetadataMap + -- ^ Metadata to send with the request + -> IO (Either GRPCIOError NormalRequestResult) +clientRequestParent cl@(clientCQ -> cq) p rm tm body initMeta = + withClientCallParent cl rm tm p (fmap join . go) where go (unsafeCC -> c) = -- NB: the send and receive operations below *must* be in separate diff --git a/src/Network/GRPC/LowLevel/GRPC.hs b/src/Network/GRPC/LowLevel/GRPC.hs index 15a7e07..203bb9e 100644 --- a/src/Network/GRPC/LowLevel/GRPC.hs +++ b/src/Network/GRPC/LowLevel/GRPC.hs @@ -11,7 +11,7 @@ GRPC , grpcDebug' , threadDelaySecs , C.MetadataMap(..) -, StatusDetails(..) +, C.StatusDetails(..) ) where import Control.Concurrent (threadDelay, myThreadId) @@ -26,9 +26,6 @@ import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.Metadata as C import Proto3.Wire.Decode (ParseError) -newtype StatusDetails = StatusDetails ByteString - deriving (Eq, IsString, Monoid, Show) - -- | Functions as a proof that the gRPC core has been started. The gRPC core -- must be initialized to create any gRPC state, so this is a requirement for -- the server and client create/start functions. @@ -54,7 +51,7 @@ data GRPCIOError = GRPCIOCallError C.CallError -- ^ Thrown if a 'CompletionQueue' fails to shut down in a -- reasonable amount of time. | GRPCIOUnknownError - | GRPCIOBadStatusCode C.StatusCode StatusDetails + | GRPCIOBadStatusCode C.StatusCode C.StatusDetails | GRPCIODecodeError ParseError | GRPCIOInternalUnexpectedRecv String -- debugging description diff --git a/src/Network/GRPC/LowLevel/Op.hs b/src/Network/GRPC/LowLevel/Op.hs index c895753..7346775 100644 --- a/src/Network/GRPC/LowLevel/Op.hs +++ b/src/Network/GRPC/LowLevel/Op.hs @@ -65,16 +65,13 @@ defaultStatusStringLen = 128 -- | Allocates and initializes the 'Opcontext' corresponding to the given 'Op'. createOpContext :: Op -> IO OpContext createOpContext (OpSendInitialMetadata m) = - OpSendInitialMetadataContext - <$> C.createMetadata m - <*> return (length $ toList m) + uncurry OpSendInitialMetadataContext <$> C.createMetadata m createOpContext (OpSendMessage bs) = fmap OpSendMessageContext (C.createByteBuffer bs) createOpContext (OpSendCloseFromClient) = return OpSendCloseFromClientContext createOpContext (OpSendStatusFromServer m code (StatusDetails str)) = - OpSendStatusFromServerContext + uncurry OpSendStatusFromServerContext <$> C.createMetadata m - <*> return (length $ toList m) <*> return code <*> return str createOpContext OpRecvInitialMetadata = diff --git a/src/Network/GRPC/LowLevel/Server.hs b/src/Network/GRPC/LowLevel/Server.hs index 152b9b8..c0dacbd 100644 --- a/src/Network/GRPC/LowLevel/Server.hs +++ b/src/Network/GRPC/LowLevel/Server.hs @@ -31,6 +31,7 @@ import Control.Monad import Control.Monad.IO.Class import Control.Monad.Trans.Except import Data.ByteString (ByteString) +import qualified Data.ByteString as B import qualified Data.Set as S import Network.GRPC.LowLevel.Call import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, @@ -45,6 +46,7 @@ import Network.GRPC.LowLevel.Op import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.ChannelArgs as C import qualified Network.GRPC.Unsafe.Op as C +import qualified Network.GRPC.Unsafe.Security as C -- | Wraps various gRPC state needed to run a server. data Server = Server @@ -104,6 +106,16 @@ forkServer Server{..} f = do tid <- myThreadId atomically $ modifyTVar' outstandingForks (S.delete tid) +-- | Configuration for SSL. +data ServerSSLConfig = ServerSSLConfig + {clientRootCert :: Maybe FilePath, + serverPrivateKey :: FilePath, + serverCert :: FilePath, + clientCertRequest :: C.SslClientCertificateRequestType, + -- ^ Whether to request a certificate from the client, and what to do with it + -- if received. + customMetadataProcessor :: Maybe C.ProcessMeta} + -- | Configuration needed to start a server. data ServerConfig = ServerConfig { host :: Host @@ -119,18 +131,35 @@ data ServerConfig = ServerConfig , serverArgs :: [C.Arg] -- ^ Optional arguments for setting up the channel on the server. Supplying an -- empty list will cause the channel to use gRPC's default options. + , sslConfig :: Maybe ServerSSLConfig + -- ^ Server-side SSL configuration. If 'Nothing', the server will use an + -- insecure connection. } - deriving (Show, Eq) serverEndpoint :: ServerConfig -> Endpoint serverEndpoint ServerConfig{..} = endpoint host port +addPort :: C.Server -> ServerConfig -> IO Int +addPort server conf@ServerConfig{..} = + case sslConfig of + Nothing -> C.grpcServerAddInsecureHttp2Port server e + Just ServerSSLConfig{..} -> + do crc <- mapM B.readFile clientRootCert + spk <- B.readFile serverPrivateKey + sc <- B.readFile serverCert + C.withServerCredentials crc spk sc clientCertRequest $ \creds -> do + case customMetadataProcessor of + Just p -> C.setMetadataProcessor creds p + Nothing -> return () + C.serverAddSecureHttp2Port server e creds + where e = unEndpoint $ serverEndpoint conf + startServer :: GRPC -> ServerConfig -> IO Server startServer grpc conf@ServerConfig{..} = C.withChannelArgs serverArgs $ \args -> do let e = serverEndpoint conf server <- C.grpcServerCreate args C.reserved - actualPort <- C.grpcServerAddInsecureHttp2Port server (unEndpoint e) + actualPort <- addPort server conf when (actualPort /= unPort port) $ error $ "Unable to bind port: " ++ show port cq <- createCompletionQueue grpc diff --git a/src/Network/GRPC/Unsafe.chs b/src/Network/GRPC/Unsafe.chs index 02669f6..e6a7944 100644 --- a/src/Network/GRPC/Unsafe.chs +++ b/src/Network/GRPC/Unsafe.chs @@ -1,15 +1,19 @@ -{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE StandaloneDeriving #-} module Network.GRPC.Unsafe where import Control.Exception (bracket) import Control.Monad +import Data.ByteString (ByteString, useAsCString, packCString) + import Foreign.C.String (CString, peekCString) import Foreign.C.Types import Foreign.Marshal.Alloc (free) import Foreign.Ptr import Foreign.Storable +import GHC.Exts (IsString(..)) {#import Network.GRPC.Unsafe.Time#} import Network.GRPC.Unsafe.Constants @@ -25,6 +29,9 @@ import Network.GRPC.Unsafe.Constants {#context prefix = "grpc" #} +newtype StatusDetails = StatusDetails {unStatusDetails :: ByteString} + deriving (Eq, IsString, Monoid, Show) + {#pointer *grpc_completion_queue as CompletionQueue newtype #} deriving instance Show CompletionQueue @@ -169,7 +176,7 @@ castPeek p = do -- 'grpcInsecureChannelCreate' is the one that is actually used. {#fun grpc_channel_create_call_ as ^ {`Channel', `Call', fromIntegral `PropagationMask', `CompletionQueue', - `String', `String', `CTimeSpecPtr',unReserved `Reserved'} + useAsCString* `ByteString', useAsCString* `ByteString', `CTimeSpecPtr',unReserved `Reserved'} -> `Call'#} -- | Create a channel (on the client) to the server. The first argument is @@ -178,10 +185,10 @@ castPeek p = do -- expose any functions for creating channel args, since they are entirely -- undocumented. {#fun grpc_insecure_channel_create as ^ - {`String', `GrpcChannelArgs', unReserved `Reserved'} -> `Channel'#} + {useAsCString* `ByteString', `GrpcChannelArgs', unReserved `Reserved'} -> `Channel'#} {#fun grpc_channel_register_call as ^ - {`Channel', `String', `String',unReserved `Reserved'} + {`Channel', useAsCString* `ByteString',useAsCString* `ByteString',unReserved `Reserved'} -> `CallHandle' CallHandle#} {#fun grpc_channel_create_registered_call_ as ^ @@ -237,13 +244,13 @@ getPeerPeek cstr = do {`GrpcChannelArgs',unReserved `Reserved'} -> `Server'#} {#fun grpc_server_register_method_ as ^ - {`Server', `String', `String', `ServerRegisterMethodPayloadHandling'} -> `CallHandle' CallHandle#} + {`Server',useAsCString* `ByteString',useAsCString* `ByteString', `ServerRegisterMethodPayloadHandling'} -> `CallHandle' CallHandle#} {#fun grpc_server_register_completion_queue as ^ {`Server', `CompletionQueue', unReserved `Reserved'} -> `()'#} {#fun grpc_server_add_insecure_http2_port as ^ - {`Server', `String'} -> `Int'#} + {`Server', useAsCString* `ByteString'} -> `Int'#} -- | Starts a server. To shut down the server, call these in order: -- 'grpcServerShutdownAndNotify', 'grpcServerCancelAllCalls', @@ -280,8 +287,8 @@ getPeerPeek cstr = do `CompletionQueue',unTag `Tag'} -> `CallError'#} -{#fun unsafe call_details_get_method as ^ {`CallDetails'} -> `String'#} +{#fun unsafe call_details_get_method as ^ {`CallDetails'} -> `ByteString' packCString* #} -{#fun unsafe call_details_get_host as ^ {`CallDetails'} -> `String'#} +{#fun unsafe call_details_get_host as ^ {`CallDetails'} -> `ByteString' packCString* #} {#fun call_details_get_deadline as ^ {`CallDetails'} -> `CTimeSpec' peek* #} diff --git a/src/Network/GRPC/Unsafe/Metadata.chs b/src/Network/GRPC/Unsafe/Metadata.chs index 6d3e70d..87d9352 100644 --- a/src/Network/GRPC/Unsafe/Metadata.chs +++ b/src/Network/GRPC/Unsafe/Metadata.chs @@ -135,13 +135,19 @@ getMetadataVal m i = do vStr <- getMetadataVal' m i vLen <- getMetadataValLen m i packCStringLen (vStr, vLen) -createMetadata :: MetadataMap -> IO MetadataKeyValPtr +createMetadata :: MetadataMap -> IO (MetadataKeyValPtr, Int) createMetadata m = do let indexedKeyVals = zip [0..] $ toList m l = length indexedKeyVals metadata <- metadataAlloc l forM_ indexedKeyVals $ \(i,(k,v)) -> setMetadataKeyVal k v metadata i - return metadata + return (metadata, l) + +withPopulatedMetadataKeyValPtr :: MetadataMap + -> ((MetadataKeyValPtr, Int) -> IO a) + -> IO a +withPopulatedMetadataKeyValPtr m = bracket (createMetadata m) + (metadataFree . fst) getAllMetadataArray :: MetadataArray -> IO MetadataMap getAllMetadataArray m = do diff --git a/src/Network/GRPC/Unsafe/Security.chs b/src/Network/GRPC/Unsafe/Security.chs new file mode 100644 index 0000000..10727b9 --- /dev/null +++ b/src/Network/GRPC/Unsafe/Security.chs @@ -0,0 +1,398 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} + +module Network.GRPC.Unsafe.Security where + +import Control.Exception (bracket) +import Data.ByteString (ByteString, useAsCString, packCString, packCStringLen) +import Data.Coerce (coerce) +import Foreign.C.String (CString, withCString) +import Foreign.C.Types +import Foreign.Storable +import Foreign.Marshal.Alloc (free) +import Foreign.Ptr (nullPtr, FunPtr, Ptr, castPtr) + +#include +#include +#include + +{#import Network.GRPC.Unsafe#} +{#import Network.GRPC.Unsafe.ChannelArgs#} +{#import Network.GRPC.Unsafe.Metadata#} +{#import Network.GRPC.Unsafe.Op#} + +{#context prefix = "grpc_"#} + +-- * Types + +-- | Context for auth. This is essentially just a set of key-value pairs that +-- can be mutated. +-- Note: it appears that any keys set or modified on this object do not +-- appear in the AuthContext of the peer, so you must send along auth info +-- in the metadata. It's currently unclear to us what +-- the purpose of modifying this is, but we offer the ability for the sake of +-- completeness. +{#pointer *auth_context as ^ newtype#} + +deriving instance Show AuthContext + +instance Storable AuthContext where + sizeOf (AuthContext p) = sizeOf p + alignment (AuthContext p) = alignment p + peek p = AuthContext <$> peek (castPtr p) + poke p (AuthContext r) = poke (castPtr p) r + +{#pointer *auth_property_iterator as ^ newtype#} + +{#pointer *call_credentials as ^ newtype#} + +{#pointer *channel_credentials as ^ newtype#} + +{#pointer *server_credentials as ^ newtype#} + +withAuthPropertyIterator :: AuthContext + -> (AuthPropertyIterator -> IO a) + -> IO a +withAuthPropertyIterator ctx = bracket (authContextPropertyIterator ctx) + (free . coerce) + +-- | Represents one key/value pair in an 'AuthContext'. +data AuthProperty = AuthProperty + {authPropName :: ByteString, + authPropValue :: ByteString} + deriving (Show, Eq) + +marshalAuthProperty :: Ptr AuthProperty -> IO AuthProperty +marshalAuthProperty p = do + n <- packCString =<< ({# get auth_property->name #} p) + vl <- fromIntegral <$> {# get auth_property->value_length #} p + v <- packCStringLen . (,vl) =<< {#get auth_property->value #} p + return $ AuthProperty n v + +-- | The context which a client-side auth metadata plugin sees when it runs. +data AuthMetadataContext = AuthMetadataContext + {serviceURL :: ByteString, + -- ^ The URL of the service the current call is going to. + methodName :: ByteString, + -- ^ The method that is being invoked with the current call. It appears that + -- the gRPC 0.15 core is not populating this correctly, because it's an empty + -- string in my tests so far. + channelAuthContext :: AuthContext + } + deriving Show + +authMetadataContextMarshal :: Ptr AuthMetadataContext -> IO AuthMetadataContext +authMetadataContextMarshal p = + AuthMetadataContext + <$> (({#get auth_metadata_context->service_url #} p) >>= packCString) + <*> (({#get auth_metadata_context->method_name #} p) >>= packCString) + <*> ({#get auth_metadata_context->channel_auth_context#} p) + +{#pointer *metadata_credentials_plugin as ^ newtype#} + +{#pointer *auth_metadata_processor as ^ newtype#} + +{#enum ssl_client_certificate_request_type as ^ {underscoreToCase} + deriving (Eq, Ord, Bounded, Show)#} + +-- * Auth Contexts + +-- | If used, the 'AuthContext' must be released with 'AuthContextRelease'. +{#fun unsafe call_auth_context as ^ + {`Call'} -> `AuthContext'#} + +{#fun unsafe auth_context_release as ^ + {`AuthContext'} -> `()'#} + +{#fun unsafe auth_context_add_cstring_property as addAuthProperty' + {`AuthContext', + useAsCString* `ByteString', + useAsCString* `ByteString'} + -> `()'#} + +-- | Adds a new property to the given 'AuthContext'. +addAuthProperty :: AuthContext -> AuthProperty -> IO () +addAuthProperty ctx prop = + addAuthProperty' ctx (authPropName prop) (authPropValue prop) + +{- +TODO: The functions for getting and setting peer identities cause +unpredictable crashes when used in conjunction with other, more general +auth property getter/setter functions. If we end needing these, we should +investigate further. + +coercePack :: Ptr a -> IO ByteString +coercePack = packCString . coerce + +{#fun unsafe grpc_auth_context_peer_identity_property_name + as getPeerIdentityPropertyName + {`AuthContext'} -> `ByteString' coercePack* #} + +{#fun unsafe auth_context_set_peer_identity_property_name + as setPeerIdentity + {`AuthContext', useAsCString* `ByteString'} -> `()'#} + +{#fun unsafe auth_context_peer_identity as getPeerIdentity + {`AuthContext'} -> `ByteString' coercePack* #} +-} +-- * Property Iteration + +{#fun unsafe auth_context_property_iterator_ as ^ + {`AuthContext'} -> `AuthPropertyIterator'#} + +{#fun unsafe auth_property_iterator_next as ^ + {`AuthPropertyIterator'} -> `Ptr AuthProperty' coerce#} + +getAuthProperties :: AuthContext -> IO [AuthProperty] +getAuthProperties ctx = withAuthPropertyIterator ctx $ \i -> do + go i + where go :: AuthPropertyIterator -> IO [AuthProperty] + go i = do p <- authPropertyIteratorNext i + if p == nullPtr + then return [] + else do props <- go i + prop <- marshalAuthProperty p + return (prop:props) + +-- * Channel Credentials + +{#fun unsafe channel_credentials_release as ^ + {`ChannelCredentials'} -> `()'#} + +{#fun unsafe composite_channel_credentials_create as ^ + {`ChannelCredentials', `CallCredentials',unReserved `Reserved'} + -> `ChannelCredentials'#} + +{#fun unsafe ssl_credentials_create_internal as ^ + {`CString', `CString', `CString'} -> `ChannelCredentials'#} + +sslChannelCredentialsCreate :: Maybe ByteString + -> Maybe ByteString + -> Maybe ByteString + -> IO ChannelCredentials +sslChannelCredentialsCreate (Just s) Nothing Nothing = + useAsCString s $ \s' -> sslCredentialsCreateInternal s' nullPtr nullPtr +sslChannelCredentialsCreate Nothing (Just s1) (Just s2) = + useAsCString s1 $ \s1' -> useAsCString s2 $ \s2' -> + sslCredentialsCreateInternal nullPtr s1' s2' +sslChannelCredentialsCreate (Just s1) (Just s2) (Just s3) = + useAsCString s1 $ \s1' -> useAsCString s2 $ \s2' -> useAsCString s3 $ \s3' -> + sslCredentialsCreateInternal s1' s2' s3' +sslChannelCredentialsCreate (Just s1) _ _ = + useAsCString s1 $ \s1' -> + sslCredentialsCreateInternal s1' nullPtr nullPtr +sslChannelCredentialsCreate _ _ _ = + sslCredentialsCreateInternal nullPtr nullPtr nullPtr + +withChannelCredentials :: Maybe ByteString + -> Maybe ByteString + -> Maybe ByteString + -> (ChannelCredentials -> IO a) + -> IO a +withChannelCredentials x y z = bracket (sslChannelCredentialsCreate x y z) + channelCredentialsRelease + +-- * Call Credentials + +{#fun call_set_credentials as ^ + {`Call', `CallCredentials'} -> `CallError'#} + +{#fun unsafe call_credentials_release as ^ + {`CallCredentials'} -> `()'#} + +{#fun unsafe composite_call_credentials_create as ^ + {`CallCredentials', `CallCredentials', unReserved `Reserved'} + -> `CallCredentials'#} + +-- * Server Credentials + +{#fun unsafe server_credentials_release as ^ + {`ServerCredentials'} -> `()'#} + +{#fun ssl_server_credentials_create_internal as ^ + {`CString', + useAsCString* `ByteString', + useAsCString* `ByteString', + `SslClientCertificateRequestType'} + -> `ServerCredentials'#} + +sslServerCredentialsCreate :: Maybe ByteString + -- ^ PEM encoding of the client root certificates. + -- Can be 'Nothing' if SSL authentication of + -- clients is not desired. + -> ByteString + -- ^ Server private key. + -> ByteString + -- ^ Server certificate. + -> SslClientCertificateRequestType + -- ^ How to handle client certificates. + -> IO ServerCredentials +sslServerCredentialsCreate Nothing k c t = + sslServerCredentialsCreateInternal nullPtr k c t +sslServerCredentialsCreate (Just cc) k c t = + useAsCString cc $ \cc' -> sslServerCredentialsCreateInternal cc' k c t + +withServerCredentials :: Maybe ByteString + -- ^ PEM encoding of the client root certificates. + -- Can be 'Nothing' if SSL authentication of + -- clients is not desired. + -> ByteString + -- ^ Server private key. + -> ByteString + -- ^ Server certificate. + -> SslClientCertificateRequestType + -- ^ How to handle client certificates. + -> (ServerCredentials -> IO a) + -> IO a +withServerCredentials a b c d = bracket (sslServerCredentialsCreate a b c d) + serverCredentialsRelease + +-- * Creating Secure Clients/Servers + +{#fun server_add_secure_http2_port as ^ + {`Server',useAsCString* `ByteString', `ServerCredentials'} -> `Int'#} + +{#fun secure_channel_create as ^ + {`ChannelCredentials',useAsCString* `ByteString', `GrpcChannelArgs', unReserved `Reserved'} + -> `Channel'#} + +-- * Custom metadata processing -- server side + +-- | Type synonym for the raw function pointer we pass to C to handle custom +-- server-side metadata auth processing. +type CAuthProcess = Ptr () + -> AuthContext + -> MetadataKeyValPtr + -> CSize + -> FunPtr CDoneCallback + -> Ptr () + -> IO () + +foreign import ccall "wrapper" + mkAuthProcess :: CAuthProcess -> IO (FunPtr CAuthProcess) + +type CDoneCallback = Ptr () + -> MetadataKeyValPtr + -> CSize + -> MetadataKeyValPtr + -> CSize + -> CInt -- ^ statuscode + -> CString -- ^ status details + -> IO () + +foreign import ccall "dynamic" + unwrapDoneCallback :: FunPtr CDoneCallback -> CDoneCallback + +{#fun server_credentials_set_auth_metadata_processor_ as ^ + {`ServerCredentials', `AuthMetadataProcessor'} -> `()'#} + +foreign import ccall "grpc_haskell.h mk_auth_metadata_processor" + mkAuthMetadataProcessor :: FunPtr CAuthProcess -> IO AuthMetadataProcessor + +data AuthProcessorResult = AuthProcessorResult + { resultConsumedMetadata :: MetadataMap + -- ^ Metadata to remove from the request before passing to the handler. + , resultResponseMetadata :: MetadataMap + -- ^ Metadata to add to the response. + , resultStatus :: StatusCode + -- ^ StatusOk if auth was successful. Using any other status code here will + -- cause the request to be rejected without reaching a handler. + -- For rejected requests, it's suggested that this + -- be StatusUnauthenticated or StatusPermissionDenied. + -- NOTE: if you are using the low-level interface and the request is rejected, + -- then handling functions in the low-level + -- interface such as 'serverHandleNormalCall' will not unblock until they + -- receive another request that is not rejected. So, if you write a buggy + -- auth plugin that rejects all requests, your server could hang. + , resultStatusDetails :: StatusDetails} + +-- | A custom auth metadata processor. This can be used to implement customized +-- auth schemes based on the metadata in the request. +type ProcessMeta = AuthContext + -> MetadataMap + -> IO AuthProcessorResult + +convertProcessor :: ProcessMeta -> CAuthProcess +convertProcessor f = \_state authCtx inMeta numMeta callBack userDataPtr -> do + meta <- getAllMetadata inMeta (fromIntegral numMeta) + AuthProcessorResult{..} <- f authCtx meta + let cb = unwrapDoneCallback callBack + let status = (fromEnum resultStatus) + withPopulatedMetadataKeyValPtr resultConsumedMetadata $ \(conMeta, conLen) -> + withPopulatedMetadataKeyValPtr resultResponseMetadata $ \(resMeta, resLen) -> + useAsCString (unStatusDetails resultStatusDetails) $ \dtls -> do + cb userDataPtr + conMeta + (fromIntegral conLen) + resMeta + (fromIntegral resLen) + (fromIntegral status) + dtls + +-- | Sets the custom metadata processor for the given server credentials. +setMetadataProcessor :: ServerCredentials -> ProcessMeta -> IO () +setMetadataProcessor creds processor = do + let rawProcessor = convertProcessor processor + rawProcessorPtr <- mkAuthProcess rawProcessor + metaProcessor <- mkAuthMetadataProcessor rawProcessorPtr + serverCredentialsSetAuthMetadataProcessor creds metaProcessor + +-- * Client-side metadata plugins + +type CGetMetadata = Ptr AuthMetadataContext + -> FunPtr CGetMetadataCallBack + -> Ptr () + -- ^ user data ptr (opaque, but must be passed on) + -> IO () + +foreign import ccall "wrapper" + mkCGetMetadata :: CGetMetadata -> IO (FunPtr CGetMetadata) + +type CGetMetadataCallBack = Ptr () + -> MetadataKeyValPtr + -> CSize + -> CInt + -> CString + -> IO () + +foreign import ccall "dynamic" + unwrapGetMetadataCallback :: FunPtr CGetMetadataCallBack + -> CGetMetadataCallBack + +data ClientMetadataCreateResult = ClientMetadataCreateResult + { clientResultMetadata :: MetadataMap + -- ^ Additional metadata to add to the call. + , clientResultStatus :: StatusCode + -- ^ if not 'StatusOk', causes the call to fail with the given status code. + -- NOTE: if the auth fails, the call will not get sent to the server. So, if + -- you're writing a test, your server might wait for a request forever. + , clientResultDetails :: StatusDetails } + +-- | Optional plugin for attaching custom auth metadata to each call. +type ClientMetadataCreate = AuthMetadataContext + -> IO ClientMetadataCreateResult + +convertMetadataCreate :: ClientMetadataCreate -> CGetMetadata +convertMetadataCreate f = \authCtxPtr doneCallback userDataPtr -> do + authCtx <- authMetadataContextMarshal authCtxPtr + ClientMetadataCreateResult{..} <- f authCtx + let cb = unwrapGetMetadataCallback doneCallback + withPopulatedMetadataKeyValPtr clientResultMetadata $ \(meta,metaLen) -> + useAsCString (unStatusDetails clientResultDetails) $ \details -> do + let status = fromIntegral $ fromEnum clientResultStatus + cb userDataPtr meta (fromIntegral metaLen) status details + +foreign import ccall "grpc_haskell.h mk_metadata_client_plugin" + mkMetadataClientPlugin :: FunPtr CGetMetadata -> IO MetadataCredentialsPlugin + +{#fun metadata_credentials_create_from_plugin_ as ^ + {`MetadataCredentialsPlugin'} -> `CallCredentials' #} + +createCustomCallCredentials :: ClientMetadataCreate -> IO CallCredentials +createCustomCallCredentials create = do + let rawCreate = convertMetadataCreate create + rawCreatePtr <- mkCGetMetadata rawCreate + plugin <- mkMetadataClientPlugin rawCreatePtr + metadataCredentialsCreateFromPlugin plugin diff --git a/stack.yaml b/stack.yaml index f2e1440..e00abf4 100644 --- a/stack.yaml +++ b/stack.yaml @@ -9,7 +9,7 @@ packages: - '.' - location: git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git - commit: 676a99af41a664660d269c475832301873062a37 + commit: 38b4af244934310edacd3defce5929f44c72d5c9 extra-dep: true - location: git: git@github.com:awakenetworks/proto3-wire.git diff --git a/tests/LowLevelTests.hs b/tests/LowLevelTests.hs index 5ecee59..75c24b5 100644 --- a/tests/LowLevelTests.hs +++ b/tests/LowLevelTests.hs @@ -19,6 +19,7 @@ import Control.Monad.Managed import Data.ByteString (ByteString, isPrefixOf, isSuffixOf) +import Data.List (find) import qualified Data.Map.Strict as M import qualified Data.Set as S import GHC.Exts (fromList, toList) @@ -46,6 +47,9 @@ lowLevelTests = testGroup "Unit tests of low-level Haskell library" , testServerCreateDestroy , testMixRegisteredUnregistered , testPayload + , testSSL + , testAuthMetadataTransfer + , testServerAuthProcessorCancel , testPayloadUnregistered , testServerCancel , testGoaway @@ -152,6 +156,231 @@ testPayload = return ("reply test", dummyMeta, StatusOk, "details string") r @?= Right () +testSSL :: TestTree +testSSL = + csTest' "request/response using SSL" client server + where + clientConf = stdClientConf + {clientSSLConfig = Just (ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + Nothing) + } + client = TestClient clientConf $ \c -> do + rm <- clientRegisterMethodNormal c "/foo" + clientRequest c rm 10 "hi" mempty >>= do + checkReqRslt $ \NormalRequestResult{..} -> do + rspCode @?= StatusOk + rspBody @?= "reply test" + + serverConf = defServerConf + {sslConfig = Just (ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + Nothing) + } + server = TestServer serverConf $ \s -> do + r <- U.serverHandleNormalCall s mempty $ \U.ServerCall{..} body -> do + body @?= "hi" + return ("reply test", mempty, StatusOk, "") + r @?= Right () + +-- NOTE: With auth plugin tests, if an auth plugin decides someone isn't +-- authenticated, then the call never happens from the perspective of +-- the server, so the server will continue to block waiting for a call. So, if +-- these tests hang forever, it's because auth failed and the server is still +-- waiting for a successfully authenticated call to come in. + +testServerAuthProcessorCancel :: TestTree +testServerAuthProcessorCancel = + csTest' "request rejection by auth processor" client server + where + clientConf = stdClientConf + {clientSSLConfig = Just (ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + Nothing) + } + client = TestClient clientConf $ \c -> do + rm <- clientRegisterMethodNormal c "/foo" + r <- clientRequest c rm 10 "hi" mempty + -- TODO: using checkReqRslt on this first result causes the test to hang! + r @?= Left (GRPCIOBadStatusCode StatusUnauthenticated "denied!") + clientRequest c rm 10 "hi" [("foo","bar")] >>= do + checkReqRslt $ \NormalRequestResult{..} -> do + rspCode @?= StatusOk + rspBody @?= "reply test" + + serverProcessor = Just $ \_ m -> do + let (status, details) = if M.member "foo" (unMap m) + then (StatusOk, "") + else (StatusUnauthenticated, "denied!") + return $ AuthProcessorResult mempty mempty status details + + serverConf = defServerConf + {sslConfig = Just (ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + serverProcessor) + } + server = TestServer serverConf $ \s -> do + r <- U.serverHandleNormalCall s mempty $ \U.ServerCall{..} body -> do + checkMD "Handler only sees requests with good metadata" + [("foo","bar")] + metadata + return ("reply test", mempty, StatusOk, "") + r @?= Right () + +testAuthMetadataTransfer :: TestTree +testAuthMetadataTransfer = + csTest' "Auth metadata changes sent from client to server" client server + where + plugin :: ClientMetadataCreate + plugin authMetaCtx = do + let authCtx = (channelAuthContext authMetaCtx) + + addAuthProperty authCtx (AuthProperty "foo1" "bar1") + print "getting properties" + newProps <- getAuthProperties authCtx + print "got properties" + let addedProp = find ((== "foo1") . authPropName) newProps + addedProp @?= Just (AuthProperty "foo1" "bar1") + return $ ClientMetadataCreateResult [("foo","bar")] StatusOk "" + clientConf = stdClientConf + {clientSSLConfig = Just (ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + (Just plugin)) + } + client = TestClient clientConf $ \c -> do + rm <- clientRegisterMethodNormal c "/foo" + clientRequest c rm 10 "hi" mempty >>= do + checkReqRslt $ \NormalRequestResult{..} -> do + rspCode @?= StatusOk + rspBody @?= "reply test" + + serverProcessor :: Maybe ProcessMeta + serverProcessor = Just $ \authCtx m -> do + let expected = fromList [("foo","bar")] + + props <- getAuthProperties authCtx + let clientProp = find ((== "foo1") . authPropName) props + assertBool "server plugin doesn't see auth properties set by client" + (clientProp == Nothing) + checkMD "server plugin sees metadata added by client plugin" expected m + return $ AuthProcessorResult mempty mempty StatusOk "" + + serverConf = defServerConf + {sslConfig = Just (ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + serverProcessor) + } + server = TestServer serverConf $ \s -> do + r <- U.serverHandleNormalCall s mempty $ \U.ServerCall{..} body -> do + body @?= "hi" + return ("reply test", mempty, StatusOk, "") + r @?= Right () + +-- TODO: auth metadata doesn't propagate from parent calls to child calls. +-- Once we implement our own system for doing so, update this test and add it +-- to the tests list. +testAuthMetadataPropagate :: TestTree +testAuthMetadataPropagate = testCase "auth metadata inherited by children" $ do + c <- async client + s1 <- async server + s2 <- async server2 + wait c + wait s1 + wait s2 + return () + where + clientPlugin _ = + return $ ClientMetadataCreateResult [("foo","bar")] StatusOk "" + clientConf = stdClientConf + {clientSSLConfig = Just (ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + (Just clientPlugin)) + } + client = do + threadDelaySecs 3 + withGRPC $ \g -> withClient g clientConf $ \c -> do + rm <- clientRegisterMethodNormal c "/foo" + clientRequest c rm 10 "hi" mempty >>= do + checkReqRslt $ \NormalRequestResult{..} -> do + rspCode @?= StatusOk + rspBody @?= "reply test" + + server1ServerPlugin _ctx md = do + checkMD "server1 sees client's auth metadata." [("foo","bar")] md + -- TODO: add response meta to check, and consume meta to see what happens. + return $ AuthProcessorResult mempty mempty StatusOk "" + + server1ServerConf = defServerConf + {sslConfig = Just (ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + (Just server1ServerPlugin)), + methodsToRegisterNormal = ["/foo"] + } + + server1ClientPlugin _ = + return $ ClientMetadataCreateResult [("foo1","bar1")] StatusOk "" + + server1ClientConf = stdClientConf + {clientSSLConfig = Just (ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + (Just server1ClientPlugin)), + serverPort = 50052 + } + + server = do + threadDelaySecs 2 + withGRPC $ \g -> withServer g server1ServerConf $ \s -> + withClient g server1ClientConf $ \c -> do + let rm = head (normalMethods s) + serverHandleNormalCall s rm mempty $ \call -> do + rmc <- clientRegisterMethodNormal c "/foo" + res <- clientRequestParent c (Just call) rmc 10 "hi" mempty + case res of + Left _ -> + error "got bad result from server2" + Right (NormalRequestResult{..}) -> + return (rspBody, mempty, StatusOk, "") + + server2ServerPlugin _ctx md = do + print md + checkMD "server2 sees server1's auth metadata." [("foo1","bar1")] md + --TODO: this assert fails + checkMD "server2 sees client's auth metadata." [("foo","bar")] md + return $ AuthProcessorResult mempty mempty StatusOk "" + + server2ServerConf = defServerConf + {sslConfig = Just (ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + (Just server2ServerPlugin)), + methodsToRegisterNormal = ["/foo"], + port = 50052 + } + + server2 = withGRPC $ \g -> withServer g server2ServerConf $ \s -> do + let rm = head (normalMethods s) + serverHandleNormalCall s rm mempty $ \call -> do + return ("server2 reply", mempty, StatusOk, "") + testServerCancel :: TestTree testServerCancel = csTest "server cancel call" client server (["/foo"],[],[],[]) @@ -446,7 +675,7 @@ testCustomUserAgent = where clientArgs = [UserAgentPrefix "prefix!", UserAgentSuffix "suffix!"] client = - TestClient (ClientConfig "localhost" 50051 clientArgs) $ + TestClient (ClientConfig "localhost" 50051 clientArgs Nothing) $ \c -> do rm <- clientRegisterMethodNormal c "/foo" void $ clientRequest c rm 4 "" mempty server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do @@ -469,7 +698,8 @@ testClientCompression = TestClient (ClientConfig "localhost" 50051 - [CompressionAlgArg GrpcCompressDeflate]) $ \c -> do + [CompressionAlgArg GrpcCompressDeflate] + Nothing) $ \c -> do rm <- clientRegisterMethodNormal c "/foo" void $ clientRequest c rm 1 "hello" mempty server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do @@ -486,6 +716,7 @@ testClientServerCompression = cconf = ClientConfig "localhost" 50051 [CompressionAlgArg GrpcCompressDeflate] + Nothing client = TestClient cconf $ \c -> do rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 1 "hello" mempty >>= do @@ -500,6 +731,7 @@ testClientServerCompression = 50051 ["/foo"] [] [] [] [CompressionAlgArg GrpcCompressDeflate] + Nothing server = TestServer sconf $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm dummyMeta $ \sc -> do @@ -514,6 +746,7 @@ testClientServerCompressionLvl = cconf = ClientConfig "localhost" 50051 [CompressionLevelArg GrpcCompressLevelHigh] + Nothing client = TestClient cconf $ \c -> do rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 1 "hello" mempty >>= do @@ -528,6 +761,7 @@ testClientServerCompressionLvl = 50051 ["/foo"] [] [] [] [CompressionLevelArg GrpcCompressLevelLow] + Nothing server = TestServer sconf $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm dummyMeta $ \sc -> do @@ -622,7 +856,7 @@ stdTestClient :: (Client -> IO ()) -> TestClient stdTestClient = TestClient stdClientConf stdClientConf :: ClientConfig -stdClientConf = ClientConfig "localhost" 50051 [] +stdClientConf = ClientConfig "localhost" 50051 [] Nothing data TestServer = TestServer ServerConfig (Server -> IO ()) @@ -631,7 +865,7 @@ runTestServer (TestServer conf f) = runManaged $ mgdGRPC >>= mgdServer conf >>= liftIO . f defServerConf :: ServerConfig -defServerConf = ServerConfig "localhost" 50051 [] [] [] [] [] +defServerConf = ServerConfig "localhost" 50051 [] [] [] [] [] Nothing serverConf :: ([MethodName],[MethodName],[MethodName],[MethodName]) -> ServerConfig diff --git a/tests/LowLevelTests/Op.hs b/tests/LowLevelTests/Op.hs index 36ab621..db53e2d 100644 --- a/tests/LowLevelTests/Op.hs +++ b/tests/LowLevelTests/Op.hs @@ -60,10 +60,10 @@ withClientServerUnaryCall grpc f = do f (c, s, cc, sc) serverConf :: ServerConfig -serverConf = ServerConfig "localhost" 50051 [("/foo")] [] [] [] [] +serverConf = ServerConfig "localhost" 50051 [("/foo")] [] [] [] [] Nothing clientConf :: ClientConfig -clientConf = ClientConfig "localhost" 50051 [] +clientConf = ClientConfig "localhost" 50051 [] Nothing clientEmptySendOps :: [Op] clientEmptySendOps = [OpSendInitialMetadata mempty, diff --git a/tests/UnsafeTests.hs b/tests/UnsafeTests.hs index e36d554..3e59bf4 100644 --- a/tests/UnsafeTests.hs +++ b/tests/UnsafeTests.hs @@ -17,6 +17,7 @@ import Network.GRPC.Unsafe.Metadata import Network.GRPC.Unsafe.Slice import Network.GRPC.Unsafe.Time import Network.GRPC.Unsafe.ChannelArgs +import Network.GRPC.Unsafe.Security import System.Clock import Test.Tasty import Test.Tasty.HUnit as HU (testCase, (@?=), @@ -38,6 +39,8 @@ unsafeTests = testGroup "Unit tests for unsafe C bindings" , testCreateDestroyMetadataKeyVals , testCreateDestroyDeadline , testCreateDestroyChannelArgs + , testCreateDestroyClientCreds + , testCreateDestroyServerCreds ] unsafeProperties :: TestTree @@ -178,6 +181,18 @@ testCreateDestroyChannelArgs = testCase "Create/destroy channel args" $ grpc $ withChannelArgs [CompressionAlgArg GrpcCompressDeflate] $ const $ return () +testCreateDestroyClientCreds :: TestTree +testCreateDestroyClientCreds = testCase "Create/destroy client credentials" $ + grpc $ withChannelCredentials Nothing Nothing Nothing $ const $ return () + +testCreateDestroyServerCreds :: TestTree +testCreateDestroyServerCreds = testCase "Create/destroy server credentials" $ + grpc $ withServerCredentials Nothing + "tests/ssl/testServerKey.pem" + "tests/ssl/testServerCert.pem" + SslDontRequestClientCertificate + $ const $ return () + assertCqEventComplete :: Event -> IO () assertCqEventComplete e = do eventCompletionType e HU.@?= OpComplete diff --git a/tests/ssl/localhost.crt b/tests/ssl/localhost.crt new file mode 100644 index 0000000..8f9261f --- /dev/null +++ b/tests/ssl/localhost.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID2jCCAsKgAwIBAgIJAJ72m1gMaVpLMA0GCSqGSIb3DQEBCwUAMFExCzAJBgNV +BAYTAlVTMQswCQYDVQQIEwJDQTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQ +dHkgTHRkMRIwEAYDVQQDEwlsb2NhbGhvc3QwHhcNMTYwODA0MDAwOTU5WhcNMTcw +ODA0MDAwOTU5WjBRMQswCQYDVQQGEwJVUzELMAkGA1UECBMCQ0ExITAfBgNVBAoT +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAxMJbG9jYWxob3N0MIIB +IjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4G6kgozW9U5HZ2SuRBSwwtmm +VHj7CKRF/u3z3DdgcBCpMbhzbirN1XSJhISUpMDG1d6UIsUJhI/sWffjb9ext92E +/hb+RWDx/0qzQ67Xq+yKBDGPVDvMHTAaq9SoJ/oUABK1HuBlxfAup5zAvJ3hI7oI +eNogcCA3v72vyxXMF2DnczLKcw+/m3OuPxwoykPC/PNMttO4edXd6dP3pjO4COd0 +dVrAVi4lq9Ltrw29ybmUYCCVkIXX8ulMIcHVksrOtMN4Qny5lwJ8/+itm61JzfeH +5Q7Eq6v6AY5i892hwUdqWyy5NkR5ty4+WH3O42ilSdJHwSMvcwz+kFWedliiAwID +AQABo4G0MIGxMB0GA1UdDgQWBBSZIcayO/mzNeZWVpNDbfQAoOWvwzCBgQYDVR0j +BHoweIAUmSHGsjv5szXmVlaTQ230AKDlr8OhVaRTMFExCzAJBgNVBAYTAlVTMQsw +CQYDVQQIEwJDQTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRIw +EAYDVQQDEwlsb2NhbGhvc3SCCQCe9ptYDGlaSzAMBgNVHRMEBTADAQH/MA0GCSqG +SIb3DQEBCwUAA4IBAQC112EBNF8k9V1arvWV0qQqznD7PGbCHK004+VopX063dO6 +731FsMBUXqFY/82A83INUsdhchTQJe5+o6MdgvpvExjb0XDbn68/lMyYW9cWfE3/ +LRur9xSSUFNnBlG+oGg3UHdlW4cgkTsqIMryf3GE4Dka/21slySpozkBSPTfLFd+ +44bZ+Ixn5QVn6Z79Q1dtev70RYFOAufCbKFLPaCXCNjy3CT/P5Qmd6QjUT4TYiHM +6Fko18jA+Mz6Gz53PypieyClRM6mxnKlKm3idPMJ/6TbjrH3nDAJQpNk7tDycAQG +z44Eb5mrUF/bi+VhFQxjWVGBqgdr9E65//Y2p+aJ +-----END CERTIFICATE----- diff --git a/tests/ssl/localhost.key b/tests/ssl/localhost.key new file mode 100644 index 0000000..ab7af8e --- /dev/null +++ b/tests/ssl/localhost.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA4G6kgozW9U5HZ2SuRBSwwtmmVHj7CKRF/u3z3DdgcBCpMbhz +birN1XSJhISUpMDG1d6UIsUJhI/sWffjb9ext92E/hb+RWDx/0qzQ67Xq+yKBDGP +VDvMHTAaq9SoJ/oUABK1HuBlxfAup5zAvJ3hI7oIeNogcCA3v72vyxXMF2DnczLK +cw+/m3OuPxwoykPC/PNMttO4edXd6dP3pjO4COd0dVrAVi4lq9Ltrw29ybmUYCCV +kIXX8ulMIcHVksrOtMN4Qny5lwJ8/+itm61JzfeH5Q7Eq6v6AY5i892hwUdqWyy5 +NkR5ty4+WH3O42ilSdJHwSMvcwz+kFWedliiAwIDAQABAoIBAAPpKai0t5IyuP/O +O1MoYSQkEsfqC8XMxwN4NgWiWWXahHg/VJOY31lW3IaYoNZ2HYDgjghFErNipqWy +sh3izk/75jNfRzMCS3U9Yf5N76gpSQZlrq+zEw13Jx/TZtK7gtm1eb59/ogCdW2q +R5mBzsiGl1szwdjyVsZakdOiH5pQp5lhh8DSxZnWGcIa71oRnIlg5bhigSxQcwYH +zL27fda4E/DKLD61jpNBUM5URyem3jaN109SeeuCzZ4uqg17VdQP+v13cR0jwjXl +Rk31dd6LbkPJrYfGqSfOtfVlgP2BrkF80nIz5sPR6XxaYxNJEB2Ae9fyLaughbC3 +3aW2F0ECgYEA/p7MZHDvuelxPJLYzhlLILpqJ5aQO5ZxM0wV0SMddNjCFQgKHYUA +dh1V4SrbxIFAHZCzMBShSGd892sbjFw1atoxR3cXhqYm9atGkfAafEyZSuYkfD7z +qFMdVvqz9W2y6G0IEmQ+FQdRXVUktYqqoBYKEHQTw2Vne+lD9BxHqzMCgYEA4aX3 +2n8X22pXuAP7SzkMrFmHGvmoA6dmm7aqQCwOvKPV3o9c9Rz938t8ljw36vO72SAd +n4HE5hOOJsMjxHkTFC9ZdtvjDSZ2Wa2E6968iVtqnDfd4eDL1kizATBtu8qv6Mkj +Uk2wGIhB6QYRC4WLfiBeji6Rhox7hRCHDC8VrfECgYAEXezbfCRgZ+SNSWd2gXCM +ayYO78Ihg38FhjSJlbSXoHATtEOYJgPQAsjKR9XlFOJon2azWGc7uqqmA6xBSAOS +hZN6ykwY/xiD9iALuLZ7k0S9yBywFNRQ+rvyFfKoLu12lwggaJ+39JwsoZ0zj+FF +RZt+lL8SBtczhNipgyKniQKBgQCLUf4GWhJQ1wfyBgNSHpdEksJoVVz3ZJRl8BKO +LKWsszuJftrWPGBDnU//Mo8T8gk5tiqUIXuA0vIh3LpoxJiTDekfCgTWSbkpBYnc +WFnwNRFOAvBsVm+Ejr53LX+TQ0H5aLb2SiFABGEtjyFXC81kwnefMgGIIZTiQ6Ie +U7P/AQKBgHP/QYORrIik+UXs4YBqMyCoJR+uzBnQKq1plXNS5xTMmgogbyPTHTZ2 +CRc9q2EJeuoo2whubWWkf7k6bmdRAaroLk6vc8AtvClUzCsNskAGghnUl53B/MzP +/K0BEGyZJetiuMxMdOejJ3eXu8reGZa1K1wlQFmri4feROTHwU+b +-----END RSA PRIVATE KEY-----