Merge pull request #457 from kosmikus/fix-router-sharing

Fix router sharing
This commit is contained in:
Andres Löh 2016-04-12 10:04:05 +02:00
commit 15143cc900
10 changed files with 292 additions and 269 deletions

View file

@ -1,10 +1,11 @@
0.7 0.7
--- ---
* The `Router` type has been changed. There are now more situations where * The `Router` type has been changed. Static router tables should now
servers will make use of static lookups to efficiently route the request be properly shared between requests, drastically increasing the
to the correct endpoint. Functions `layout` and `layoutWithContext` have number of situations where servers will be able to route requests
been added to visualize the router layout for debugging purposes. Test efficiently. Functions `layout` and `layoutWithContext` have been
added to visualize the router layout for debugging purposes. Test
cases for expected router layouts have been added. cases for expected router layouts have been added.
* Export `throwError` from module `Servant` * Export `throwError` from module `Servant`
* Add `Handler` type synonym * Add `Handler` type synonym

View file

@ -132,21 +132,13 @@ serve p = serveWithContext p EmptyContext
serveWithContext :: (HasServer layout context) serveWithContext :: (HasServer layout context)
=> Proxy layout -> Context context -> Server layout -> Application => Proxy layout -> Context context -> Server layout -> Application
serveWithContext p context server = toApplication (runRouter (route p context d)) serveWithContext p context server =
where toApplication (runRouter (route p context (emptyDelayed (Route server))))
d = Delayed r r r r (\ _ _ _ -> Route server)
r = return (Route ())
-- | The function 'layout' produces a textual description of the internal -- | The function 'layout' produces a textual description of the internal
-- router layout for debugging purposes. Note that the router layout is -- router layout for debugging purposes. Note that the router layout is
-- determined just by the API, not by the handlers. -- determined just by the API, not by the handlers.
-- --
-- This function makes certain assumptions about the well-behavedness of
-- the 'HasServer' instances of the combinators which should be ok for the
-- core servant constructions, but might not be satisfied for some other
-- combinators provided elsewhere. It is possible that the function may
-- crash for these.
--
-- Example: -- Example:
-- --
-- For the following API -- For the following API
@ -168,7 +160,7 @@ serveWithContext p context server = toApplication (runRouter (route p context d)
-- > │ └─ e/ -- > │ └─ e/
-- > │ └─• -- > │ └─•
-- > ├─ b/ -- > ├─ b/
-- > │ └─ <dyn>/ -- > │ └─ <capture>/
-- > │ ├─• -- > │ ├─•
-- > │ ┆ -- > │ ┆
-- > │ └─• -- > │ └─•
@ -185,7 +177,7 @@ serveWithContext p context server = toApplication (runRouter (route p context d)
-- --
-- [@─•@] Leaves reflect endpoints. -- [@─•@] Leaves reflect endpoints.
-- --
-- [@\<dyn\>/@] This is a delayed capture of a path component. -- [@\<capture\>/@] This is a delayed capture of a path component.
-- --
-- [@\<raw\>@] This is a part of the API we do not know anything about. -- [@\<raw\>@] This is a part of the API we do not know anything about.
-- --
@ -200,10 +192,8 @@ layout p = layoutWithContext p EmptyContext
-- | Variant of 'layout' that takes an additional 'Context'. -- | Variant of 'layout' that takes an additional 'Context'.
layoutWithContext :: (HasServer layout context) layoutWithContext :: (HasServer layout context)
=> Proxy layout -> Context context -> Text => Proxy layout -> Context context -> Text
layoutWithContext p context = routerLayout (route p context d) layoutWithContext p context =
where routerLayout (route p context (emptyDelayed (FailFatal err501)))
d = Delayed r r r r (\ _ _ _ -> FailFatal err501)
r = return (Route ())
-- Documentation -- Documentation

View file

@ -12,6 +12,7 @@
module Servant.Server.Experimental.Auth where module Servant.Server.Experimental.Auth where
import Control.Monad.Trans (liftIO)
import Control.Monad.Trans.Except (runExceptT) import Control.Monad.Trans.Except (runExceptT)
import Data.Proxy (Proxy (Proxy)) import Data.Proxy (Proxy (Proxy))
import Data.Typeable (Typeable) import Data.Typeable (Typeable)
@ -24,10 +25,11 @@ import Servant.Server.Internal (HasContextEntry,
HasServer, ServerT, HasServer, ServerT,
getContextEntry, getContextEntry,
route) route)
import Servant.Server.Internal.Router (Router' (WithRequest)) import Servant.Server.Internal.RoutingApplication (addAuthCheck,
import Servant.Server.Internal.RoutingApplication (RouteResult (FailFatal, Route), delayedFailFatal,
addAuthCheck) DelayedIO,
import Servant.Server.Internal.ServantErr (ServantErr, Handler) withRequest)
import Servant.Server.Internal.ServantErr (Handler)
-- * General Auth -- * General Auth
@ -57,8 +59,10 @@ instance ( HasServer api context
type ServerT (AuthProtect tag :> api) m = type ServerT (AuthProtect tag :> api) m =
AuthServerData (AuthProtect tag) -> ServerT api m AuthServerData (AuthProtect tag) -> ServerT api m
route Proxy context subserver = WithRequest $ \ request -> route Proxy context subserver =
route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request) route (Proxy :: Proxy api) context (subserver `addAuthCheck` withRequest authCheck)
where where
authHandler :: Request -> Handler (AuthServerData (AuthProtect tag))
authHandler = unAuthHandler (getContextEntry context) authHandler = unAuthHandler (getContextEntry context)
authCheck = fmap (either FailFatal Route) . runExceptT . authHandler authCheck :: Request -> DelayedIO (AuthServerData (AuthProtect tag))
authCheck = (>>= either delayedFailFatal return) . liftIO . runExceptT . authHandler

View file

@ -22,6 +22,7 @@ module Servant.Server.Internal
, module Servant.Server.Internal.ServantErr , module Servant.Server.Internal.ServantErr
) where ) where
import Control.Monad.Trans (liftIO)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC8 import qualified Data.ByteString.Char8 as BC8
import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy as BL
@ -70,7 +71,11 @@ import Servant.Server.Internal.ServantErr
class HasServer layout context where class HasServer layout context where
type ServerT layout (m :: * -> *) :: * type ServerT layout (m :: * -> *) :: *
route :: Proxy layout -> Context context -> Delayed (Server layout) -> Router route ::
Proxy layout
-> Context context
-> Delayed env (Server layout)
-> Router env
type Server layout = ServerT layout Handler type Server layout = ServerT layout Handler
@ -92,7 +97,7 @@ instance (HasServer a context, HasServer b context) => HasServer (a :<|> b) cont
type ServerT (a :<|> b) m = ServerT a m :<|> ServerT b m type ServerT (a :<|> b) m = ServerT a m :<|> ServerT b m
route Proxy context server = choice (route pa context ((\ (a :<|> _) -> a) <$> server)) route Proxy context server = choice (route pa context ((\ (a :<|> _) -> a) <$> server))
(route pb context ((\ (_ :<|> b) -> b) <$> server)) (route pb context ((\ (_ :<|> b) -> b) <$> server))
where pa = Proxy :: Proxy a where pa = Proxy :: Proxy a
pb = Proxy :: Proxy b pb = Proxy :: Proxy b
@ -120,12 +125,12 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer sublayout context)
a -> ServerT sublayout m a -> ServerT sublayout m
route Proxy context d = route Proxy context d =
DynamicRouter $ \ first -> CaptureRouter $
route (Proxy :: Proxy sublayout) route (Proxy :: Proxy sublayout)
context context
(addCapture d $ case parseUrlPieceMaybe first :: Maybe a of (addCapture d $ \ txt -> case parseUrlPieceMaybe txt :: Maybe a of
Nothing -> return $ Fail err400 Nothing -> delayedFail err400
Just v -> return $ Route v Just v -> return v
) )
allowedMethodHead :: Method -> Request -> Bool allowedMethodHead :: Method -> Request -> Bool
@ -144,41 +149,41 @@ processMethodRouter handleA status method headers request = case handleA of
bdy = if allowedMethodHead method request then "" else body bdy = if allowedMethodHead method request then "" else body
hdrs = (hContentType, cs contentT) : (fromMaybe [] headers) hdrs = (hContentType, cs contentT) : (fromMaybe [] headers)
methodCheck :: Method -> Request -> IO (RouteResult ()) methodCheck :: Method -> Request -> DelayedIO ()
methodCheck method request methodCheck method request
| allowedMethod method request = return $ Route () | allowedMethod method request = return ()
| otherwise = return $ Fail err405 | otherwise = delayedFail err405
acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> IO (RouteResult ()) acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> DelayedIO ()
acceptCheck proxy accH acceptCheck proxy accH
| canHandleAcceptH proxy (AcceptHeader accH) = return $ Route () | canHandleAcceptH proxy (AcceptHeader accH) = return ()
| otherwise = return $ FailFatal err406 | otherwise = delayedFailFatal err406
methodRouter :: (AllCTRender ctypes a) methodRouter :: (AllCTRender ctypes a)
=> Method -> Proxy ctypes -> Status => Method -> Proxy ctypes -> Status
-> Delayed (Handler a) -> Delayed env (Handler a)
-> Router -> Router env
methodRouter method proxy status action = leafRouter route' methodRouter method proxy status action = leafRouter route'
where where
route' request respond = route' env request respond =
let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request
in runAction (action `addMethodCheck` methodCheck method request in runAction (action `addMethodCheck` methodCheck method request
`addAcceptCheck` acceptCheck proxy accH `addAcceptCheck` acceptCheck proxy accH
) respond $ \ output -> do ) env request respond $ \ output -> do
let handleA = handleAcceptH proxy (AcceptHeader accH) output let handleA = handleAcceptH proxy (AcceptHeader accH) output
processMethodRouter handleA status method Nothing request processMethodRouter handleA status method Nothing request
methodRouterHeaders :: (GetHeaders (Headers h v), AllCTRender ctypes v) methodRouterHeaders :: (GetHeaders (Headers h v), AllCTRender ctypes v)
=> Method -> Proxy ctypes -> Status => Method -> Proxy ctypes -> Status
-> Delayed (Handler (Headers h v)) -> Delayed env (Handler (Headers h v))
-> Router -> Router env
methodRouterHeaders method proxy status action = leafRouter route' methodRouterHeaders method proxy status action = leafRouter route'
where where
route' request respond = route' env request respond =
let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request
in runAction (action `addMethodCheck` methodCheck method request in runAction (action `addMethodCheck` methodCheck method request
`addAcceptCheck` acceptCheck proxy accH `addAcceptCheck` acceptCheck proxy accH
) respond $ \ output -> do ) env request respond $ \ output -> do
let headers = getHeaders output let headers = getHeaders output
handleA = handleAcceptH proxy (AcceptHeader accH) (getResponse output) handleA = handleAcceptH proxy (AcceptHeader accH) (getResponse output)
processMethodRouter handleA status method (Just headers) request processMethodRouter handleA status method (Just headers) request
@ -230,8 +235,8 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context)
type ServerT (Header sym a :> sublayout) m = type ServerT (Header sym a :> sublayout) m =
Maybe a -> ServerT sublayout m Maybe a -> ServerT sublayout m
route Proxy context subserver = WithRequest $ \ request -> route Proxy context subserver =
let mheader = parseHeaderMaybe =<< lookup str (requestHeaders request) let mheader req = parseHeaderMaybe =<< lookup str (requestHeaders req)
in route (Proxy :: Proxy sublayout) context (passToServer subserver mheader) in route (Proxy :: Proxy sublayout) context (passToServer subserver mheader)
where str = fromString $ symbolVal (Proxy :: Proxy sym) where str = fromString $ symbolVal (Proxy :: Proxy sym)
@ -262,10 +267,10 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context)
type ServerT (QueryParam sym a :> sublayout) m = type ServerT (QueryParam sym a :> sublayout) m =
Maybe a -> ServerT sublayout m Maybe a -> ServerT sublayout m
route Proxy context subserver = WithRequest $ \ request -> route Proxy context subserver =
let querytext = parseQueryText $ rawQueryString request let querytext r = parseQueryText $ rawQueryString r
param = param r =
case lookup paramname querytext of case lookup paramname (querytext r) of
Nothing -> Nothing -- param absent from the query string Nothing -> Nothing -- param absent from the query string
Just Nothing -> Nothing -- param present with no value -> Nothing Just Nothing -> Nothing -- param present with no value -> Nothing
Just (Just v) -> parseQueryParamMaybe v -- if present, we try to convert to Just (Just v) -> parseQueryParamMaybe v -- if present, we try to convert to
@ -298,13 +303,13 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context)
type ServerT (QueryParams sym a :> sublayout) m = type ServerT (QueryParams sym a :> sublayout) m =
[a] -> ServerT sublayout m [a] -> ServerT sublayout m
route Proxy context subserver = WithRequest $ \ request -> route Proxy context subserver =
let querytext = parseQueryText $ rawQueryString request let querytext r = parseQueryText $ rawQueryString r
-- if sym is "foo", we look for query string parameters -- if sym is "foo", we look for query string parameters
-- named "foo" or "foo[]" and call parseQueryParam on the -- named "foo" or "foo[]" and call parseQueryParam on the
-- corresponding values -- corresponding values
parameters = filter looksLikeParam querytext parameters r = filter looksLikeParam (querytext r)
values = mapMaybe (convert . snd) parameters values r = mapMaybe (convert . snd) (parameters r)
in route (Proxy :: Proxy sublayout) context (passToServer subserver values) in route (Proxy :: Proxy sublayout) context (passToServer subserver values)
where paramname = cs $ symbolVal (Proxy :: Proxy sym) where paramname = cs $ symbolVal (Proxy :: Proxy sym)
looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]") looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]")
@ -329,9 +334,9 @@ instance (KnownSymbol sym, HasServer sublayout context)
type ServerT (QueryFlag sym :> sublayout) m = type ServerT (QueryFlag sym :> sublayout) m =
Bool -> ServerT sublayout m Bool -> ServerT sublayout m
route Proxy context subserver = WithRequest $ \ request -> route Proxy context subserver =
let querytext = parseQueryText $ rawQueryString request let querytext r = parseQueryText $ rawQueryString r
param = case lookup paramname querytext of param r = case lookup paramname (querytext r) of
Just Nothing -> True -- param is there, with no value Just Nothing -> True -- param is there, with no value
Just (Just v) -> examine v -- param with a value Just (Just v) -> examine v -- param with a value
Nothing -> False -- param not in the query string Nothing -> False -- param not in the query string
@ -352,8 +357,8 @@ instance HasServer Raw context where
type ServerT Raw m = Application type ServerT Raw m = Application
route Proxy _ rawApplication = RawRouter $ \ request respond -> do route Proxy _ rawApplication = RawRouter $ \ env request respond -> do
r <- runDelayed rawApplication r <- runDelayed rawApplication env request
case r of case r of
Route app -> app request (respond . Route) Route app -> app request (respond . Route)
Fail a -> respond $ Fail a Fail a -> respond $ Fail a
@ -386,10 +391,10 @@ instance ( AllCTUnrender list a, HasServer sublayout context
type ServerT (ReqBody list a :> sublayout) m = type ServerT (ReqBody list a :> sublayout) m =
a -> ServerT sublayout m a -> ServerT sublayout m
route Proxy context subserver = WithRequest $ \ request -> route Proxy context subserver =
route (Proxy :: Proxy sublayout) context (addBodyCheck subserver (bodyCheck request)) route (Proxy :: Proxy sublayout) context (addBodyCheck subserver bodyCheck)
where where
bodyCheck request = do bodyCheck = withRequest $ \ request -> do
-- See HTTP RFC 2616, section 7.2.1 -- See HTTP RFC 2616, section 7.2.1
-- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 -- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
-- See also "W3C Internet Media Type registration, consistency of use" -- See also "W3C Internet Media Type registration, consistency of use"
@ -397,11 +402,11 @@ instance ( AllCTUnrender list a, HasServer sublayout context
let contentTypeH = fromMaybe "application/octet-stream" let contentTypeH = fromMaybe "application/octet-stream"
$ lookup hContentType $ requestHeaders request $ lookup hContentType $ requestHeaders request
mrqbody <- handleCTypeH (Proxy :: Proxy list) (cs contentTypeH) mrqbody <- handleCTypeH (Proxy :: Proxy list) (cs contentTypeH)
<$> lazyRequestBody request <$> liftIO (lazyRequestBody request)
case mrqbody of case mrqbody of
Nothing -> return $ FailFatal err415 Nothing -> delayedFailFatal err415
Just (Left e) -> return $ FailFatal err400 { errBody = cs e } Just (Left e) -> delayedFailFatal err400 { errBody = cs e }
Just (Right v) -> return $ Route v Just (Right v) -> return v
-- | Make sure the incoming request starts with @"/path"@, strip it and -- | Make sure the incoming request starts with @"/path"@, strip it and
-- pass the rest of the request path to @sublayout@. -- pass the rest of the request path to @sublayout@.
@ -418,28 +423,28 @@ instance (KnownSymbol path, HasServer sublayout context) => HasServer (path :> s
instance HasServer api context => HasServer (RemoteHost :> api) context where instance HasServer api context => HasServer (RemoteHost :> api) context where
type ServerT (RemoteHost :> api) m = SockAddr -> ServerT api m type ServerT (RemoteHost :> api) m = SockAddr -> ServerT api m
route Proxy context subserver = WithRequest $ \req -> route Proxy context subserver =
route (Proxy :: Proxy api) context (passToServer subserver $ remoteHost req) route (Proxy :: Proxy api) context (passToServer subserver remoteHost)
instance HasServer api context => HasServer (IsSecure :> api) context where instance HasServer api context => HasServer (IsSecure :> api) context where
type ServerT (IsSecure :> api) m = IsSecure -> ServerT api m type ServerT (IsSecure :> api) m = IsSecure -> ServerT api m
route Proxy context subserver = WithRequest $ \req -> route Proxy context subserver =
route (Proxy :: Proxy api) context (passToServer subserver $ secure req) route (Proxy :: Proxy api) context (passToServer subserver secure)
where secure req = if isSecure req then Secure else NotSecure where secure req = if isSecure req then Secure else NotSecure
instance HasServer api context => HasServer (Vault :> api) context where instance HasServer api context => HasServer (Vault :> api) context where
type ServerT (Vault :> api) m = Vault -> ServerT api m type ServerT (Vault :> api) m = Vault -> ServerT api m
route Proxy context subserver = WithRequest $ \req -> route Proxy context subserver =
route (Proxy :: Proxy api) context (passToServer subserver $ vault req) route (Proxy :: Proxy api) context (passToServer subserver vault)
instance HasServer api context => HasServer (HttpVersion :> api) context where instance HasServer api context => HasServer (HttpVersion :> api) context where
type ServerT (HttpVersion :> api) m = HttpVersion -> ServerT api m type ServerT (HttpVersion :> api) m = HttpVersion -> ServerT api m
route Proxy context subserver = WithRequest $ \req -> route Proxy context subserver =
route (Proxy :: Proxy api) context (passToServer subserver $ httpVersion req) route (Proxy :: Proxy api) context (passToServer subserver httpVersion)
-- | Basic Authentication -- | Basic Authentication
instance ( KnownSymbol realm instance ( KnownSymbol realm
@ -450,12 +455,12 @@ instance ( KnownSymbol realm
type ServerT (BasicAuth realm usr :> api) m = usr -> ServerT api m type ServerT (BasicAuth realm usr :> api) m = usr -> ServerT api m
route Proxy context subserver = WithRequest $ \ request -> route Proxy context subserver =
route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request) route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck)
where where
realm = BC8.pack $ symbolVal (Proxy :: Proxy realm) realm = BC8.pack $ symbolVal (Proxy :: Proxy realm)
basicAuthContext = getContextEntry context basicAuthContext = getContextEntry context
authCheck req = runBasicAuth req realm basicAuthContext authCheck = withRequest $ \ req -> runBasicAuth req realm basicAuthContext
-- * helpers -- * helpers

View file

@ -6,6 +6,7 @@
module Servant.Server.Internal.BasicAuth where module Servant.Server.Internal.BasicAuth where
import Control.Monad (guard) import Control.Monad (guard)
import Control.Monad.Trans (liftIO)
import qualified Data.ByteString as BS import qualified Data.ByteString as BS
import Data.ByteString.Base64 (decodeLenient) import Data.ByteString.Base64 (decodeLenient)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
@ -57,13 +58,13 @@ decodeBAHdr req = do
-- | Run and check basic authentication, returning the appropriate http error per -- | Run and check basic authentication, returning the appropriate http error per
-- the spec. -- the spec.
runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> IO (RouteResult usr) runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr
runBasicAuth req realm (BasicAuthCheck ba) = runBasicAuth req realm (BasicAuthCheck ba) =
case decodeBAHdr req of case decodeBAHdr req of
Nothing -> plzAuthenticate Nothing -> plzAuthenticate
Just e -> ba e >>= \res -> case res of Just e -> liftIO (ba e) >>= \res -> case res of
BadPassword -> plzAuthenticate BadPassword -> plzAuthenticate
NoSuchUser -> plzAuthenticate NoSuchUser -> plzAuthenticate
Unauthorized -> return $ FailFatal err403 Unauthorized -> delayedFailFatal err403
Authorized usr -> return $ Route usr Authorized usr -> return usr
where plzAuthenticate = return $ FailFatal err401 { errHeaders = [mkBAChallengerHdr realm] } where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] }

View file

@ -1,5 +1,7 @@
{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module Servant.Server.Internal.Router where module Servant.Server.Internal.Router where
@ -8,36 +10,41 @@ import qualified Data.Map as M
import Data.Monoid import Data.Monoid
import Data.Text (Text) import Data.Text (Text)
import qualified Data.Text as T import qualified Data.Text as T
import Network.Wai (Request, Response, pathInfo) import Network.Wai (Response, pathInfo)
import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.RoutingApplication
import Servant.Server.Internal.ServantErr import Servant.Server.Internal.ServantErr
type Router = Router' RoutingApplication type Router env = Router' env RoutingApplication
-- | Internal representation of a router. -- | Internal representation of a router.
data Router' a = --
WithRequest (Request -> Router' a) -- The first argument describes an environment type that is
-- ^ current request is passed to the router -- expected as extra input by the routers at the leaves. The
| StaticRouter (Map Text (Router' a)) [a] -- environment is filled while running the router, with path
-- components that can be used to process captures.
--
data Router' env a =
StaticRouter (Map Text (Router' env a)) [env -> a]
-- ^ the map contains routers for subpaths (first path component used -- ^ the map contains routers for subpaths (first path component used
-- for lookup and removed afterwards), the list contains handlers -- for lookup and removed afterwards), the list contains handlers
-- for the empty path, to be tried in order -- for the empty path, to be tried in order
| DynamicRouter (Text -> Router' a) | CaptureRouter (Router' (Text, env) a)
-- ^ first path component passed to the function and removed afterwards -- ^ first path component is passed to the child router in its
| RawRouter a -- environment and removed afterwards
| RawRouter (env -> a)
-- ^ to be used for routes we do not know anything about -- ^ to be used for routes we do not know anything about
| Choice (Router' a) (Router' a) | Choice (Router' env a) (Router' env a)
-- ^ left-biased choice between two routers -- ^ left-biased choice between two routers
deriving Functor deriving Functor
-- | Smart constructor for a single static path component. -- | Smart constructor for a single static path component.
pathRouter :: Text -> Router' a -> Router' a pathRouter :: Text -> Router' env a -> Router' env a
pathRouter t r = StaticRouter (M.singleton t r) [] pathRouter t r = StaticRouter (M.singleton t r) []
-- | Smart constructor for a leaf, i.e., a router that expects -- | Smart constructor for a leaf, i.e., a router that expects
-- the empty path. -- the empty path.
-- --
leafRouter :: a -> Router' a leafRouter :: (env -> a) -> Router' env a
leafRouter l = StaticRouter M.empty [l] leafRouter l = StaticRouter M.empty [l]
-- | Smart constructor for the choice between routers. -- | Smart constructor for the choice between routers.
@ -46,40 +53,27 @@ leafRouter l = StaticRouter M.empty [l]
-- * Two static routers can be joined by joining their maps -- * Two static routers can be joined by joining their maps
-- and concatenating their leaf-lists. -- and concatenating their leaf-lists.
-- * Two dynamic routers can be joined by joining their codomains. -- * Two dynamic routers can be joined by joining their codomains.
-- * Two 'WithRequest' routers can be joined by passing them
-- the same request and joining their codomains.
-- * A 'WithRequest' router can be joined with anything else by
-- passing the same request to both but ignoring it in the
-- component that does not need it.
-- * Choice nodes can be reordered. -- * Choice nodes can be reordered.
-- --
choice :: Router -> Router -> Router choice :: Router' env a -> Router' env a -> Router' env a
choice (StaticRouter table1 ls1) (StaticRouter table2 ls2) = choice (StaticRouter table1 ls1) (StaticRouter table2 ls2) =
StaticRouter (M.unionWith choice table1 table2) (ls1 ++ ls2) StaticRouter (M.unionWith choice table1 table2) (ls1 ++ ls2)
choice (DynamicRouter fun1) (DynamicRouter fun2) = choice (CaptureRouter router1) (CaptureRouter router2) =
DynamicRouter (\ first -> choice (fun1 first) (fun2 first)) CaptureRouter (choice router1 router2)
choice (WithRequest router1) (WithRequest router2) =
WithRequest (\ request -> choice (router1 request) (router2 request))
choice (WithRequest router1) router2 =
WithRequest (\ request -> choice (router1 request) router2)
choice router1 (WithRequest router2) =
WithRequest (\ request -> choice router1 (router2 request))
choice router1 (Choice router2 router3) = Choice (choice router1 router2) router3 choice router1 (Choice router2 router3) = Choice (choice router1 router2) router3
choice router1 router2 = Choice router1 router2 choice router1 router2 = Choice router1 router2
-- | Datatype used for representing and debugging the -- | Datatype used for representing and debugging the
-- structure of a router. Abstracts from the functions -- structure of a router. Abstracts from the handlers
-- being used in the actual router and the handlers at -- at the leaves.
-- the leaves.
-- --
-- Two 'Router's can be structurally compared by computing -- Two 'Router's can be structurally compared by computing
-- their 'RouterStructure' using 'routerStructure' and -- their 'RouterStructure' using 'routerStructure' and
-- then testing for equality, see 'sameStructure'. -- then testing for equality, see 'sameStructure'.
-- --
data RouterStructure = data RouterStructure =
WithRequestStructure RouterStructure StaticRouterStructure (Map Text RouterStructure) Int
| StaticRouterStructure (Map Text RouterStructure) Int | CaptureRouterStructure RouterStructure
| DynamicRouterStructure RouterStructure
| RawRouterStructure | RawRouterStructure
| ChoiceStructure RouterStructure RouterStructure | ChoiceStructure RouterStructure RouterStructure
deriving (Eq, Show) deriving (Eq, Show)
@ -87,18 +81,15 @@ data RouterStructure =
-- | Compute the structure of a router. -- | Compute the structure of a router.
-- --
-- Assumes that the request or text being passed -- Assumes that the request or text being passed
-- in 'WithRequest' or 'DynamicRouter' does not -- in 'WithRequest' or 'CaptureRouter' does not
-- affect the structure of the underlying tree. -- affect the structure of the underlying tree.
-- --
routerStructure :: Router' a -> RouterStructure routerStructure :: Router' env a -> RouterStructure
routerStructure (WithRequest f) =
WithRequestStructure $
routerStructure (f (error "routerStructure: dummy request"))
routerStructure (StaticRouter m ls) = routerStructure (StaticRouter m ls) =
StaticRouterStructure (fmap routerStructure m) (length ls) StaticRouterStructure (fmap routerStructure m) (length ls)
routerStructure (DynamicRouter f) = routerStructure (CaptureRouter router) =
DynamicRouterStructure $ CaptureRouterStructure $
routerStructure (f (error "routerStructure: dummy text")) routerStructure router
routerStructure (RawRouter _) = routerStructure (RawRouter _) =
RawRouterStructure RawRouterStructure
routerStructure (Choice r1 r2) = routerStructure (Choice r1 r2) =
@ -108,21 +99,20 @@ routerStructure (Choice r1 r2) =
-- | Compare the structure of two routers. -- | Compare the structure of two routers.
-- --
sameStructure :: Router' a -> Router' b -> Bool sameStructure :: Router' env a -> Router' env b -> Bool
sameStructure r1 r2 = sameStructure r1 r2 =
routerStructure r1 == routerStructure r2 routerStructure r1 == routerStructure r2
-- | Provide a textual representation of the -- | Provide a textual representation of the
-- structure of a router. -- structure of a router.
-- --
routerLayout :: Router' a -> Text routerLayout :: Router' env a -> Text
routerLayout router = routerLayout router =
T.unlines (["/"] ++ mkRouterLayout False (routerStructure router)) T.unlines (["/"] ++ mkRouterLayout False (routerStructure router))
where where
mkRouterLayout :: Bool -> RouterStructure -> [Text] mkRouterLayout :: Bool -> RouterStructure -> [Text]
mkRouterLayout c (WithRequestStructure r) = mkRouterLayout c r
mkRouterLayout c (StaticRouterStructure m n) = mkSubTrees c (M.toList m) n mkRouterLayout c (StaticRouterStructure m n) = mkSubTrees c (M.toList m) n
mkRouterLayout c (DynamicRouterStructure r) = mkSubTree c "<dyn>" (mkRouterLayout False r) mkRouterLayout c (CaptureRouterStructure r) = mkSubTree c "<capture>" (mkRouterLayout False r)
mkRouterLayout c RawRouterStructure = mkRouterLayout c RawRouterStructure =
if c then ["├─ <raw>"] else ["└─ <raw>"] if c then ["├─ <raw>"] else ["└─ <raw>"]
mkRouterLayout c (ChoiceStructure r1 r2) = mkRouterLayout c (ChoiceStructure r1 r2) =
@ -146,47 +136,54 @@ routerLayout router =
mkSubTree False path children = ("└─ " <> path <> "/") : map (" " <>) children mkSubTree False path children = ("└─ " <> path <> "/") : map (" " <>) children
-- | Apply a transformation to the response of a `Router`. -- | Apply a transformation to the response of a `Router`.
tweakResponse :: (RouteResult Response -> RouteResult Response) -> Router -> Router tweakResponse :: (RouteResult Response -> RouteResult Response) -> Router env -> Router env
tweakResponse f = fmap (\a -> \req cont -> a req (cont . f)) tweakResponse f = fmap (\a -> \req cont -> a req (cont . f))
-- | Interpret a router as an application. -- | Interpret a router as an application.
runRouter :: Router -> RoutingApplication runRouter :: Router () -> RoutingApplication
runRouter (WithRequest router) request respond = runRouter r = runRouterEnv r ()
runRouter (router request) request respond
runRouter (StaticRouter table ls) request respond = runRouterEnv :: Router env -> env -> RoutingApplication
case pathInfo request of runRouterEnv router env request respond =
[] -> runChoice ls request respond case router of
-- This case is to handle trailing slashes. StaticRouter table ls ->
[""] -> runChoice ls request respond case pathInfo request of
first : rest | Just router <- M.lookup first table [] -> runChoice ls env request respond
-> let request' = request { pathInfo = rest } -- This case is to handle trailing slashes.
in runRouter router request' respond [""] -> runChoice ls env request respond
_ -> respond $ Fail err404 first : rest | Just router' <- M.lookup first table
runRouter (DynamicRouter fun) request respond = -> let request' = request { pathInfo = rest }
case pathInfo request of in runRouterEnv router' env request' respond
[] -> respond $ Fail err404 _ -> respond $ Fail err404
-- This case is to handle trailing slashes. CaptureRouter router' ->
[""] -> respond $ Fail err404 case pathInfo request of
first : rest [] -> respond $ Fail err404
-> let request' = request { pathInfo = rest } -- This case is to handle trailing slashes.
in runRouter (fun first) request' respond [""] -> respond $ Fail err404
runRouter (RawRouter app) request respond = app request respond first : rest
runRouter (Choice r1 r2) request respond = -> let request' = request { pathInfo = rest }
runChoice [runRouter r1, runRouter r2] request respond in runRouterEnv router' (first, env) request' respond
RawRouter app ->
app env request respond
Choice r1 r2 ->
runChoice [runRouterEnv r1, runRouterEnv r2] env request respond
-- | Try a list of routing applications in order. -- | Try a list of routing applications in order.
-- We stop as soon as one fails fatally or succeeds. -- We stop as soon as one fails fatally or succeeds.
-- If all fail normally, we pick the "best" error. -- If all fail normally, we pick the "best" error.
-- --
runChoice :: [RoutingApplication] -> RoutingApplication runChoice :: [env -> RoutingApplication] -> env -> RoutingApplication
runChoice [] _request respond = respond (Fail err404) runChoice ls =
runChoice [r] request respond = r request respond case ls of
runChoice (r : rs) request respond = [] -> \ _ _ respond -> respond (Fail err404)
r request $ \ response1 -> [r] -> r
case response1 of (r : rs) ->
Fail _ -> runChoice rs request $ \ response2 -> \ env request respond ->
respond $ highestPri response1 response2 r env request $ \ response1 ->
_ -> respond response1 case response1 of
Fail _ -> runChoice rs env request $ \ response2 ->
respond $ highestPri response1 response2
_ -> respond response1
where where
highestPri (Fail e1) (Fail e2) = highestPri (Fail e1) (Fail e2) =
if worseHTTPCode (errHTTPCode e1) (errHTTPCode e2) if worseHTTPCode (errHTTPCode e1) (errHTTPCode e2)

View file

@ -8,7 +8,10 @@
{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneDeriving #-}
module Servant.Server.Internal.RoutingApplication where module Servant.Server.Internal.RoutingApplication where
import Control.Monad (ap, liftM)
import Control.Monad.Trans (MonadIO(..))
import Control.Monad.Trans.Except (runExceptT) import Control.Monad.Trans.Except (runExceptT)
import Data.Text (Text)
import Network.Wai (Application, Request, import Network.Wai (Application, Request,
Response, ResponseReceived) Response, ResponseReceived)
import Prelude () import Prelude ()
@ -95,113 +98,133 @@ toApplication ra request respond = ra request routingRespond
-- The accept header check can be performed as the final -- The accept header check can be performed as the final
-- computation in this block. It can cause a 406. -- computation in this block. It can cause a 406.
-- --
data Delayed c where data Delayed env c where
Delayed :: { capturesD :: IO (RouteResult captures) Delayed :: { capturesD :: env -> DelayedIO captures
, methodD :: IO (RouteResult ()) , methodD :: DelayedIO ()
, authD :: IO (RouteResult auth) , authD :: DelayedIO auth
, bodyD :: IO (RouteResult body) , bodyD :: DelayedIO body
, serverD :: (captures -> auth -> body -> RouteResult c) , serverD :: captures -> auth -> body -> Request -> RouteResult c
} -> Delayed c } -> Delayed env c
instance Functor Delayed where instance Functor (Delayed env) where
fmap f Delayed{..} fmap f Delayed{..} =
= Delayed { capturesD = capturesD Delayed
, methodD = methodD { serverD = \ c a b req -> f <$> serverD c a b req
, authD = authD , ..
, bodyD = bodyD } -- Note [Existential Record Update]
, serverD = (fmap.fmap.fmap.fmap) f serverD
} -- Note [Existential Record Update] -- | Computations used in a 'Delayed' can depend on the
-- incoming 'Request', may perform 'IO, and result in a
-- 'RouteResult, meaning they can either suceed, fail
-- (with the possibility to recover), or fail fatally.
--
newtype DelayedIO a = DelayedIO { runDelayedIO :: Request -> IO (RouteResult a) }
instance Functor DelayedIO where
fmap = liftM
instance Applicative DelayedIO where
pure = return
(<*>) = ap
instance Monad DelayedIO where
return x = DelayedIO (const $ return (Route x))
DelayedIO m >>= f =
DelayedIO $ \ req -> do
r <- m req
case r of
Fail e -> return $ Fail e
FailFatal e -> return $ FailFatal e
Route a -> runDelayedIO (f a) req
instance MonadIO DelayedIO where
liftIO m = DelayedIO (const $ Route <$> m)
-- | A 'Delayed' without any stored checks.
emptyDelayed :: RouteResult a -> Delayed env a
emptyDelayed result =
Delayed (const r) r r r (\ _ _ _ _ -> result)
where
r = return ()
-- | Fail with the option to recover.
delayedFail :: ServantErr -> DelayedIO a
delayedFail err = DelayedIO (const $ return $ Fail err)
-- | Fail fatally, i.e., without any option to recover.
delayedFailFatal :: ServantErr -> DelayedIO a
delayedFailFatal err = DelayedIO (const $ return $ FailFatal err)
-- | Gain access to the incoming request.
withRequest :: (Request -> DelayedIO a) -> DelayedIO a
withRequest f = DelayedIO (\ req -> runDelayedIO (f req) req)
-- | Add a capture to the end of the capture block. -- | Add a capture to the end of the capture block.
addCapture :: Delayed (a -> b) addCapture :: Delayed env (a -> b)
-> IO (RouteResult a) -> (Text -> DelayedIO a)
-> Delayed b -> Delayed (Text, env) b
addCapture Delayed{..} new addCapture Delayed{..} new =
= Delayed { capturesD = combineRouteResults (,) capturesD new Delayed
, methodD = methodD { capturesD = \ (txt, env) -> (,) <$> capturesD env <*> new txt
, authD = authD , serverD = \ (x, v) a b req -> ($ v) <$> serverD x a b req
, bodyD = bodyD , ..
, serverD = \ (x, v) y z -> ($ v) <$> serverD x y z } -- Note [Existential Record Update]
} -- Note [Existential Record Update]
-- | Add a method check to the end of the method block. -- | Add a method check to the end of the method block.
addMethodCheck :: Delayed a addMethodCheck :: Delayed env a
-> IO (RouteResult ()) -> DelayedIO ()
-> Delayed a -> Delayed env a
addMethodCheck Delayed{..} new addMethodCheck Delayed{..} new =
= Delayed { capturesD = capturesD Delayed
, methodD = combineRouteResults const methodD new { methodD = methodD <* new
, authD = authD , ..
, bodyD = bodyD } -- Note [Existential Record Update]
, serverD = serverD
} -- Note [Existential Record Update]
-- | Add an auth check to the end of the auth block. -- | Add an auth check to the end of the auth block.
addAuthCheck :: Delayed (a -> b) addAuthCheck :: Delayed env (a -> b)
-> IO (RouteResult a) -> DelayedIO a
-> Delayed b -> Delayed env b
addAuthCheck Delayed{..} new addAuthCheck Delayed{..} new =
= Delayed { capturesD = capturesD Delayed
, methodD = methodD { authD = (,) <$> authD <*> new
, authD = combineRouteResults (,) authD new , serverD = \ c (y, v) b req -> ($ v) <$> serverD c y b req
, bodyD = bodyD , ..
, serverD = \ x (y, v) z -> ($ v) <$> serverD x y z } -- Note [Existential Record Update]
} -- Note [Existential Record Update]
-- | Add a body check to the end of the body block. -- | Add a body check to the end of the body block.
addBodyCheck :: Delayed (a -> b) addBodyCheck :: Delayed env (a -> b)
-> IO (RouteResult a) -> DelayedIO a
-> Delayed b -> Delayed env b
addBodyCheck Delayed{..} new addBodyCheck Delayed{..} new =
= Delayed { capturesD = capturesD Delayed
, methodD = methodD { bodyD = (,) <$> bodyD <*> new
, authD = authD , serverD = \ c a (z, v) req -> ($ v) <$> serverD c a z req
, bodyD = combineRouteResults (,) bodyD new , ..
, serverD = \ x y (z, v) -> ($ v) <$> serverD x y z } -- Note [Existential Record Update]
} -- Note [Existential Record Update]
-- | Add an accept header check to the end of the body block. -- | Add an accept header check to the end of the body block.
-- The accept header check should occur after the body check, -- The accept header check should occur after the body check,
-- but this will be the case, because the accept header check -- but this will be the case, because the accept header check
-- is only scheduled by the method combinators. -- is only scheduled by the method combinators.
addAcceptCheck :: Delayed a addAcceptCheck :: Delayed env a
-> IO (RouteResult ()) -> DelayedIO ()
-> Delayed a -> Delayed env a
addAcceptCheck Delayed{..} new addAcceptCheck Delayed{..} new =
= Delayed { capturesD = capturesD Delayed
, methodD = methodD { bodyD = bodyD <* new
, authD = authD , ..
, bodyD = combineRouteResults const bodyD new } -- Note [Existential Record Update]
, serverD = serverD
} -- Note [Existential Record Update]
-- | Many combinators extract information that is passed to -- | Many combinators extract information that is passed to
-- the handler without the possibility of failure. In such a -- the handler without the possibility of failure. In such a
-- case, 'passToServer' can be used. -- case, 'passToServer' can be used.
passToServer :: Delayed (a -> b) -> a -> Delayed b passToServer :: Delayed env (a -> b) -> (Request -> a) -> Delayed env b
passToServer d x = ($ x) <$> d passToServer Delayed{..} x =
Delayed
-- | The combination 'IO . RouteResult' is a monad, but we { serverD = \ c a b req -> ($ x req) <$> serverD c a b req
-- don't explicitly wrap it in a newtype in order to make it , ..
-- an instance. This is the '>>=' of that monad. } -- Note [Existential Record Update]
--
-- We stop on the first error.
bindRouteResults :: IO (RouteResult a) -> (a -> IO (RouteResult b)) -> IO (RouteResult b)
bindRouteResults m f = do
r <- m
case r of
Fail e -> return $ Fail e
FailFatal e -> return $ FailFatal e
Route a -> f a
-- | Common special case of 'bindRouteResults', corresponding
-- to 'liftM2'.
combineRouteResults :: (a -> b -> c) -> IO (RouteResult a) -> IO (RouteResult b) -> IO (RouteResult c)
combineRouteResults f m1 m2 =
m1 `bindRouteResults` \ a ->
m2 `bindRouteResults` \ b ->
return (Route (f a b))
-- | Run a delayed server. Performs all scheduled operations -- | Run a delayed server. Performs all scheduled operations
-- in order, and passes the results from the capture and body -- in order, and passes the results from the capture and body
@ -209,24 +232,29 @@ combineRouteResults f m1 m2 =
-- --
-- This should only be called once per request; otherwise the guarantees about -- This should only be called once per request; otherwise the guarantees about
-- effect and HTTP error ordering break down. -- effect and HTTP error ordering break down.
runDelayed :: Delayed a runDelayed :: Delayed env a
-> env
-> Request
-> IO (RouteResult a) -> IO (RouteResult a)
runDelayed Delayed{..} = runDelayed Delayed{..} env = runDelayedIO $ do
capturesD `bindRouteResults` \ c -> c <- capturesD env
methodD `bindRouteResults` \ _ -> methodD
authD `bindRouteResults` \ a -> a <- authD
bodyD `bindRouteResults` \ b -> b <- bodyD
return (serverD c a b) DelayedIO (\ req -> return $ serverD c a b req)
-- | Runs a delayed server and the resulting action. -- | Runs a delayed server and the resulting action.
-- Takes a continuation that lets us send a response. -- Takes a continuation that lets us send a response.
-- Also takes a continuation for how to turn the -- Also takes a continuation for how to turn the
-- result of the delayed server into a response. -- result of the delayed server into a response.
runAction :: Delayed (Handler a) runAction :: Delayed env (Handler a)
-> env
-> Request
-> (RouteResult Response -> IO r) -> (RouteResult Response -> IO r)
-> (a -> RouteResult Response) -> (a -> RouteResult Response)
-> IO r -> IO r
runAction action respond k = runDelayed action >>= go >>= respond runAction action env req respond k =
runDelayed action env req >>= go >>= respond
where where
go (Fail e) = return $ Fail e go (Fail e) = return $ Fail e
go (FailFatal e) = return $ FailFatal e go (FailFatal e) = return $ FailFatal e

View file

@ -25,9 +25,9 @@ routerSpec = do
let app' :: Application let app' :: Application
app' = toApplication $ runRouter router' app' = toApplication $ runRouter router'
router', router :: Router router', router :: Router ()
router' = tweakResponse (fmap twk) router router' = tweakResponse (fmap twk) router
router = leafRouter $ \_ cont -> cont (Route $ responseBuilder (Status 201 "") [] "") router = leafRouter $ \_ _ cont -> cont (Route $ responseBuilder (Status 201 "") [] "")
twk :: Response -> Response twk :: Response -> Response
twk (ResponseBuilder (Status i s) hs b) = ResponseBuilder (Status (i + 1) s) hs b twk (ResponseBuilder (Status i s) hs b) = ResponseBuilder (Status (i + 1) s) hs b
@ -69,11 +69,9 @@ shouldHaveSameStructureAs p1 p2 =
unless (sameStructure (makeTrivialRouter p1) (makeTrivialRouter p2)) $ unless (sameStructure (makeTrivialRouter p1) (makeTrivialRouter p2)) $
expectationFailure ("expected:\n" ++ unpack (layout p2) ++ "\nbut got:\n" ++ unpack (layout p1)) expectationFailure ("expected:\n" ++ unpack (layout p2) ++ "\nbut got:\n" ++ unpack (layout p1))
makeTrivialRouter :: (HasServer layout '[]) => Proxy layout -> Router makeTrivialRouter :: (HasServer layout '[]) => Proxy layout -> Router ()
makeTrivialRouter p = route p EmptyContext d makeTrivialRouter p =
where route p EmptyContext (emptyDelayed (FailFatal err501))
d = Delayed r r r r (\ _ _ _ -> FailFatal err501)
r = return (Route ())
type End = Get '[JSON] () type End = Get '[JSON] ()

View file

@ -20,7 +20,6 @@ module Servant.Server.UsingContextSpec.TestCombinators where
import GHC.TypeLits import GHC.TypeLits
import Servant import Servant
import Servant.Server.Internal.RoutingApplication
data ExtractFromContext data ExtractFromContext
@ -31,7 +30,7 @@ instance (HasContextEntry context String, HasServer subApi context) =>
String -> ServerT subApi m String -> ServerT subApi m
route Proxy context delayed = route Proxy context delayed =
route subProxy context (fmap (inject context) delayed :: Delayed (Server subApi)) route subProxy context (fmap (inject context) delayed)
where where
subProxy :: Proxy subApi subProxy :: Proxy subApi
subProxy = Proxy subProxy = Proxy

View file

@ -48,7 +48,7 @@ import Servant.API ((:<|>) (..), (:>), AuthProtect,
Raw, RemoteHost, ReqBody, Raw, RemoteHost, ReqBody,
StdMethod (..), Verb, addHeader) StdMethod (..), Verb, addHeader)
import Servant.API.Internal.Test.ComprehensiveAPI import Servant.API.Internal.Test.ComprehensiveAPI
import Servant.Server (ServantErr (..), Server, Handler, err401, err403, import Servant.Server (Server, Handler, err401, err403,
err404, serve, serveWithContext, err404, serve, serveWithContext,
Context((:.), EmptyContext)) Context((:.), EmptyContext))
import Test.Hspec (Spec, context, describe, it, import Test.Hspec (Spec, context, describe, it,