diff --git a/servant-server/src/Servant/Server/CombinatorUtils.hs b/servant-server/src/Servant/Server/CombinatorUtils.hs index c7e8dd66..ee67e047 100644 --- a/servant-server/src/Servant/Server/CombinatorUtils.hs +++ b/servant-server/src/Servant/Server/CombinatorUtils.hs @@ -1,11 +1,16 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Servant.Server.CombinatorUtils ( - RouteResult(..), - argumentCombinator, + CombinatorImplementation, + runCI, captureCombinator, + argumentCombinator, + -- * re-exports + RouteResult(..), ) where import Data.Proxy @@ -16,29 +21,41 @@ import Servant.API import Servant.Server import Servant.Server.Internal -argumentCombinator :: - forall api combinator arg context env . - (ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler), - HasServer api context) => - (Request -> RouteResult arg) - -> Proxy (combinator :> api) - -> Context context - -> Delayed env (Server (combinator :> api)) - -> Router' env RoutingApplication -argumentCombinator getArg Proxy context delayed = - route (Proxy :: Proxy api) context $ addBodyCheck delayed - (DelayedIO (return ())) $ \ () -> - withRequest $ \ request -> liftRouteResult (getArg request) +data CombinatorImplementation combinator arg api context where + CI :: (forall env . + Proxy (combinator :> api) + -> Context context + -> Delayed env (arg -> Server api) + -> Router' env RoutingApplication) + -> CombinatorImplementation combinator arg api context -captureCombinator :: - forall api combinator arg context env . - (HasServer api context) => - (Text -> RouteResult arg) +runCI :: CombinatorImplementation combinator arg api context -> Proxy (combinator :> api) -> Context context -> Delayed env (arg -> Server api) -> Router' env RoutingApplication -captureCombinator getArg Proxy context delayed = +runCI (CI i) = i + +captureCombinator :: + forall api combinator arg context . + (HasServer api context) => + (Text -> RouteResult arg) + -> CombinatorImplementation combinator arg api context +captureCombinator getArg = CI $ \ Proxy context delayed -> CaptureRouter $ route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured -> (liftRouteResult (getArg captured)) + +argumentCombinator :: + forall api combinator arg context . + (ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler), + HasServer api context) => + (Request -> RouteResult arg) + -> CombinatorImplementation combinator arg api context +argumentCombinator getArg = CI $ \ Proxy context delayed -> + route (Proxy :: Proxy api) context $ + addBodyCheck delayed contentTypeCheck bodyCheck + where + contentTypeCheck = return () + bodyCheck () = withRequest $ \ request -> + liftRouteResult (getArg request) diff --git a/servant-server/test/Servant/Server/CombinatorUtilsSpec.hs b/servant-server/test/Servant/Server/CombinatorUtilsSpec.hs index 68f997c1..32fbd2c8 100644 --- a/servant-server/test/Servant/Server/CombinatorUtilsSpec.hs +++ b/servant-server/test/Servant/Server/CombinatorUtilsSpec.hs @@ -27,7 +27,7 @@ import Servant.Server.CombinatorUtils runApp :: Application -> Request -> IO Response runApp app req = do mvar <- newMVar Nothing - app req $ \ response -> do + ResponseReceived <- app req $ \ response -> do modifyMVar mvar $ \ Nothing -> return $ (Just response, ResponseReceived) modifyMVar mvar $ \mResponse -> do @@ -100,7 +100,7 @@ data FooHeader instance HasServer api context => HasServer (FooHeader :> api) context where type ServerT (FooHeader :> api) m = String -> ServerT api m - route = argumentCombinator getCustom + route = runCI $ argumentCombinator getCustom getCustom :: Request -> RouteResult String getCustom request = case lookup "FooHeader" (requestHeaders request) of @@ -111,7 +111,7 @@ data StringCapture instance HasServer api context => HasServer (StringCapture :> api) context where type ServerT (StringCapture :> api) m = String -> ServerT api m - route = captureCombinator getCapture + route = runCI $ captureCombinator getCapture getCapture :: Text -> RouteResult String getCapture = Route . cs