{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} module Network.GRPC.HighLevel.Server.Unregistered where import Control.Arrow import Control.Concurrent.Async (async, wait) import qualified Control.Exception as CE import Control.Monad import Data.Foldable (find) import Network.GRPC.HighLevel.Server import Network.GRPC.LowLevel import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U import Proto3.Suite.Class dispatchLoop :: Server -> (String -> IO ()) -> MetadataMap -> [Handler 'Normal] -> [Handler 'ClientStreaming] -> [Handler 'ServerStreaming] -> [Handler 'BiDiStreaming] -> IO () dispatchLoop s logger md hN hC hS hB = forever $ U.withServerCallAsync s $ \sc -> case findHandler sc allHandlers of Just (AnyHandler ah) -> case ah of UnaryHandler _ h -> unaryHandler sc h ClientStreamHandler _ h -> csHandler sc h ServerStreamHandler _ h -> ssHandler sc h BiDiStreamHandler _ h -> bdHandler sc h Nothing -> unknownHandler sc where allHandlers = map AnyHandler hN ++ map AnyHandler hC ++ map AnyHandler hS ++ map AnyHandler hB findHandler sc = find ((== U.callMethod sc) . anyHandlerMethodName) unaryHandler :: (Message a, Message b) => U.ServerCall -> ServerHandler a b -> IO () unaryHandler sc h = handleError $ U.serverHandleNormalCall' s sc md $ \_sc' bs -> convertServerHandler h (const bs <$> U.convertCall sc) csHandler :: (Message a, Message b) => U.ServerCall -> ServerReaderHandler a b -> IO () csHandler sc = handleError . U.serverReader s sc md . convertServerReaderHandler ssHandler :: (Message a, Message b) => U.ServerCall -> ServerWriterHandler a b -> IO () ssHandler sc = handleError . U.serverWriter s sc md . convertServerWriterHandler bdHandler :: (Message a, Message b) => U.ServerCall -> ServerRWHandler a b -> IO () bdHandler sc = handleError . U.serverRW s sc md . convertServerRWHandler unknownHandler :: U.ServerCall -> IO () unknownHandler sc = void $ U.serverHandleNormalCall' s sc md $ \_ _ -> return (mempty, mempty, StatusNotFound, StatusDetails "unknown method") handleError :: IO a -> IO () handleError = (handleCallError logger . left herr =<<) . CE.try where herr (e :: CE.SomeException) = GRPCIOHandlerException (show e) serverLoop :: ServerOptions -> IO () serverLoop ServerOptions{..} = do -- We run the loop in a new thread so that we can kill the serverLoop thread. -- Without this fork, we block on a foreign call, which can't be interrupted. tid <- async $ withGRPC $ \grpc -> withServer grpc config $ \server -> do dispatchLoop server optLogger optInitialMetadata optNormalHandlers optClientStreamHandlers optServerStreamHandlers optBiDiStreamHandlers wait tid where config = ServerConfig { host = optServerHost , port = optServerPort , methodsToRegisterNormal = [] , methodsToRegisterClientStreaming = [] , methodsToRegisterServerStreaming = [] , methodsToRegisterBiDiStreaming = [] , serverArgs = [CompressionAlgArg GrpcCompressDeflate | optUseCompression] ++ [ UserAgentPrefix optUserAgentPrefix , UserAgentSuffix optUserAgentSuffix ] , sslConfig = optSSLConfig }