From b1a6d8884565f5d577f3b8f7378208c20f8b2417 Mon Sep 17 00:00:00 2001 From: Andres Loeh Date: Sat, 9 Apr 2016 15:42:57 +0200 Subject: [PATCH] Revise the Router type to allow proper sharing. We've previously used functions in the Router type to provide information for subrouters. But this accesses the Requests too early, and breaks sharing of the router structure in general, causing the Router or large parts of the Router to be recomputed on every request. We now do not use functions anymore, and properly compute all static parts of the router first, and gain access to the request only in Delayed. This also turns the code used within Delayed into a proper monad now called DelayedIO, making some of the code using it a bit nicer. --- servant-server/CHANGELOG.md | 9 +- servant-server/src/Servant/Server.hs | 22 +- .../src/Servant/Server/Experimental/Auth.hs | 18 +- servant-server/src/Servant/Server/Internal.hs | 111 ++++----- .../src/Servant/Server/Internal/BasicAuth.hs | 11 +- .../src/Servant/Server/Internal/Router.hs | 151 ++++++------ .../Server/Internal/RoutingApplication.hs | 222 ++++++++++-------- .../test/Servant/Server/RouterSpec.hs | 12 +- .../UsingContextSpec/TestCombinators.hs | 3 +- servant-server/test/Servant/ServerSpec.hs | 2 +- 10 files changed, 292 insertions(+), 269 deletions(-) diff --git a/servant-server/CHANGELOG.md b/servant-server/CHANGELOG.md index 85a6d9e1..736c36bd 100644 --- a/servant-server/CHANGELOG.md +++ b/servant-server/CHANGELOG.md @@ -1,10 +1,11 @@ 0.7 --- -* The `Router` type has been changed. There are now more situations where - servers will make use of static lookups to efficiently route the request - to the correct endpoint. Functions `layout` and `layoutWithContext` have - been added to visualize the router layout for debugging purposes. Test +* The `Router` type has been changed. Static router tables should now + be properly shared between requests, drastically increasing the + number of situations where servers will be able to route requests + efficiently. Functions `layout` and `layoutWithContext` have been + added to visualize the router layout for debugging purposes. Test cases for expected router layouts have been added. * Export `throwError` from module `Servant` * Add `Handler` type synonym diff --git a/servant-server/src/Servant/Server.hs b/servant-server/src/Servant/Server.hs index b2cf7a66..bbba7c1b 100644 --- a/servant-server/src/Servant/Server.hs +++ b/servant-server/src/Servant/Server.hs @@ -132,21 +132,13 @@ serve p = serveWithContext p EmptyContext serveWithContext :: (HasServer layout context) => Proxy layout -> Context context -> Server layout -> Application -serveWithContext p context server = toApplication (runRouter (route p context d)) - where - d = Delayed r r r r (\ _ _ _ -> Route server) - r = return (Route ()) +serveWithContext p context server = + toApplication (runRouter (route p context (emptyDelayed (Route server)))) -- | The function 'layout' produces a textual description of the internal -- router layout for debugging purposes. Note that the router layout is -- determined just by the API, not by the handlers. -- --- This function makes certain assumptions about the well-behavedness of --- the 'HasServer' instances of the combinators which should be ok for the --- core servant constructions, but might not be satisfied for some other --- combinators provided elsewhere. It is possible that the function may --- crash for these. --- -- Example: -- -- For the following API @@ -168,7 +160,7 @@ serveWithContext p context server = toApplication (runRouter (route p context d) -- > │ └─ e/ -- > │ └─• -- > ├─ b/ --- > │ └─ / +-- > │ └─ / -- > │ ├─• -- > │ ┆ -- > │ └─• @@ -185,7 +177,7 @@ serveWithContext p context server = toApplication (runRouter (route p context d) -- -- [@─•@] Leaves reflect endpoints. -- --- [@\/@] This is a delayed capture of a path component. +-- [@\/@] This is a delayed capture of a path component. -- -- [@\@] This is a part of the API we do not know anything about. -- @@ -200,10 +192,8 @@ layout p = layoutWithContext p EmptyContext -- | Variant of 'layout' that takes an additional 'Context'. layoutWithContext :: (HasServer layout context) => Proxy layout -> Context context -> Text -layoutWithContext p context = routerLayout (route p context d) - where - d = Delayed r r r r (\ _ _ _ -> FailFatal err501) - r = return (Route ()) +layoutWithContext p context = + routerLayout (route p context (emptyDelayed (FailFatal err501))) -- Documentation diff --git a/servant-server/src/Servant/Server/Experimental/Auth.hs b/servant-server/src/Servant/Server/Experimental/Auth.hs index 86d4dc03..fd38ff1e 100644 --- a/servant-server/src/Servant/Server/Experimental/Auth.hs +++ b/servant-server/src/Servant/Server/Experimental/Auth.hs @@ -12,6 +12,7 @@ module Servant.Server.Experimental.Auth where +import Control.Monad.Trans (liftIO) import Control.Monad.Trans.Except (runExceptT) import Data.Proxy (Proxy (Proxy)) import Data.Typeable (Typeable) @@ -24,10 +25,11 @@ import Servant.Server.Internal (HasContextEntry, HasServer, ServerT, getContextEntry, route) -import Servant.Server.Internal.Router (Router' (WithRequest)) -import Servant.Server.Internal.RoutingApplication (RouteResult (FailFatal, Route), - addAuthCheck) -import Servant.Server.Internal.ServantErr (ServantErr, Handler) +import Servant.Server.Internal.RoutingApplication (addAuthCheck, + delayedFailFatal, + DelayedIO, + withRequest) +import Servant.Server.Internal.ServantErr (Handler) -- * General Auth @@ -57,8 +59,10 @@ instance ( HasServer api context type ServerT (AuthProtect tag :> api) m = AuthServerData (AuthProtect tag) -> ServerT api m - route Proxy context subserver = WithRequest $ \ request -> - route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request) + route Proxy context subserver = + route (Proxy :: Proxy api) context (subserver `addAuthCheck` withRequest authCheck) where + authHandler :: Request -> Handler (AuthServerData (AuthProtect tag)) authHandler = unAuthHandler (getContextEntry context) - authCheck = fmap (either FailFatal Route) . runExceptT . authHandler + authCheck :: Request -> DelayedIO (AuthServerData (AuthProtect tag)) + authCheck = (>>= either delayedFailFatal return) . liftIO . runExceptT . authHandler diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index eb3ca19c..2d378c9d 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -22,6 +22,7 @@ module Servant.Server.Internal , module Servant.Server.Internal.ServantErr ) where +import Control.Monad.Trans (liftIO) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC8 import qualified Data.ByteString.Lazy as BL @@ -70,7 +71,11 @@ import Servant.Server.Internal.ServantErr class HasServer layout context where type ServerT layout (m :: * -> *) :: * - route :: Proxy layout -> Context context -> Delayed (Server layout) -> Router + route :: + Proxy layout + -> Context context + -> Delayed env (Server layout) + -> Router env type Server layout = ServerT layout Handler @@ -92,7 +97,7 @@ instance (HasServer a context, HasServer b context) => HasServer (a :<|> b) cont type ServerT (a :<|> b) m = ServerT a m :<|> ServerT b m route Proxy context server = choice (route pa context ((\ (a :<|> _) -> a) <$> server)) - (route pb context ((\ (_ :<|> b) -> b) <$> server)) + (route pb context ((\ (_ :<|> b) -> b) <$> server)) where pa = Proxy :: Proxy a pb = Proxy :: Proxy b @@ -120,12 +125,12 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer sublayout context) a -> ServerT sublayout m route Proxy context d = - DynamicRouter $ \ first -> + CaptureRouter $ route (Proxy :: Proxy sublayout) context - (addCapture d $ case parseUrlPieceMaybe first :: Maybe a of - Nothing -> return $ Fail err400 - Just v -> return $ Route v + (addCapture d $ \ txt -> case parseUrlPieceMaybe txt :: Maybe a of + Nothing -> delayedFail err400 + Just v -> return v ) allowedMethodHead :: Method -> Request -> Bool @@ -144,41 +149,41 @@ processMethodRouter handleA status method headers request = case handleA of bdy = if allowedMethodHead method request then "" else body hdrs = (hContentType, cs contentT) : (fromMaybe [] headers) -methodCheck :: Method -> Request -> IO (RouteResult ()) +methodCheck :: Method -> Request -> DelayedIO () methodCheck method request - | allowedMethod method request = return $ Route () - | otherwise = return $ Fail err405 + | allowedMethod method request = return () + | otherwise = delayedFail err405 -acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> IO (RouteResult ()) +acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> DelayedIO () acceptCheck proxy accH - | canHandleAcceptH proxy (AcceptHeader accH) = return $ Route () - | otherwise = return $ FailFatal err406 + | canHandleAcceptH proxy (AcceptHeader accH) = return () + | otherwise = delayedFailFatal err406 methodRouter :: (AllCTRender ctypes a) => Method -> Proxy ctypes -> Status - -> Delayed (Handler a) - -> Router + -> Delayed env (Handler a) + -> Router env methodRouter method proxy status action = leafRouter route' where - route' request respond = + route' env request respond = let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request in runAction (action `addMethodCheck` methodCheck method request `addAcceptCheck` acceptCheck proxy accH - ) respond $ \ output -> do + ) env request respond $ \ output -> do let handleA = handleAcceptH proxy (AcceptHeader accH) output processMethodRouter handleA status method Nothing request methodRouterHeaders :: (GetHeaders (Headers h v), AllCTRender ctypes v) => Method -> Proxy ctypes -> Status - -> Delayed (Handler (Headers h v)) - -> Router + -> Delayed env (Handler (Headers h v)) + -> Router env methodRouterHeaders method proxy status action = leafRouter route' where - route' request respond = + route' env request respond = let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request in runAction (action `addMethodCheck` methodCheck method request `addAcceptCheck` acceptCheck proxy accH - ) respond $ \ output -> do + ) env request respond $ \ output -> do let headers = getHeaders output handleA = handleAcceptH proxy (AcceptHeader accH) (getResponse output) processMethodRouter handleA status method (Just headers) request @@ -230,8 +235,8 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context) type ServerT (Header sym a :> sublayout) m = Maybe a -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - let mheader = parseHeaderMaybe =<< lookup str (requestHeaders request) + route Proxy context subserver = + let mheader req = parseHeaderMaybe =<< lookup str (requestHeaders req) in route (Proxy :: Proxy sublayout) context (passToServer subserver mheader) where str = fromString $ symbolVal (Proxy :: Proxy sym) @@ -262,10 +267,10 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context) type ServerT (QueryParam sym a :> sublayout) m = Maybe a -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - let querytext = parseQueryText $ rawQueryString request - param = - case lookup paramname querytext of + route Proxy context subserver = + let querytext r = parseQueryText $ rawQueryString r + param r = + case lookup paramname (querytext r) of Nothing -> Nothing -- param absent from the query string Just Nothing -> Nothing -- param present with no value -> Nothing Just (Just v) -> parseQueryParamMaybe v -- if present, we try to convert to @@ -298,13 +303,13 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context) type ServerT (QueryParams sym a :> sublayout) m = [a] -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - let querytext = parseQueryText $ rawQueryString request + route Proxy context subserver = + let querytext r = parseQueryText $ rawQueryString r -- if sym is "foo", we look for query string parameters -- named "foo" or "foo[]" and call parseQueryParam on the -- corresponding values - parameters = filter looksLikeParam querytext - values = mapMaybe (convert . snd) parameters + parameters r = filter looksLikeParam (querytext r) + values r = mapMaybe (convert . snd) (parameters r) in route (Proxy :: Proxy sublayout) context (passToServer subserver values) where paramname = cs $ symbolVal (Proxy :: Proxy sym) looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]") @@ -329,9 +334,9 @@ instance (KnownSymbol sym, HasServer sublayout context) type ServerT (QueryFlag sym :> sublayout) m = Bool -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - let querytext = parseQueryText $ rawQueryString request - param = case lookup paramname querytext of + route Proxy context subserver = + let querytext r = parseQueryText $ rawQueryString r + param r = case lookup paramname (querytext r) of Just Nothing -> True -- param is there, with no value Just (Just v) -> examine v -- param with a value Nothing -> False -- param not in the query string @@ -352,8 +357,8 @@ instance HasServer Raw context where type ServerT Raw m = Application - route Proxy _ rawApplication = RawRouter $ \ request respond -> do - r <- runDelayed rawApplication + route Proxy _ rawApplication = RawRouter $ \ env request respond -> do + r <- runDelayed rawApplication env request case r of Route app -> app request (respond . Route) Fail a -> respond $ Fail a @@ -386,10 +391,10 @@ instance ( AllCTUnrender list a, HasServer sublayout context type ServerT (ReqBody list a :> sublayout) m = a -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - route (Proxy :: Proxy sublayout) context (addBodyCheck subserver (bodyCheck request)) + route Proxy context subserver = + route (Proxy :: Proxy sublayout) context (addBodyCheck subserver bodyCheck) where - bodyCheck request = do + bodyCheck = withRequest $ \ request -> do -- See HTTP RFC 2616, section 7.2.1 -- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 -- See also "W3C Internet Media Type registration, consistency of use" @@ -397,11 +402,11 @@ instance ( AllCTUnrender list a, HasServer sublayout context let contentTypeH = fromMaybe "application/octet-stream" $ lookup hContentType $ requestHeaders request mrqbody <- handleCTypeH (Proxy :: Proxy list) (cs contentTypeH) - <$> lazyRequestBody request + <$> liftIO (lazyRequestBody request) case mrqbody of - Nothing -> return $ FailFatal err415 - Just (Left e) -> return $ FailFatal err400 { errBody = cs e } - Just (Right v) -> return $ Route v + Nothing -> delayedFailFatal err415 + Just (Left e) -> delayedFailFatal err400 { errBody = cs e } + Just (Right v) -> return v -- | Make sure the incoming request starts with @"/path"@, strip it and -- pass the rest of the request path to @sublayout@. @@ -418,28 +423,28 @@ instance (KnownSymbol path, HasServer sublayout context) => HasServer (path :> s instance HasServer api context => HasServer (RemoteHost :> api) context where type ServerT (RemoteHost :> api) m = SockAddr -> ServerT api m - route Proxy context subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) context (passToServer subserver $ remoteHost req) + route Proxy context subserver = + route (Proxy :: Proxy api) context (passToServer subserver remoteHost) instance HasServer api context => HasServer (IsSecure :> api) context where type ServerT (IsSecure :> api) m = IsSecure -> ServerT api m - route Proxy context subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) context (passToServer subserver $ secure req) + route Proxy context subserver = + route (Proxy :: Proxy api) context (passToServer subserver secure) where secure req = if isSecure req then Secure else NotSecure instance HasServer api context => HasServer (Vault :> api) context where type ServerT (Vault :> api) m = Vault -> ServerT api m - route Proxy context subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) context (passToServer subserver $ vault req) + route Proxy context subserver = + route (Proxy :: Proxy api) context (passToServer subserver vault) instance HasServer api context => HasServer (HttpVersion :> api) context where type ServerT (HttpVersion :> api) m = HttpVersion -> ServerT api m - route Proxy context subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) context (passToServer subserver $ httpVersion req) + route Proxy context subserver = + route (Proxy :: Proxy api) context (passToServer subserver httpVersion) -- | Basic Authentication instance ( KnownSymbol realm @@ -450,12 +455,12 @@ instance ( KnownSymbol realm type ServerT (BasicAuth realm usr :> api) m = usr -> ServerT api m - route Proxy context subserver = WithRequest $ \ request -> - route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request) + route Proxy context subserver = + route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck) where realm = BC8.pack $ symbolVal (Proxy :: Proxy realm) basicAuthContext = getContextEntry context - authCheck req = runBasicAuth req realm basicAuthContext + authCheck = withRequest $ \ req -> runBasicAuth req realm basicAuthContext -- * helpers diff --git a/servant-server/src/Servant/Server/Internal/BasicAuth.hs b/servant-server/src/Servant/Server/Internal/BasicAuth.hs index fcd678b5..1fed931b 100644 --- a/servant-server/src/Servant/Server/Internal/BasicAuth.hs +++ b/servant-server/src/Servant/Server/Internal/BasicAuth.hs @@ -6,6 +6,7 @@ module Servant.Server.Internal.BasicAuth where import Control.Monad (guard) +import Control.Monad.Trans (liftIO) import qualified Data.ByteString as BS import Data.ByteString.Base64 (decodeLenient) import Data.Monoid ((<>)) @@ -57,13 +58,13 @@ decodeBAHdr req = do -- | Run and check basic authentication, returning the appropriate http error per -- the spec. -runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> IO (RouteResult usr) +runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr runBasicAuth req realm (BasicAuthCheck ba) = case decodeBAHdr req of Nothing -> plzAuthenticate - Just e -> ba e >>= \res -> case res of + Just e -> liftIO (ba e) >>= \res -> case res of BadPassword -> plzAuthenticate NoSuchUser -> plzAuthenticate - Unauthorized -> return $ FailFatal err403 - Authorized usr -> return $ Route usr - where plzAuthenticate = return $ FailFatal err401 { errHeaders = [mkBAChallengerHdr realm] } + Unauthorized -> delayedFailFatal err403 + Authorized usr -> return usr + where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] } diff --git a/servant-server/src/Servant/Server/Internal/Router.hs b/servant-server/src/Servant/Server/Internal/Router.hs index 04b661a3..3b69c04c 100644 --- a/servant-server/src/Servant/Server/Internal/Router.hs +++ b/servant-server/src/Servant/Server/Internal/Router.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE OverloadedStrings #-} module Servant.Server.Internal.Router where @@ -8,36 +10,41 @@ import qualified Data.Map as M import Data.Monoid import Data.Text (Text) import qualified Data.Text as T -import Network.Wai (Request, Response, pathInfo) +import Network.Wai (Response, pathInfo) import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.ServantErr -type Router = Router' RoutingApplication +type Router env = Router' env RoutingApplication -- | Internal representation of a router. -data Router' a = - WithRequest (Request -> Router' a) - -- ^ current request is passed to the router - | StaticRouter (Map Text (Router' a)) [a] +-- +-- The first argument describes an environment type that is +-- expected as extra input by the routers at the leaves. The +-- environment is filled while running the router, with path +-- components that can be used to process captures. +-- +data Router' env a = + StaticRouter (Map Text (Router' env a)) [env -> a] -- ^ the map contains routers for subpaths (first path component used -- for lookup and removed afterwards), the list contains handlers -- for the empty path, to be tried in order - | DynamicRouter (Text -> Router' a) - -- ^ first path component passed to the function and removed afterwards - | RawRouter a + | CaptureRouter (Router' (Text, env) a) + -- ^ first path component is passed to the child router in its + -- environment and removed afterwards + | RawRouter (env -> a) -- ^ to be used for routes we do not know anything about - | Choice (Router' a) (Router' a) + | Choice (Router' env a) (Router' env a) -- ^ left-biased choice between two routers deriving Functor -- | Smart constructor for a single static path component. -pathRouter :: Text -> Router' a -> Router' a +pathRouter :: Text -> Router' env a -> Router' env a pathRouter t r = StaticRouter (M.singleton t r) [] -- | Smart constructor for a leaf, i.e., a router that expects -- the empty path. -- -leafRouter :: a -> Router' a +leafRouter :: (env -> a) -> Router' env a leafRouter l = StaticRouter M.empty [l] -- | Smart constructor for the choice between routers. @@ -46,40 +53,27 @@ leafRouter l = StaticRouter M.empty [l] -- * Two static routers can be joined by joining their maps -- and concatenating their leaf-lists. -- * Two dynamic routers can be joined by joining their codomains. --- * Two 'WithRequest' routers can be joined by passing them --- the same request and joining their codomains. --- * A 'WithRequest' router can be joined with anything else by --- passing the same request to both but ignoring it in the --- component that does not need it. -- * Choice nodes can be reordered. -- -choice :: Router -> Router -> Router +choice :: Router' env a -> Router' env a -> Router' env a choice (StaticRouter table1 ls1) (StaticRouter table2 ls2) = StaticRouter (M.unionWith choice table1 table2) (ls1 ++ ls2) -choice (DynamicRouter fun1) (DynamicRouter fun2) = - DynamicRouter (\ first -> choice (fun1 first) (fun2 first)) -choice (WithRequest router1) (WithRequest router2) = - WithRequest (\ request -> choice (router1 request) (router2 request)) -choice (WithRequest router1) router2 = - WithRequest (\ request -> choice (router1 request) router2) -choice router1 (WithRequest router2) = - WithRequest (\ request -> choice router1 (router2 request)) +choice (CaptureRouter router1) (CaptureRouter router2) = + CaptureRouter (choice router1 router2) choice router1 (Choice router2 router3) = Choice (choice router1 router2) router3 choice router1 router2 = Choice router1 router2 -- | Datatype used for representing and debugging the --- structure of a router. Abstracts from the functions --- being used in the actual router and the handlers at --- the leaves. +-- structure of a router. Abstracts from the handlers +-- at the leaves. -- -- Two 'Router's can be structurally compared by computing -- their 'RouterStructure' using 'routerStructure' and -- then testing for equality, see 'sameStructure'. -- data RouterStructure = - WithRequestStructure RouterStructure - | StaticRouterStructure (Map Text RouterStructure) Int - | DynamicRouterStructure RouterStructure + StaticRouterStructure (Map Text RouterStructure) Int + | CaptureRouterStructure RouterStructure | RawRouterStructure | ChoiceStructure RouterStructure RouterStructure deriving (Eq, Show) @@ -87,18 +81,15 @@ data RouterStructure = -- | Compute the structure of a router. -- -- Assumes that the request or text being passed --- in 'WithRequest' or 'DynamicRouter' does not +-- in 'WithRequest' or 'CaptureRouter' does not -- affect the structure of the underlying tree. -- -routerStructure :: Router' a -> RouterStructure -routerStructure (WithRequest f) = - WithRequestStructure $ - routerStructure (f (error "routerStructure: dummy request")) +routerStructure :: Router' env a -> RouterStructure routerStructure (StaticRouter m ls) = StaticRouterStructure (fmap routerStructure m) (length ls) -routerStructure (DynamicRouter f) = - DynamicRouterStructure $ - routerStructure (f (error "routerStructure: dummy text")) +routerStructure (CaptureRouter router) = + CaptureRouterStructure $ + routerStructure router routerStructure (RawRouter _) = RawRouterStructure routerStructure (Choice r1 r2) = @@ -108,21 +99,20 @@ routerStructure (Choice r1 r2) = -- | Compare the structure of two routers. -- -sameStructure :: Router' a -> Router' b -> Bool +sameStructure :: Router' env a -> Router' env b -> Bool sameStructure r1 r2 = routerStructure r1 == routerStructure r2 -- | Provide a textual representation of the -- structure of a router. -- -routerLayout :: Router' a -> Text +routerLayout :: Router' env a -> Text routerLayout router = T.unlines (["/"] ++ mkRouterLayout False (routerStructure router)) where mkRouterLayout :: Bool -> RouterStructure -> [Text] - mkRouterLayout c (WithRequestStructure r) = mkRouterLayout c r mkRouterLayout c (StaticRouterStructure m n) = mkSubTrees c (M.toList m) n - mkRouterLayout c (DynamicRouterStructure r) = mkSubTree c "" (mkRouterLayout False r) + mkRouterLayout c (CaptureRouterStructure r) = mkSubTree c "" (mkRouterLayout False r) mkRouterLayout c RawRouterStructure = if c then ["├─ "] else ["└─ "] mkRouterLayout c (ChoiceStructure r1 r2) = @@ -146,47 +136,54 @@ routerLayout router = mkSubTree False path children = ("└─ " <> path <> "/") : map (" " <>) children -- | Apply a transformation to the response of a `Router`. -tweakResponse :: (RouteResult Response -> RouteResult Response) -> Router -> Router +tweakResponse :: (RouteResult Response -> RouteResult Response) -> Router env -> Router env tweakResponse f = fmap (\a -> \req cont -> a req (cont . f)) -- | Interpret a router as an application. -runRouter :: Router -> RoutingApplication -runRouter (WithRequest router) request respond = - runRouter (router request) request respond -runRouter (StaticRouter table ls) request respond = - case pathInfo request of - [] -> runChoice ls request respond - -- This case is to handle trailing slashes. - [""] -> runChoice ls request respond - first : rest | Just router <- M.lookup first table - -> let request' = request { pathInfo = rest } - in runRouter router request' respond - _ -> respond $ Fail err404 -runRouter (DynamicRouter fun) request respond = - case pathInfo request of - [] -> respond $ Fail err404 - -- This case is to handle trailing slashes. - [""] -> respond $ Fail err404 - first : rest - -> let request' = request { pathInfo = rest } - in runRouter (fun first) request' respond -runRouter (RawRouter app) request respond = app request respond -runRouter (Choice r1 r2) request respond = - runChoice [runRouter r1, runRouter r2] request respond +runRouter :: Router () -> RoutingApplication +runRouter r = runRouterEnv r () + +runRouterEnv :: Router env -> env -> RoutingApplication +runRouterEnv router env request respond = + case router of + StaticRouter table ls -> + case pathInfo request of + [] -> runChoice ls env request respond + -- This case is to handle trailing slashes. + [""] -> runChoice ls env request respond + first : rest | Just router' <- M.lookup first table + -> let request' = request { pathInfo = rest } + in runRouterEnv router' env request' respond + _ -> respond $ Fail err404 + CaptureRouter router' -> + case pathInfo request of + [] -> respond $ Fail err404 + -- This case is to handle trailing slashes. + [""] -> respond $ Fail err404 + first : rest + -> let request' = request { pathInfo = rest } + in runRouterEnv router' (first, env) request' respond + RawRouter app -> + app env request respond + Choice r1 r2 -> + runChoice [runRouterEnv r1, runRouterEnv r2] env request respond -- | Try a list of routing applications in order. -- We stop as soon as one fails fatally or succeeds. -- If all fail normally, we pick the "best" error. -- -runChoice :: [RoutingApplication] -> RoutingApplication -runChoice [] _request respond = respond (Fail err404) -runChoice [r] request respond = r request respond -runChoice (r : rs) request respond = - r request $ \ response1 -> - case response1 of - Fail _ -> runChoice rs request $ \ response2 -> - respond $ highestPri response1 response2 - _ -> respond response1 +runChoice :: [env -> RoutingApplication] -> env -> RoutingApplication +runChoice ls = + case ls of + [] -> \ _ _ respond -> respond (Fail err404) + [r] -> r + (r : rs) -> + \ env request respond -> + r env request $ \ response1 -> + case response1 of + Fail _ -> runChoice rs env request $ \ response2 -> + respond $ highestPri response1 response2 + _ -> respond response1 where highestPri (Fail e1) (Fail e2) = if worseHTTPCode (errHTTPCode e1) (errHTTPCode e2) diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 99def4b8..5825531e 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -8,7 +8,10 @@ {-# LANGUAGE StandaloneDeriving #-} module Servant.Server.Internal.RoutingApplication where +import Control.Monad (ap, liftM) +import Control.Monad.Trans (MonadIO(..)) import Control.Monad.Trans.Except (runExceptT) +import Data.Text (Text) import Network.Wai (Application, Request, Response, ResponseReceived) import Prelude () @@ -95,113 +98,133 @@ toApplication ra request respond = ra request routingRespond -- The accept header check can be performed as the final -- computation in this block. It can cause a 406. -- -data Delayed c where - Delayed :: { capturesD :: IO (RouteResult captures) - , methodD :: IO (RouteResult ()) - , authD :: IO (RouteResult auth) - , bodyD :: IO (RouteResult body) - , serverD :: (captures -> auth -> body -> RouteResult c) - } -> Delayed c +data Delayed env c where + Delayed :: { capturesD :: env -> DelayedIO captures + , methodD :: DelayedIO () + , authD :: DelayedIO auth + , bodyD :: DelayedIO body + , serverD :: captures -> auth -> body -> Request -> RouteResult c + } -> Delayed env c -instance Functor Delayed where - fmap f Delayed{..} - = Delayed { capturesD = capturesD - , methodD = methodD - , authD = authD - , bodyD = bodyD - , serverD = (fmap.fmap.fmap.fmap) f serverD - } -- Note [Existential Record Update] +instance Functor (Delayed env) where + fmap f Delayed{..} = + Delayed + { serverD = \ c a b req -> f <$> serverD c a b req + , .. + } -- Note [Existential Record Update] + +-- | Computations used in a 'Delayed' can depend on the +-- incoming 'Request', may perform 'IO, and result in a +-- 'RouteResult, meaning they can either suceed, fail +-- (with the possibility to recover), or fail fatally. +-- +newtype DelayedIO a = DelayedIO { runDelayedIO :: Request -> IO (RouteResult a) } + +instance Functor DelayedIO where + fmap = liftM + +instance Applicative DelayedIO where + pure = return + (<*>) = ap + +instance Monad DelayedIO where + return x = DelayedIO (const $ return (Route x)) + DelayedIO m >>= f = + DelayedIO $ \ req -> do + r <- m req + case r of + Fail e -> return $ Fail e + FailFatal e -> return $ FailFatal e + Route a -> runDelayedIO (f a) req + +instance MonadIO DelayedIO where + liftIO m = DelayedIO (const $ Route <$> m) + +-- | A 'Delayed' without any stored checks. +emptyDelayed :: RouteResult a -> Delayed env a +emptyDelayed result = + Delayed (const r) r r r (\ _ _ _ _ -> result) + where + r = return () + +-- | Fail with the option to recover. +delayedFail :: ServantErr -> DelayedIO a +delayedFail err = DelayedIO (const $ return $ Fail err) + +-- | Fail fatally, i.e., without any option to recover. +delayedFailFatal :: ServantErr -> DelayedIO a +delayedFailFatal err = DelayedIO (const $ return $ FailFatal err) + +-- | Gain access to the incoming request. +withRequest :: (Request -> DelayedIO a) -> DelayedIO a +withRequest f = DelayedIO (\ req -> runDelayedIO (f req) req) -- | Add a capture to the end of the capture block. -addCapture :: Delayed (a -> b) - -> IO (RouteResult a) - -> Delayed b -addCapture Delayed{..} new - = Delayed { capturesD = combineRouteResults (,) capturesD new - , methodD = methodD - , authD = authD - , bodyD = bodyD - , serverD = \ (x, v) y z -> ($ v) <$> serverD x y z - } -- Note [Existential Record Update] +addCapture :: Delayed env (a -> b) + -> (Text -> DelayedIO a) + -> Delayed (Text, env) b +addCapture Delayed{..} new = + Delayed + { capturesD = \ (txt, env) -> (,) <$> capturesD env <*> new txt + , serverD = \ (x, v) a b req -> ($ v) <$> serverD x a b req + , .. + } -- Note [Existential Record Update] -- | Add a method check to the end of the method block. -addMethodCheck :: Delayed a - -> IO (RouteResult ()) - -> Delayed a -addMethodCheck Delayed{..} new - = Delayed { capturesD = capturesD - , methodD = combineRouteResults const methodD new - , authD = authD - , bodyD = bodyD - , serverD = serverD - } -- Note [Existential Record Update] +addMethodCheck :: Delayed env a + -> DelayedIO () + -> Delayed env a +addMethodCheck Delayed{..} new = + Delayed + { methodD = methodD <* new + , .. + } -- Note [Existential Record Update] -- | Add an auth check to the end of the auth block. -addAuthCheck :: Delayed (a -> b) - -> IO (RouteResult a) - -> Delayed b -addAuthCheck Delayed{..} new - = Delayed { capturesD = capturesD - , methodD = methodD - , authD = combineRouteResults (,) authD new - , bodyD = bodyD - , serverD = \ x (y, v) z -> ($ v) <$> serverD x y z - } -- Note [Existential Record Update] +addAuthCheck :: Delayed env (a -> b) + -> DelayedIO a + -> Delayed env b +addAuthCheck Delayed{..} new = + Delayed + { authD = (,) <$> authD <*> new + , serverD = \ c (y, v) b req -> ($ v) <$> serverD c y b req + , .. + } -- Note [Existential Record Update] -- | Add a body check to the end of the body block. -addBodyCheck :: Delayed (a -> b) - -> IO (RouteResult a) - -> Delayed b -addBodyCheck Delayed{..} new - = Delayed { capturesD = capturesD - , methodD = methodD - , authD = authD - , bodyD = combineRouteResults (,) bodyD new - , serverD = \ x y (z, v) -> ($ v) <$> serverD x y z - } -- Note [Existential Record Update] +addBodyCheck :: Delayed env (a -> b) + -> DelayedIO a + -> Delayed env b +addBodyCheck Delayed{..} new = + Delayed + { bodyD = (,) <$> bodyD <*> new + , serverD = \ c a (z, v) req -> ($ v) <$> serverD c a z req + , .. + } -- Note [Existential Record Update] -- | Add an accept header check to the end of the body block. -- The accept header check should occur after the body check, -- but this will be the case, because the accept header check -- is only scheduled by the method combinators. -addAcceptCheck :: Delayed a - -> IO (RouteResult ()) - -> Delayed a -addAcceptCheck Delayed{..} new - = Delayed { capturesD = capturesD - , methodD = methodD - , authD = authD - , bodyD = combineRouteResults const bodyD new - , serverD = serverD - } -- Note [Existential Record Update] +addAcceptCheck :: Delayed env a + -> DelayedIO () + -> Delayed env a +addAcceptCheck Delayed{..} new = + Delayed + { bodyD = bodyD <* new + , .. + } -- Note [Existential Record Update] -- | Many combinators extract information that is passed to -- the handler without the possibility of failure. In such a -- case, 'passToServer' can be used. -passToServer :: Delayed (a -> b) -> a -> Delayed b -passToServer d x = ($ x) <$> d - --- | The combination 'IO . RouteResult' is a monad, but we --- don't explicitly wrap it in a newtype in order to make it --- an instance. This is the '>>=' of that monad. --- --- We stop on the first error. -bindRouteResults :: IO (RouteResult a) -> (a -> IO (RouteResult b)) -> IO (RouteResult b) -bindRouteResults m f = do - r <- m - case r of - Fail e -> return $ Fail e - FailFatal e -> return $ FailFatal e - Route a -> f a - --- | Common special case of 'bindRouteResults', corresponding --- to 'liftM2'. -combineRouteResults :: (a -> b -> c) -> IO (RouteResult a) -> IO (RouteResult b) -> IO (RouteResult c) -combineRouteResults f m1 m2 = - m1 `bindRouteResults` \ a -> - m2 `bindRouteResults` \ b -> - return (Route (f a b)) +passToServer :: Delayed env (a -> b) -> (Request -> a) -> Delayed env b +passToServer Delayed{..} x = + Delayed + { serverD = \ c a b req -> ($ x req) <$> serverD c a b req + , .. + } -- Note [Existential Record Update] -- | Run a delayed server. Performs all scheduled operations -- in order, and passes the results from the capture and body @@ -209,24 +232,29 @@ combineRouteResults f m1 m2 = -- -- This should only be called once per request; otherwise the guarantees about -- effect and HTTP error ordering break down. -runDelayed :: Delayed a +runDelayed :: Delayed env a + -> env + -> Request -> IO (RouteResult a) -runDelayed Delayed{..} = - capturesD `bindRouteResults` \ c -> - methodD `bindRouteResults` \ _ -> - authD `bindRouteResults` \ a -> - bodyD `bindRouteResults` \ b -> - return (serverD c a b) +runDelayed Delayed{..} env = runDelayedIO $ do + c <- capturesD env + methodD + a <- authD + b <- bodyD + DelayedIO (\ req -> return $ serverD c a b req) -- | Runs a delayed server and the resulting action. -- Takes a continuation that lets us send a response. -- Also takes a continuation for how to turn the -- result of the delayed server into a response. -runAction :: Delayed (Handler a) +runAction :: Delayed env (Handler a) + -> env + -> Request -> (RouteResult Response -> IO r) -> (a -> RouteResult Response) -> IO r -runAction action respond k = runDelayed action >>= go >>= respond +runAction action env req respond k = + runDelayed action env req >>= go >>= respond where go (Fail e) = return $ Fail e go (FailFatal e) = return $ FailFatal e diff --git a/servant-server/test/Servant/Server/RouterSpec.hs b/servant-server/test/Servant/Server/RouterSpec.hs index 7ebd1a75..684361b2 100644 --- a/servant-server/test/Servant/Server/RouterSpec.hs +++ b/servant-server/test/Servant/Server/RouterSpec.hs @@ -25,9 +25,9 @@ routerSpec = do let app' :: Application app' = toApplication $ runRouter router' - router', router :: Router + router', router :: Router () router' = tweakResponse (fmap twk) router - router = leafRouter $ \_ cont -> cont (Route $ responseBuilder (Status 201 "") [] "") + router = leafRouter $ \_ _ cont -> cont (Route $ responseBuilder (Status 201 "") [] "") twk :: Response -> Response twk (ResponseBuilder (Status i s) hs b) = ResponseBuilder (Status (i + 1) s) hs b @@ -69,11 +69,9 @@ shouldHaveSameStructureAs p1 p2 = unless (sameStructure (makeTrivialRouter p1) (makeTrivialRouter p2)) $ expectationFailure ("expected:\n" ++ unpack (layout p2) ++ "\nbut got:\n" ++ unpack (layout p1)) -makeTrivialRouter :: (HasServer layout '[]) => Proxy layout -> Router -makeTrivialRouter p = route p EmptyContext d - where - d = Delayed r r r r (\ _ _ _ -> FailFatal err501) - r = return (Route ()) +makeTrivialRouter :: (HasServer layout '[]) => Proxy layout -> Router () +makeTrivialRouter p = + route p EmptyContext (emptyDelayed (FailFatal err501)) type End = Get '[JSON] () diff --git a/servant-server/test/Servant/Server/UsingContextSpec/TestCombinators.hs b/servant-server/test/Servant/Server/UsingContextSpec/TestCombinators.hs index 48595c9c..21999451 100644 --- a/servant-server/test/Servant/Server/UsingContextSpec/TestCombinators.hs +++ b/servant-server/test/Servant/Server/UsingContextSpec/TestCombinators.hs @@ -20,7 +20,6 @@ module Servant.Server.UsingContextSpec.TestCombinators where import GHC.TypeLits import Servant -import Servant.Server.Internal.RoutingApplication data ExtractFromContext @@ -31,7 +30,7 @@ instance (HasContextEntry context String, HasServer subApi context) => String -> ServerT subApi m route Proxy context delayed = - route subProxy context (fmap (inject context) delayed :: Delayed (Server subApi)) + route subProxy context (fmap (inject context) delayed) where subProxy :: Proxy subApi subProxy = Proxy diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index 5b4154d7..fc4eb1df 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -48,7 +48,7 @@ import Servant.API ((:<|>) (..), (:>), AuthProtect, Raw, RemoteHost, ReqBody, StdMethod (..), Verb, addHeader) import Servant.API.Internal.Test.ComprehensiveAPI -import Servant.Server (ServantErr (..), Server, Handler, err401, err403, +import Servant.Server (Server, Handler, err401, err403, err404, serve, serveWithContext, Context((:.), EmptyContext)) import Test.Hspec (Spec, context, describe, it,