diff --git a/servant-server/CHANGELOG.md b/servant-server/CHANGELOG.md index 85a6d9e1..736c36bd 100644 --- a/servant-server/CHANGELOG.md +++ b/servant-server/CHANGELOG.md @@ -1,10 +1,11 @@ 0.7 --- -* The `Router` type has been changed. There are now more situations where - servers will make use of static lookups to efficiently route the request - to the correct endpoint. Functions `layout` and `layoutWithContext` have - been added to visualize the router layout for debugging purposes. Test +* The `Router` type has been changed. Static router tables should now + be properly shared between requests, drastically increasing the + number of situations where servers will be able to route requests + efficiently. Functions `layout` and `layoutWithContext` have been + added to visualize the router layout for debugging purposes. Test cases for expected router layouts have been added. * Export `throwError` from module `Servant` * Add `Handler` type synonym diff --git a/servant-server/src/Servant/Server.hs b/servant-server/src/Servant/Server.hs index b2cf7a66..bbba7c1b 100644 --- a/servant-server/src/Servant/Server.hs +++ b/servant-server/src/Servant/Server.hs @@ -132,21 +132,13 @@ serve p = serveWithContext p EmptyContext serveWithContext :: (HasServer layout context) => Proxy layout -> Context context -> Server layout -> Application -serveWithContext p context server = toApplication (runRouter (route p context d)) - where - d = Delayed r r r r (\ _ _ _ -> Route server) - r = return (Route ()) +serveWithContext p context server = + toApplication (runRouter (route p context (emptyDelayed (Route server)))) -- | The function 'layout' produces a textual description of the internal -- router layout for debugging purposes. Note that the router layout is -- determined just by the API, not by the handlers. -- --- This function makes certain assumptions about the well-behavedness of --- the 'HasServer' instances of the combinators which should be ok for the --- core servant constructions, but might not be satisfied for some other --- combinators provided elsewhere. It is possible that the function may --- crash for these. --- -- Example: -- -- For the following API @@ -168,7 +160,7 @@ serveWithContext p context server = toApplication (runRouter (route p context d) -- > │ └─ e/ -- > │ └─• -- > ├─ b/ --- > │ └─ / +-- > │ └─ / -- > │ ├─• -- > │ ┆ -- > │ └─• @@ -185,7 +177,7 @@ serveWithContext p context server = toApplication (runRouter (route p context d) -- -- [@─•@] Leaves reflect endpoints. -- --- [@\/@] This is a delayed capture of a path component. +-- [@\/@] This is a delayed capture of a path component. -- -- [@\@] This is a part of the API we do not know anything about. -- @@ -200,10 +192,8 @@ layout p = layoutWithContext p EmptyContext -- | Variant of 'layout' that takes an additional 'Context'. layoutWithContext :: (HasServer layout context) => Proxy layout -> Context context -> Text -layoutWithContext p context = routerLayout (route p context d) - where - d = Delayed r r r r (\ _ _ _ -> FailFatal err501) - r = return (Route ()) +layoutWithContext p context = + routerLayout (route p context (emptyDelayed (FailFatal err501))) -- Documentation diff --git a/servant-server/src/Servant/Server/Experimental/Auth.hs b/servant-server/src/Servant/Server/Experimental/Auth.hs index 86d4dc03..fd38ff1e 100644 --- a/servant-server/src/Servant/Server/Experimental/Auth.hs +++ b/servant-server/src/Servant/Server/Experimental/Auth.hs @@ -12,6 +12,7 @@ module Servant.Server.Experimental.Auth where +import Control.Monad.Trans (liftIO) import Control.Monad.Trans.Except (runExceptT) import Data.Proxy (Proxy (Proxy)) import Data.Typeable (Typeable) @@ -24,10 +25,11 @@ import Servant.Server.Internal (HasContextEntry, HasServer, ServerT, getContextEntry, route) -import Servant.Server.Internal.Router (Router' (WithRequest)) -import Servant.Server.Internal.RoutingApplication (RouteResult (FailFatal, Route), - addAuthCheck) -import Servant.Server.Internal.ServantErr (ServantErr, Handler) +import Servant.Server.Internal.RoutingApplication (addAuthCheck, + delayedFailFatal, + DelayedIO, + withRequest) +import Servant.Server.Internal.ServantErr (Handler) -- * General Auth @@ -57,8 +59,10 @@ instance ( HasServer api context type ServerT (AuthProtect tag :> api) m = AuthServerData (AuthProtect tag) -> ServerT api m - route Proxy context subserver = WithRequest $ \ request -> - route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request) + route Proxy context subserver = + route (Proxy :: Proxy api) context (subserver `addAuthCheck` withRequest authCheck) where + authHandler :: Request -> Handler (AuthServerData (AuthProtect tag)) authHandler = unAuthHandler (getContextEntry context) - authCheck = fmap (either FailFatal Route) . runExceptT . authHandler + authCheck :: Request -> DelayedIO (AuthServerData (AuthProtect tag)) + authCheck = (>>= either delayedFailFatal return) . liftIO . runExceptT . authHandler diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index eb3ca19c..2d378c9d 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -22,6 +22,7 @@ module Servant.Server.Internal , module Servant.Server.Internal.ServantErr ) where +import Control.Monad.Trans (liftIO) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC8 import qualified Data.ByteString.Lazy as BL @@ -70,7 +71,11 @@ import Servant.Server.Internal.ServantErr class HasServer layout context where type ServerT layout (m :: * -> *) :: * - route :: Proxy layout -> Context context -> Delayed (Server layout) -> Router + route :: + Proxy layout + -> Context context + -> Delayed env (Server layout) + -> Router env type Server layout = ServerT layout Handler @@ -92,7 +97,7 @@ instance (HasServer a context, HasServer b context) => HasServer (a :<|> b) cont type ServerT (a :<|> b) m = ServerT a m :<|> ServerT b m route Proxy context server = choice (route pa context ((\ (a :<|> _) -> a) <$> server)) - (route pb context ((\ (_ :<|> b) -> b) <$> server)) + (route pb context ((\ (_ :<|> b) -> b) <$> server)) where pa = Proxy :: Proxy a pb = Proxy :: Proxy b @@ -120,12 +125,12 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer sublayout context) a -> ServerT sublayout m route Proxy context d = - DynamicRouter $ \ first -> + CaptureRouter $ route (Proxy :: Proxy sublayout) context - (addCapture d $ case parseUrlPieceMaybe first :: Maybe a of - Nothing -> return $ Fail err400 - Just v -> return $ Route v + (addCapture d $ \ txt -> case parseUrlPieceMaybe txt :: Maybe a of + Nothing -> delayedFail err400 + Just v -> return v ) allowedMethodHead :: Method -> Request -> Bool @@ -144,41 +149,41 @@ processMethodRouter handleA status method headers request = case handleA of bdy = if allowedMethodHead method request then "" else body hdrs = (hContentType, cs contentT) : (fromMaybe [] headers) -methodCheck :: Method -> Request -> IO (RouteResult ()) +methodCheck :: Method -> Request -> DelayedIO () methodCheck method request - | allowedMethod method request = return $ Route () - | otherwise = return $ Fail err405 + | allowedMethod method request = return () + | otherwise = delayedFail err405 -acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> IO (RouteResult ()) +acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> DelayedIO () acceptCheck proxy accH - | canHandleAcceptH proxy (AcceptHeader accH) = return $ Route () - | otherwise = return $ FailFatal err406 + | canHandleAcceptH proxy (AcceptHeader accH) = return () + | otherwise = delayedFailFatal err406 methodRouter :: (AllCTRender ctypes a) => Method -> Proxy ctypes -> Status - -> Delayed (Handler a) - -> Router + -> Delayed env (Handler a) + -> Router env methodRouter method proxy status action = leafRouter route' where - route' request respond = + route' env request respond = let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request in runAction (action `addMethodCheck` methodCheck method request `addAcceptCheck` acceptCheck proxy accH - ) respond $ \ output -> do + ) env request respond $ \ output -> do let handleA = handleAcceptH proxy (AcceptHeader accH) output processMethodRouter handleA status method Nothing request methodRouterHeaders :: (GetHeaders (Headers h v), AllCTRender ctypes v) => Method -> Proxy ctypes -> Status - -> Delayed (Handler (Headers h v)) - -> Router + -> Delayed env (Handler (Headers h v)) + -> Router env methodRouterHeaders method proxy status action = leafRouter route' where - route' request respond = + route' env request respond = let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request in runAction (action `addMethodCheck` methodCheck method request `addAcceptCheck` acceptCheck proxy accH - ) respond $ \ output -> do + ) env request respond $ \ output -> do let headers = getHeaders output handleA = handleAcceptH proxy (AcceptHeader accH) (getResponse output) processMethodRouter handleA status method (Just headers) request @@ -230,8 +235,8 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context) type ServerT (Header sym a :> sublayout) m = Maybe a -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - let mheader = parseHeaderMaybe =<< lookup str (requestHeaders request) + route Proxy context subserver = + let mheader req = parseHeaderMaybe =<< lookup str (requestHeaders req) in route (Proxy :: Proxy sublayout) context (passToServer subserver mheader) where str = fromString $ symbolVal (Proxy :: Proxy sym) @@ -262,10 +267,10 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context) type ServerT (QueryParam sym a :> sublayout) m = Maybe a -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - let querytext = parseQueryText $ rawQueryString request - param = - case lookup paramname querytext of + route Proxy context subserver = + let querytext r = parseQueryText $ rawQueryString r + param r = + case lookup paramname (querytext r) of Nothing -> Nothing -- param absent from the query string Just Nothing -> Nothing -- param present with no value -> Nothing Just (Just v) -> parseQueryParamMaybe v -- if present, we try to convert to @@ -298,13 +303,13 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context) type ServerT (QueryParams sym a :> sublayout) m = [a] -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - let querytext = parseQueryText $ rawQueryString request + route Proxy context subserver = + let querytext r = parseQueryText $ rawQueryString r -- if sym is "foo", we look for query string parameters -- named "foo" or "foo[]" and call parseQueryParam on the -- corresponding values - parameters = filter looksLikeParam querytext - values = mapMaybe (convert . snd) parameters + parameters r = filter looksLikeParam (querytext r) + values r = mapMaybe (convert . snd) (parameters r) in route (Proxy :: Proxy sublayout) context (passToServer subserver values) where paramname = cs $ symbolVal (Proxy :: Proxy sym) looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]") @@ -329,9 +334,9 @@ instance (KnownSymbol sym, HasServer sublayout context) type ServerT (QueryFlag sym :> sublayout) m = Bool -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - let querytext = parseQueryText $ rawQueryString request - param = case lookup paramname querytext of + route Proxy context subserver = + let querytext r = parseQueryText $ rawQueryString r + param r = case lookup paramname (querytext r) of Just Nothing -> True -- param is there, with no value Just (Just v) -> examine v -- param with a value Nothing -> False -- param not in the query string @@ -352,8 +357,8 @@ instance HasServer Raw context where type ServerT Raw m = Application - route Proxy _ rawApplication = RawRouter $ \ request respond -> do - r <- runDelayed rawApplication + route Proxy _ rawApplication = RawRouter $ \ env request respond -> do + r <- runDelayed rawApplication env request case r of Route app -> app request (respond . Route) Fail a -> respond $ Fail a @@ -386,10 +391,10 @@ instance ( AllCTUnrender list a, HasServer sublayout context type ServerT (ReqBody list a :> sublayout) m = a -> ServerT sublayout m - route Proxy context subserver = WithRequest $ \ request -> - route (Proxy :: Proxy sublayout) context (addBodyCheck subserver (bodyCheck request)) + route Proxy context subserver = + route (Proxy :: Proxy sublayout) context (addBodyCheck subserver bodyCheck) where - bodyCheck request = do + bodyCheck = withRequest $ \ request -> do -- See HTTP RFC 2616, section 7.2.1 -- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 -- See also "W3C Internet Media Type registration, consistency of use" @@ -397,11 +402,11 @@ instance ( AllCTUnrender list a, HasServer sublayout context let contentTypeH = fromMaybe "application/octet-stream" $ lookup hContentType $ requestHeaders request mrqbody <- handleCTypeH (Proxy :: Proxy list) (cs contentTypeH) - <$> lazyRequestBody request + <$> liftIO (lazyRequestBody request) case mrqbody of - Nothing -> return $ FailFatal err415 - Just (Left e) -> return $ FailFatal err400 { errBody = cs e } - Just (Right v) -> return $ Route v + Nothing -> delayedFailFatal err415 + Just (Left e) -> delayedFailFatal err400 { errBody = cs e } + Just (Right v) -> return v -- | Make sure the incoming request starts with @"/path"@, strip it and -- pass the rest of the request path to @sublayout@. @@ -418,28 +423,28 @@ instance (KnownSymbol path, HasServer sublayout context) => HasServer (path :> s instance HasServer api context => HasServer (RemoteHost :> api) context where type ServerT (RemoteHost :> api) m = SockAddr -> ServerT api m - route Proxy context subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) context (passToServer subserver $ remoteHost req) + route Proxy context subserver = + route (Proxy :: Proxy api) context (passToServer subserver remoteHost) instance HasServer api context => HasServer (IsSecure :> api) context where type ServerT (IsSecure :> api) m = IsSecure -> ServerT api m - route Proxy context subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) context (passToServer subserver $ secure req) + route Proxy context subserver = + route (Proxy :: Proxy api) context (passToServer subserver secure) where secure req = if isSecure req then Secure else NotSecure instance HasServer api context => HasServer (Vault :> api) context where type ServerT (Vault :> api) m = Vault -> ServerT api m - route Proxy context subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) context (passToServer subserver $ vault req) + route Proxy context subserver = + route (Proxy :: Proxy api) context (passToServer subserver vault) instance HasServer api context => HasServer (HttpVersion :> api) context where type ServerT (HttpVersion :> api) m = HttpVersion -> ServerT api m - route Proxy context subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) context (passToServer subserver $ httpVersion req) + route Proxy context subserver = + route (Proxy :: Proxy api) context (passToServer subserver httpVersion) -- | Basic Authentication instance ( KnownSymbol realm @@ -450,12 +455,12 @@ instance ( KnownSymbol realm type ServerT (BasicAuth realm usr :> api) m = usr -> ServerT api m - route Proxy context subserver = WithRequest $ \ request -> - route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request) + route Proxy context subserver = + route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck) where realm = BC8.pack $ symbolVal (Proxy :: Proxy realm) basicAuthContext = getContextEntry context - authCheck req = runBasicAuth req realm basicAuthContext + authCheck = withRequest $ \ req -> runBasicAuth req realm basicAuthContext -- * helpers diff --git a/servant-server/src/Servant/Server/Internal/BasicAuth.hs b/servant-server/src/Servant/Server/Internal/BasicAuth.hs index fcd678b5..1fed931b 100644 --- a/servant-server/src/Servant/Server/Internal/BasicAuth.hs +++ b/servant-server/src/Servant/Server/Internal/BasicAuth.hs @@ -6,6 +6,7 @@ module Servant.Server.Internal.BasicAuth where import Control.Monad (guard) +import Control.Monad.Trans (liftIO) import qualified Data.ByteString as BS import Data.ByteString.Base64 (decodeLenient) import Data.Monoid ((<>)) @@ -57,13 +58,13 @@ decodeBAHdr req = do -- | Run and check basic authentication, returning the appropriate http error per -- the spec. -runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> IO (RouteResult usr) +runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr runBasicAuth req realm (BasicAuthCheck ba) = case decodeBAHdr req of Nothing -> plzAuthenticate - Just e -> ba e >>= \res -> case res of + Just e -> liftIO (ba e) >>= \res -> case res of BadPassword -> plzAuthenticate NoSuchUser -> plzAuthenticate - Unauthorized -> return $ FailFatal err403 - Authorized usr -> return $ Route usr - where plzAuthenticate = return $ FailFatal err401 { errHeaders = [mkBAChallengerHdr realm] } + Unauthorized -> delayedFailFatal err403 + Authorized usr -> return usr + where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] } diff --git a/servant-server/src/Servant/Server/Internal/Router.hs b/servant-server/src/Servant/Server/Internal/Router.hs index 04b661a3..3b69c04c 100644 --- a/servant-server/src/Servant/Server/Internal/Router.hs +++ b/servant-server/src/Servant/Server/Internal/Router.hs @@ -1,5 +1,7 @@ {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE OverloadedStrings #-} module Servant.Server.Internal.Router where @@ -8,36 +10,41 @@ import qualified Data.Map as M import Data.Monoid import Data.Text (Text) import qualified Data.Text as T -import Network.Wai (Request, Response, pathInfo) +import Network.Wai (Response, pathInfo) import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.ServantErr -type Router = Router' RoutingApplication +type Router env = Router' env RoutingApplication -- | Internal representation of a router. -data Router' a = - WithRequest (Request -> Router' a) - -- ^ current request is passed to the router - | StaticRouter (Map Text (Router' a)) [a] +-- +-- The first argument describes an environment type that is +-- expected as extra input by the routers at the leaves. The +-- environment is filled while running the router, with path +-- components that can be used to process captures. +-- +data Router' env a = + StaticRouter (Map Text (Router' env a)) [env -> a] -- ^ the map contains routers for subpaths (first path component used -- for lookup and removed afterwards), the list contains handlers -- for the empty path, to be tried in order - | DynamicRouter (Text -> Router' a) - -- ^ first path component passed to the function and removed afterwards - | RawRouter a + | CaptureRouter (Router' (Text, env) a) + -- ^ first path component is passed to the child router in its + -- environment and removed afterwards + | RawRouter (env -> a) -- ^ to be used for routes we do not know anything about - | Choice (Router' a) (Router' a) + | Choice (Router' env a) (Router' env a) -- ^ left-biased choice between two routers deriving Functor -- | Smart constructor for a single static path component. -pathRouter :: Text -> Router' a -> Router' a +pathRouter :: Text -> Router' env a -> Router' env a pathRouter t r = StaticRouter (M.singleton t r) [] -- | Smart constructor for a leaf, i.e., a router that expects -- the empty path. -- -leafRouter :: a -> Router' a +leafRouter :: (env -> a) -> Router' env a leafRouter l = StaticRouter M.empty [l] -- | Smart constructor for the choice between routers. @@ -46,40 +53,27 @@ leafRouter l = StaticRouter M.empty [l] -- * Two static routers can be joined by joining their maps -- and concatenating their leaf-lists. -- * Two dynamic routers can be joined by joining their codomains. --- * Two 'WithRequest' routers can be joined by passing them --- the same request and joining their codomains. --- * A 'WithRequest' router can be joined with anything else by --- passing the same request to both but ignoring it in the --- component that does not need it. -- * Choice nodes can be reordered. -- -choice :: Router -> Router -> Router +choice :: Router' env a -> Router' env a -> Router' env a choice (StaticRouter table1 ls1) (StaticRouter table2 ls2) = StaticRouter (M.unionWith choice table1 table2) (ls1 ++ ls2) -choice (DynamicRouter fun1) (DynamicRouter fun2) = - DynamicRouter (\ first -> choice (fun1 first) (fun2 first)) -choice (WithRequest router1) (WithRequest router2) = - WithRequest (\ request -> choice (router1 request) (router2 request)) -choice (WithRequest router1) router2 = - WithRequest (\ request -> choice (router1 request) router2) -choice router1 (WithRequest router2) = - WithRequest (\ request -> choice router1 (router2 request)) +choice (CaptureRouter router1) (CaptureRouter router2) = + CaptureRouter (choice router1 router2) choice router1 (Choice router2 router3) = Choice (choice router1 router2) router3 choice router1 router2 = Choice router1 router2 -- | Datatype used for representing and debugging the --- structure of a router. Abstracts from the functions --- being used in the actual router and the handlers at --- the leaves. +-- structure of a router. Abstracts from the handlers +-- at the leaves. -- -- Two 'Router's can be structurally compared by computing -- their 'RouterStructure' using 'routerStructure' and -- then testing for equality, see 'sameStructure'. -- data RouterStructure = - WithRequestStructure RouterStructure - | StaticRouterStructure (Map Text RouterStructure) Int - | DynamicRouterStructure RouterStructure + StaticRouterStructure (Map Text RouterStructure) Int + | CaptureRouterStructure RouterStructure | RawRouterStructure | ChoiceStructure RouterStructure RouterStructure deriving (Eq, Show) @@ -87,18 +81,15 @@ data RouterStructure = -- | Compute the structure of a router. -- -- Assumes that the request or text being passed --- in 'WithRequest' or 'DynamicRouter' does not +-- in 'WithRequest' or 'CaptureRouter' does not -- affect the structure of the underlying tree. -- -routerStructure :: Router' a -> RouterStructure -routerStructure (WithRequest f) = - WithRequestStructure $ - routerStructure (f (error "routerStructure: dummy request")) +routerStructure :: Router' env a -> RouterStructure routerStructure (StaticRouter m ls) = StaticRouterStructure (fmap routerStructure m) (length ls) -routerStructure (DynamicRouter f) = - DynamicRouterStructure $ - routerStructure (f (error "routerStructure: dummy text")) +routerStructure (CaptureRouter router) = + CaptureRouterStructure $ + routerStructure router routerStructure (RawRouter _) = RawRouterStructure routerStructure (Choice r1 r2) = @@ -108,21 +99,20 @@ routerStructure (Choice r1 r2) = -- | Compare the structure of two routers. -- -sameStructure :: Router' a -> Router' b -> Bool +sameStructure :: Router' env a -> Router' env b -> Bool sameStructure r1 r2 = routerStructure r1 == routerStructure r2 -- | Provide a textual representation of the -- structure of a router. -- -routerLayout :: Router' a -> Text +routerLayout :: Router' env a -> Text routerLayout router = T.unlines (["/"] ++ mkRouterLayout False (routerStructure router)) where mkRouterLayout :: Bool -> RouterStructure -> [Text] - mkRouterLayout c (WithRequestStructure r) = mkRouterLayout c r mkRouterLayout c (StaticRouterStructure m n) = mkSubTrees c (M.toList m) n - mkRouterLayout c (DynamicRouterStructure r) = mkSubTree c "" (mkRouterLayout False r) + mkRouterLayout c (CaptureRouterStructure r) = mkSubTree c "" (mkRouterLayout False r) mkRouterLayout c RawRouterStructure = if c then ["├─ "] else ["└─ "] mkRouterLayout c (ChoiceStructure r1 r2) = @@ -146,47 +136,54 @@ routerLayout router = mkSubTree False path children = ("└─ " <> path <> "/") : map (" " <>) children -- | Apply a transformation to the response of a `Router`. -tweakResponse :: (RouteResult Response -> RouteResult Response) -> Router -> Router +tweakResponse :: (RouteResult Response -> RouteResult Response) -> Router env -> Router env tweakResponse f = fmap (\a -> \req cont -> a req (cont . f)) -- | Interpret a router as an application. -runRouter :: Router -> RoutingApplication -runRouter (WithRequest router) request respond = - runRouter (router request) request respond -runRouter (StaticRouter table ls) request respond = - case pathInfo request of - [] -> runChoice ls request respond - -- This case is to handle trailing slashes. - [""] -> runChoice ls request respond - first : rest | Just router <- M.lookup first table - -> let request' = request { pathInfo = rest } - in runRouter router request' respond - _ -> respond $ Fail err404 -runRouter (DynamicRouter fun) request respond = - case pathInfo request of - [] -> respond $ Fail err404 - -- This case is to handle trailing slashes. - [""] -> respond $ Fail err404 - first : rest - -> let request' = request { pathInfo = rest } - in runRouter (fun first) request' respond -runRouter (RawRouter app) request respond = app request respond -runRouter (Choice r1 r2) request respond = - runChoice [runRouter r1, runRouter r2] request respond +runRouter :: Router () -> RoutingApplication +runRouter r = runRouterEnv r () + +runRouterEnv :: Router env -> env -> RoutingApplication +runRouterEnv router env request respond = + case router of + StaticRouter table ls -> + case pathInfo request of + [] -> runChoice ls env request respond + -- This case is to handle trailing slashes. + [""] -> runChoice ls env request respond + first : rest | Just router' <- M.lookup first table + -> let request' = request { pathInfo = rest } + in runRouterEnv router' env request' respond + _ -> respond $ Fail err404 + CaptureRouter router' -> + case pathInfo request of + [] -> respond $ Fail err404 + -- This case is to handle trailing slashes. + [""] -> respond $ Fail err404 + first : rest + -> let request' = request { pathInfo = rest } + in runRouterEnv router' (first, env) request' respond + RawRouter app -> + app env request respond + Choice r1 r2 -> + runChoice [runRouterEnv r1, runRouterEnv r2] env request respond -- | Try a list of routing applications in order. -- We stop as soon as one fails fatally or succeeds. -- If all fail normally, we pick the "best" error. -- -runChoice :: [RoutingApplication] -> RoutingApplication -runChoice [] _request respond = respond (Fail err404) -runChoice [r] request respond = r request respond -runChoice (r : rs) request respond = - r request $ \ response1 -> - case response1 of - Fail _ -> runChoice rs request $ \ response2 -> - respond $ highestPri response1 response2 - _ -> respond response1 +runChoice :: [env -> RoutingApplication] -> env -> RoutingApplication +runChoice ls = + case ls of + [] -> \ _ _ respond -> respond (Fail err404) + [r] -> r + (r : rs) -> + \ env request respond -> + r env request $ \ response1 -> + case response1 of + Fail _ -> runChoice rs env request $ \ response2 -> + respond $ highestPri response1 response2 + _ -> respond response1 where highestPri (Fail e1) (Fail e2) = if worseHTTPCode (errHTTPCode e1) (errHTTPCode e2) diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 99def4b8..5825531e 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -8,7 +8,10 @@ {-# LANGUAGE StandaloneDeriving #-} module Servant.Server.Internal.RoutingApplication where +import Control.Monad (ap, liftM) +import Control.Monad.Trans (MonadIO(..)) import Control.Monad.Trans.Except (runExceptT) +import Data.Text (Text) import Network.Wai (Application, Request, Response, ResponseReceived) import Prelude () @@ -95,113 +98,133 @@ toApplication ra request respond = ra request routingRespond -- The accept header check can be performed as the final -- computation in this block. It can cause a 406. -- -data Delayed c where - Delayed :: { capturesD :: IO (RouteResult captures) - , methodD :: IO (RouteResult ()) - , authD :: IO (RouteResult auth) - , bodyD :: IO (RouteResult body) - , serverD :: (captures -> auth -> body -> RouteResult c) - } -> Delayed c +data Delayed env c where + Delayed :: { capturesD :: env -> DelayedIO captures + , methodD :: DelayedIO () + , authD :: DelayedIO auth + , bodyD :: DelayedIO body + , serverD :: captures -> auth -> body -> Request -> RouteResult c + } -> Delayed env c -instance Functor Delayed where - fmap f Delayed{..} - = Delayed { capturesD = capturesD - , methodD = methodD - , authD = authD - , bodyD = bodyD - , serverD = (fmap.fmap.fmap.fmap) f serverD - } -- Note [Existential Record Update] +instance Functor (Delayed env) where + fmap f Delayed{..} = + Delayed + { serverD = \ c a b req -> f <$> serverD c a b req + , .. + } -- Note [Existential Record Update] + +-- | Computations used in a 'Delayed' can depend on the +-- incoming 'Request', may perform 'IO, and result in a +-- 'RouteResult, meaning they can either suceed, fail +-- (with the possibility to recover), or fail fatally. +-- +newtype DelayedIO a = DelayedIO { runDelayedIO :: Request -> IO (RouteResult a) } + +instance Functor DelayedIO where + fmap = liftM + +instance Applicative DelayedIO where + pure = return + (<*>) = ap + +instance Monad DelayedIO where + return x = DelayedIO (const $ return (Route x)) + DelayedIO m >>= f = + DelayedIO $ \ req -> do + r <- m req + case r of + Fail e -> return $ Fail e + FailFatal e -> return $ FailFatal e + Route a -> runDelayedIO (f a) req + +instance MonadIO DelayedIO where + liftIO m = DelayedIO (const $ Route <$> m) + +-- | A 'Delayed' without any stored checks. +emptyDelayed :: RouteResult a -> Delayed env a +emptyDelayed result = + Delayed (const r) r r r (\ _ _ _ _ -> result) + where + r = return () + +-- | Fail with the option to recover. +delayedFail :: ServantErr -> DelayedIO a +delayedFail err = DelayedIO (const $ return $ Fail err) + +-- | Fail fatally, i.e., without any option to recover. +delayedFailFatal :: ServantErr -> DelayedIO a +delayedFailFatal err = DelayedIO (const $ return $ FailFatal err) + +-- | Gain access to the incoming request. +withRequest :: (Request -> DelayedIO a) -> DelayedIO a +withRequest f = DelayedIO (\ req -> runDelayedIO (f req) req) -- | Add a capture to the end of the capture block. -addCapture :: Delayed (a -> b) - -> IO (RouteResult a) - -> Delayed b -addCapture Delayed{..} new - = Delayed { capturesD = combineRouteResults (,) capturesD new - , methodD = methodD - , authD = authD - , bodyD = bodyD - , serverD = \ (x, v) y z -> ($ v) <$> serverD x y z - } -- Note [Existential Record Update] +addCapture :: Delayed env (a -> b) + -> (Text -> DelayedIO a) + -> Delayed (Text, env) b +addCapture Delayed{..} new = + Delayed + { capturesD = \ (txt, env) -> (,) <$> capturesD env <*> new txt + , serverD = \ (x, v) a b req -> ($ v) <$> serverD x a b req + , .. + } -- Note [Existential Record Update] -- | Add a method check to the end of the method block. -addMethodCheck :: Delayed a - -> IO (RouteResult ()) - -> Delayed a -addMethodCheck Delayed{..} new - = Delayed { capturesD = capturesD - , methodD = combineRouteResults const methodD new - , authD = authD - , bodyD = bodyD - , serverD = serverD - } -- Note [Existential Record Update] +addMethodCheck :: Delayed env a + -> DelayedIO () + -> Delayed env a +addMethodCheck Delayed{..} new = + Delayed + { methodD = methodD <* new + , .. + } -- Note [Existential Record Update] -- | Add an auth check to the end of the auth block. -addAuthCheck :: Delayed (a -> b) - -> IO (RouteResult a) - -> Delayed b -addAuthCheck Delayed{..} new - = Delayed { capturesD = capturesD - , methodD = methodD - , authD = combineRouteResults (,) authD new - , bodyD = bodyD - , serverD = \ x (y, v) z -> ($ v) <$> serverD x y z - } -- Note [Existential Record Update] +addAuthCheck :: Delayed env (a -> b) + -> DelayedIO a + -> Delayed env b +addAuthCheck Delayed{..} new = + Delayed + { authD = (,) <$> authD <*> new + , serverD = \ c (y, v) b req -> ($ v) <$> serverD c y b req + , .. + } -- Note [Existential Record Update] -- | Add a body check to the end of the body block. -addBodyCheck :: Delayed (a -> b) - -> IO (RouteResult a) - -> Delayed b -addBodyCheck Delayed{..} new - = Delayed { capturesD = capturesD - , methodD = methodD - , authD = authD - , bodyD = combineRouteResults (,) bodyD new - , serverD = \ x y (z, v) -> ($ v) <$> serverD x y z - } -- Note [Existential Record Update] +addBodyCheck :: Delayed env (a -> b) + -> DelayedIO a + -> Delayed env b +addBodyCheck Delayed{..} new = + Delayed + { bodyD = (,) <$> bodyD <*> new + , serverD = \ c a (z, v) req -> ($ v) <$> serverD c a z req + , .. + } -- Note [Existential Record Update] -- | Add an accept header check to the end of the body block. -- The accept header check should occur after the body check, -- but this will be the case, because the accept header check -- is only scheduled by the method combinators. -addAcceptCheck :: Delayed a - -> IO (RouteResult ()) - -> Delayed a -addAcceptCheck Delayed{..} new - = Delayed { capturesD = capturesD - , methodD = methodD - , authD = authD - , bodyD = combineRouteResults const bodyD new - , serverD = serverD - } -- Note [Existential Record Update] +addAcceptCheck :: Delayed env a + -> DelayedIO () + -> Delayed env a +addAcceptCheck Delayed{..} new = + Delayed + { bodyD = bodyD <* new + , .. + } -- Note [Existential Record Update] -- | Many combinators extract information that is passed to -- the handler without the possibility of failure. In such a -- case, 'passToServer' can be used. -passToServer :: Delayed (a -> b) -> a -> Delayed b -passToServer d x = ($ x) <$> d - --- | The combination 'IO . RouteResult' is a monad, but we --- don't explicitly wrap it in a newtype in order to make it --- an instance. This is the '>>=' of that monad. --- --- We stop on the first error. -bindRouteResults :: IO (RouteResult a) -> (a -> IO (RouteResult b)) -> IO (RouteResult b) -bindRouteResults m f = do - r <- m - case r of - Fail e -> return $ Fail e - FailFatal e -> return $ FailFatal e - Route a -> f a - --- | Common special case of 'bindRouteResults', corresponding --- to 'liftM2'. -combineRouteResults :: (a -> b -> c) -> IO (RouteResult a) -> IO (RouteResult b) -> IO (RouteResult c) -combineRouteResults f m1 m2 = - m1 `bindRouteResults` \ a -> - m2 `bindRouteResults` \ b -> - return (Route (f a b)) +passToServer :: Delayed env (a -> b) -> (Request -> a) -> Delayed env b +passToServer Delayed{..} x = + Delayed + { serverD = \ c a b req -> ($ x req) <$> serverD c a b req + , .. + } -- Note [Existential Record Update] -- | Run a delayed server. Performs all scheduled operations -- in order, and passes the results from the capture and body @@ -209,24 +232,29 @@ combineRouteResults f m1 m2 = -- -- This should only be called once per request; otherwise the guarantees about -- effect and HTTP error ordering break down. -runDelayed :: Delayed a +runDelayed :: Delayed env a + -> env + -> Request -> IO (RouteResult a) -runDelayed Delayed{..} = - capturesD `bindRouteResults` \ c -> - methodD `bindRouteResults` \ _ -> - authD `bindRouteResults` \ a -> - bodyD `bindRouteResults` \ b -> - return (serverD c a b) +runDelayed Delayed{..} env = runDelayedIO $ do + c <- capturesD env + methodD + a <- authD + b <- bodyD + DelayedIO (\ req -> return $ serverD c a b req) -- | Runs a delayed server and the resulting action. -- Takes a continuation that lets us send a response. -- Also takes a continuation for how to turn the -- result of the delayed server into a response. -runAction :: Delayed (Handler a) +runAction :: Delayed env (Handler a) + -> env + -> Request -> (RouteResult Response -> IO r) -> (a -> RouteResult Response) -> IO r -runAction action respond k = runDelayed action >>= go >>= respond +runAction action env req respond k = + runDelayed action env req >>= go >>= respond where go (Fail e) = return $ Fail e go (FailFatal e) = return $ FailFatal e diff --git a/servant-server/test/Servant/Server/RouterSpec.hs b/servant-server/test/Servant/Server/RouterSpec.hs index 7ebd1a75..684361b2 100644 --- a/servant-server/test/Servant/Server/RouterSpec.hs +++ b/servant-server/test/Servant/Server/RouterSpec.hs @@ -25,9 +25,9 @@ routerSpec = do let app' :: Application app' = toApplication $ runRouter router' - router', router :: Router + router', router :: Router () router' = tweakResponse (fmap twk) router - router = leafRouter $ \_ cont -> cont (Route $ responseBuilder (Status 201 "") [] "") + router = leafRouter $ \_ _ cont -> cont (Route $ responseBuilder (Status 201 "") [] "") twk :: Response -> Response twk (ResponseBuilder (Status i s) hs b) = ResponseBuilder (Status (i + 1) s) hs b @@ -69,11 +69,9 @@ shouldHaveSameStructureAs p1 p2 = unless (sameStructure (makeTrivialRouter p1) (makeTrivialRouter p2)) $ expectationFailure ("expected:\n" ++ unpack (layout p2) ++ "\nbut got:\n" ++ unpack (layout p1)) -makeTrivialRouter :: (HasServer layout '[]) => Proxy layout -> Router -makeTrivialRouter p = route p EmptyContext d - where - d = Delayed r r r r (\ _ _ _ -> FailFatal err501) - r = return (Route ()) +makeTrivialRouter :: (HasServer layout '[]) => Proxy layout -> Router () +makeTrivialRouter p = + route p EmptyContext (emptyDelayed (FailFatal err501)) type End = Get '[JSON] () diff --git a/servant-server/test/Servant/Server/UsingContextSpec/TestCombinators.hs b/servant-server/test/Servant/Server/UsingContextSpec/TestCombinators.hs index 48595c9c..21999451 100644 --- a/servant-server/test/Servant/Server/UsingContextSpec/TestCombinators.hs +++ b/servant-server/test/Servant/Server/UsingContextSpec/TestCombinators.hs @@ -20,7 +20,6 @@ module Servant.Server.UsingContextSpec.TestCombinators where import GHC.TypeLits import Servant -import Servant.Server.Internal.RoutingApplication data ExtractFromContext @@ -31,7 +30,7 @@ instance (HasContextEntry context String, HasServer subApi context) => String -> ServerT subApi m route Proxy context delayed = - route subProxy context (fmap (inject context) delayed :: Delayed (Server subApi)) + route subProxy context (fmap (inject context) delayed) where subProxy :: Proxy subApi subProxy = Proxy diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index 5b4154d7..fc4eb1df 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -48,7 +48,7 @@ import Servant.API ((:<|>) (..), (:>), AuthProtect, Raw, RemoteHost, ReqBody, StdMethod (..), Verb, addHeader) import Servant.API.Internal.Test.ComprehensiveAPI -import Servant.Server (ServantErr (..), Server, Handler, err401, err403, +import Servant.Server (Server, Handler, err401, err403, err404, serve, serveWithContext, Context((:.), EmptyContext)) import Test.Hspec (Spec, context, describe, it,