diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index f354a1e8..84743972 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -46,6 +46,7 @@ library Servant.Server.Internal.DelayedIO Servant.Server.Internal.ErrorFormatter Servant.Server.Internal.Handler + Servant.Server.Internal.RouterEnv Servant.Server.Internal.RouteResult Servant.Server.Internal.Router Servant.Server.Internal.RoutingApplication diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index a2b4f033..5d2800fe 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -26,6 +26,7 @@ module Servant.Server.Internal , module Servant.Server.Internal.ErrorFormatter , module Servant.Server.Internal.Handler , module Servant.Server.Internal.Router + , module Servant.Server.Internal.RouterEnv , module Servant.Server.Internal.RouteResult , module Servant.Server.Internal.RoutingApplication , module Servant.Server.Internal.ServerError @@ -76,7 +77,7 @@ import Servant.API QueryParam', QueryParams, Raw, ReflectMethod (reflectMethod), RemoteHost, ReqBody', SBool (..), SBoolI (..), SourceIO, Stream, StreamBody', Summary, ToSourceIO (..), Vault, Verb, - WithNamedContext, NamedRoutes) + WithNamedContext, WithRoutingHeader, NamedRoutes) import Servant.API.Generic (GenericMode(..), ToServant, ToServantApi, GServantProduct, toServant, fromServant) import Servant.API.ContentTypes (AcceptHeader (..), AllCTRender (..), AllCTUnrender (..), @@ -103,6 +104,7 @@ import Servant.Server.Internal.ErrorFormatter import Servant.Server.Internal.Handler import Servant.Server.Internal.Router import Servant.Server.Internal.RouteResult +import Servant.Server.Internal.RouterEnv import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.ServerError @@ -241,6 +243,23 @@ instance (KnownSymbol capture, FromHttpApiData a, Typeable a formatError = urlParseErrorFormatter $ getContextEntry (mkContextWithErrorFormatter context) hint = CaptureHint (T.pack $ symbolVal $ Proxy @capture) (typeRep (Proxy :: Proxy [a])) +-- | Using 'WithRoutingHeaders' in one of the endpoints for your API, +-- will automatically add routing headers to the response generated by the server. +-- +-- @since 0.20 +-- +instance ( HasServer api context + , HasContextEntry (MkContextWithErrorFormatter context) ErrorFormatters + ) + => HasServer (WithRoutingHeader :> api) context where + + type ServerT (WithRoutingHeader :> api) m = ServerT api m + + hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt s + + route _ context d = + EnvRouter enableRoutingHeaders $ route (Proxy :: Proxy api) context d + allowedMethodHead :: Method -> Request -> Bool allowedMethodHead method request = method == methodGet && requestMethod request == methodHead @@ -292,7 +311,10 @@ noContentRouter method status action = leafRouter route' route' env request respond = runAction (action `addMethodCheck` methodCheck method request) env request respond $ \ _output -> - Route $ responseLBS status [] "" + let headers = if (shouldReturnRoutedPath env) + then [(hRoutedPathHeader, cs $ routedPathRepr env)] + else [] + in Route $ responseLBS status headers "" instance {-# OVERLAPPABLE #-} ( AllCTRender ctypes a, ReflectMethod method, KnownNat status diff --git a/servant-server/src/Servant/Server/Internal/Delayed.hs b/servant-server/src/Servant/Server/Internal/Delayed.hs index 3ba89574..029d95ca 100644 --- a/servant-server/src/Servant/Server/Internal/Delayed.hs +++ b/servant-server/src/Servant/Server/Internal/Delayed.hs @@ -14,11 +14,15 @@ import Control.Monad.Reader (ask) import Control.Monad.Trans.Resource (ResourceT, runResourceT) +import Data.String.Conversions + (cs) import Network.Wai - (Request, Response) + (Request, Response, mapResponseHeaders) import Servant.Server.Internal.DelayedIO import Servant.Server.Internal.Handler +import Servant.Server.Internal.RouterEnv + (RouterEnv (..), hRoutedPathHeader, routedPathRepr) import Servant.Server.Internal.RouteResult import Servant.Server.Internal.ServerError @@ -228,12 +232,12 @@ passToServer Delayed{..} x = -- This should only be called once per request; otherwise the guarantees about -- effect and HTTP error ordering break down. runDelayed :: Delayed env a - -> env + -> RouterEnv env -> Request -> ResourceT IO (RouteResult a) runDelayed Delayed{..} env = runDelayedIO $ do r <- ask - c <- capturesD env + c <- capturesD $ routerEnv env methodD a <- authD acceptD @@ -248,7 +252,7 @@ runDelayed Delayed{..} env = runDelayedIO $ do -- Also takes a continuation for how to turn the -- result of the delayed server into a response. runAction :: Delayed env (Handler a) - -> env + -> RouterEnv env -> Request -> (RouteResult Response -> IO r) -> (a -> RouteResult Response) @@ -261,8 +265,12 @@ runAction action env req respond k = runResourceT $ go (Route a) = liftIO $ do e <- runHandler a case e of - Left err -> return . Route $ responseServerError err - Right x -> return $! k x + Left err -> return . Route . withRoutingHeaders $ responseServerError err + Right x -> return $! withRoutingHeaders <$> k x + withRoutingHeaders :: Response -> Response + withRoutingHeaders = if shouldReturnRoutedPath env + then mapResponseHeaders ((hRoutedPathHeader, cs $ routedPathRepr env) :) + else id {- Note [Existential Record Update] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/servant-server/src/Servant/Server/Internal/Router.hs b/servant-server/src/Servant/Server/Internal/Router.hs index 0a3391ce..794ab400 100644 --- a/servant-server/src/Servant/Server/Internal/Router.hs +++ b/servant-server/src/Servant/Server/Internal/Router.hs @@ -2,6 +2,7 @@ {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TupleSections #-} module Servant.Server.Internal.Router where import Prelude () @@ -17,29 +18,16 @@ import qualified Data.Map as M import Data.Text (Text) import qualified Data.Text as T -import Data.Typeable - (TypeRep) import Network.Wai (Response, pathInfo) import Servant.Server.Internal.ErrorFormatter +import Servant.Server.Internal.RouterEnv import Servant.Server.Internal.RouteResult import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.ServerError type Router env = Router' env RoutingApplication -data CaptureHint = CaptureHint - { captureName :: Text - , captureType :: TypeRep - } - deriving (Show, Eq) - -toCaptureTag :: CaptureHint -> Text -toCaptureTag hint = captureName hint <> "::" <> (T.pack . show) (captureType hint) - -toCaptureTags :: [CaptureHint] -> Text -toCaptureTags hints = "<" <> T.intercalate "|" (map toCaptureTag hints) <> ">" - -- | Internal representation of a router. -- -- The first argument describes an environment type that is @@ -48,7 +36,7 @@ toCaptureTags hints = "<" <> T.intercalate "|" (map toCaptureTag hints) <> ">" -- components that can be used to process captures. -- data Router' env a = - StaticRouter (Map Text (Router' env a)) [env -> a] + StaticRouter (Map Text (Router' env a)) [RouterEnv env -> a] -- ^ the map contains routers for subpaths (first path component used -- for lookup and removed afterwards), the list contains handlers -- for the empty path, to be tried in order @@ -58,10 +46,13 @@ data Router' env a = | CaptureAllRouter [CaptureHint] (Router' ([Text], env) a) -- ^ all path components are passed to the child router in its -- environment and are removed afterwards - | RawRouter (env -> a) + | RawRouter (RouterEnv env -> a) -- ^ to be used for routes we do not know anything about | Choice (Router' env a) (Router' env a) -- ^ left-biased choice between two routers + | EnvRouter (RouterEnv env -> RouterEnv env) (Router' env a) + -- ^ modifies the environment, and passes it to the child router + -- @since 0.20 deriving Functor -- | Smart constructor for a single static path component. @@ -71,7 +62,7 @@ pathRouter t r = StaticRouter (M.singleton t r) [] -- | Smart constructor for a leaf, i.e., a router that expects -- the empty path. -- -leafRouter :: (env -> a) -> Router' env a +leafRouter :: (RouterEnv env -> a) -> Router' env a leafRouter l = StaticRouter M.empty [l] -- | Smart constructor for the choice between routers. @@ -126,6 +117,7 @@ routerStructure (Choice r1 r2) = ChoiceStructure (routerStructure r1) (routerStructure r2) +routerStructure (EnvRouter _ r) = routerStructure r -- | Compare the structure of two routers. -- @@ -172,9 +164,9 @@ tweakResponse f = fmap (\a -> \req cont -> a req (cont . f)) -- | Interpret a router as an application. runRouter :: NotFoundErrorFormatter -> Router () -> RoutingApplication -runRouter fmt r = runRouterEnv fmt r () +runRouter fmt r = runRouterEnv fmt r $ emptyEnv () -runRouterEnv :: NotFoundErrorFormatter -> Router env -> env -> RoutingApplication +runRouterEnv :: NotFoundErrorFormatter -> Router env -> RouterEnv env -> RoutingApplication runRouterEnv fmt router env request respond = case router of StaticRouter table ls -> @@ -184,24 +176,31 @@ runRouterEnv fmt router env request respond = [""] -> runChoice fmt ls env request respond first : rest | Just router' <- M.lookup first table -> let request' = request { pathInfo = rest } - in runRouterEnv fmt router' env request' respond + newEnv = appendPathPiece (StaticPiece first) env + in runRouterEnv fmt router' newEnv request' respond _ -> respond $ Fail $ fmt request - CaptureRouter _ router' -> + CaptureRouter hints router' -> case pathInfo request of [] -> respond $ Fail $ fmt request -- This case is to handle trailing slashes. [""] -> respond $ Fail $ fmt request first : rest -> let request' = request { pathInfo = rest } - in runRouterEnv fmt router' (first, env) request' respond - CaptureAllRouter _ router' -> + newEnv = appendPathPiece (CapturePiece hints) env + newEnv' = ((first,) <$> newEnv) + in runRouterEnv fmt router' newEnv' request' respond + CaptureAllRouter hints router' -> let segments = pathInfo request request' = request { pathInfo = [] } - in runRouterEnv fmt router' (segments, env) request' respond + newEnv = appendPathPiece (CapturePiece hints) env + newEnv' = ((segments,) <$> newEnv) + in runRouterEnv fmt router' newEnv' request' respond RawRouter app -> app env request respond Choice r1 r2 -> runChoice fmt [runRouterEnv fmt r1, runRouterEnv fmt r2] env request respond + EnvRouter f router' -> + runRouterEnv fmt router' (f env) request respond -- | Try a list of routing applications in order. -- We stop as soon as one fails fatally or succeeds. diff --git a/servant-server/src/Servant/Server/Internal/RouterEnv.hs b/servant-server/src/Servant/Server/Internal/RouterEnv.hs new file mode 100644 index 00000000..15c628fe --- /dev/null +++ b/servant-server/src/Servant/Server/Internal/RouterEnv.hs @@ -0,0 +1,65 @@ +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +-- | This module contains the `RouterEnv env` type and associated functions. +-- `RouterEnv env` encapsulates the `env` type (as in `Router env a`), +-- which contains a tuple-encoded list of url pieces parsed from the incoming request. +-- The encapsulation makes it possible to pass more information throughout +-- the routing process, and ultimately to the computation of the `Delayed env c` +-- associated with each request. +-- The type and functions have been designed to be extensible: it should remain easy +-- to add a new field to the record and manipulate it. +-- +-- @since 0.20 +-- +module Servant.Server.Internal.RouterEnv where + +import Data.Text + (Text) +import qualified Data.Text as T +import Data.Typeable + (TypeRep) +import Network.HTTP.Types.Header + (HeaderName) + +data RouterEnv env = RouterEnv + { routedPath :: [PathPiece] + , shouldReturnRoutedPath :: Bool + , routerEnv :: env + } + deriving Functor + +emptyEnv :: a -> RouterEnv a +emptyEnv v = RouterEnv [] False v + +enableRoutingHeaders :: RouterEnv env -> RouterEnv env +enableRoutingHeaders env = env { shouldReturnRoutedPath = True } + +routedPathRepr :: RouterEnv env -> Text +routedPathRepr RouterEnv{routedPath = path} = + "/" <> T.intercalate "/" (map go $ reverse path) + where + go (StaticPiece p) = p + go (CapturePiece p) = toCaptureTags p + +data PathPiece + = StaticPiece Text + | CapturePiece [CaptureHint] + +appendPathPiece :: PathPiece -> RouterEnv a -> RouterEnv a +appendPathPiece p env@RouterEnv{..} = env { routedPath = p:routedPath } + +data CaptureHint = CaptureHint + { captureName :: Text + , captureType :: TypeRep + } + deriving (Show, Eq) + +toCaptureTag :: CaptureHint -> Text +toCaptureTag hint = captureName hint <> "::" <> (T.pack . show) (captureType hint) + +toCaptureTags :: [CaptureHint] -> Text +toCaptureTags hints = "<" <> T.intercalate "|" (map toCaptureTag hints) <> ">" + +hRoutedPathHeader :: HeaderName +hRoutedPathHeader = "Servant-Routed-Path" diff --git a/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs index 04443c9d..87fab549 100644 --- a/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs +++ b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs @@ -80,7 +80,7 @@ delayed body srv = Delayed simpleRun :: Delayed () (Handler ()) -> IO () simpleRun d = fmap (either ignoreE id) . try $ - runAction d () defaultRequest (\_ -> return ()) (\_ -> FailFatal err500) + runAction d (emptyEnv ()) defaultRequest (\_ -> return ()) (\_ -> FailFatal err500) where ignoreE :: SomeException -> () ignoreE = const () diff --git a/servant/servant.cabal b/servant/servant.cabal index 32b63feb..d175bc37 100644 --- a/servant/servant.cabal +++ b/servant/servant.cabal @@ -38,6 +38,7 @@ library Servant.API.Capture Servant.API.ContentTypes Servant.API.Description + Servant.API.Environment Servant.API.Empty Servant.API.Experimental.Auth Servant.API.Fragment diff --git a/servant/src/Servant/API.hs b/servant/src/Servant/API.hs index 22309dce..2673dac4 100644 --- a/servant/src/Servant/API.hs +++ b/servant/src/Servant/API.hs @@ -7,6 +7,8 @@ module Servant.API ( -- | Type-level combinator for alternative endpoints: @':<|>'@ module Servant.API.Empty, -- | Type-level combinator for an empty API: @'EmptyAPI'@ + module Servant.API.Environment, + -- | Type-level combinators to modify the routing environment: @'WithRoutingHeader'@ module Servant.API.Modifiers, -- | Type-level modifiers for 'QueryParam', 'Header' and 'ReqBody'. @@ -97,6 +99,8 @@ import Servant.API.Description (Description, Summary) import Servant.API.Empty (EmptyAPI (..)) +import Servant.API.Environment + (WithRoutingHeader) import Servant.API.Experimental.Auth (AuthProtect) import Servant.API.Fragment diff --git a/servant/src/Servant/API/Environment.hs b/servant/src/Servant/API/Environment.hs new file mode 100644 index 00000000..08e477d7 --- /dev/null +++ b/servant/src/Servant/API/Environment.hs @@ -0,0 +1,29 @@ +{-# OPTIONS_HADDOCK not-home #-} +-- | Define API combinator that modify the behaviour of the routing environment. +module Servant.API.Environment (WithRoutingHeader) where + +-- | Modify the behaviour of the following sub-API, such that all endpoint of said API +-- return an additional routing header in their response. +-- A routing header is a header that specifies which endpoint the incoming request was +-- routed to. Endpoint are designated by their path, in which @Capture@ combinators are +-- replaced by a capture hint. +-- This header can be used by downstream middlewares to gather information about +-- individual endpoints, since in most cases a routing header uniquely identifies a +-- single endpoint. +-- +-- Example: +-- +-- >>> type MyApi = WithRoutingHeader :> "by-id" :> Capture "id" Int :> Get '[JSON] Foo +-- >>> -- GET /by-id/1234 will return a response with the following header: +-- >>> -- ("Servant-Routed-Path", "/by-id/") +-- +-- @since 0.20 +-- +data WithRoutingHeader + +-- $setup +-- >>> import Servant.API +-- >>> import Data.Aeson +-- >>> import Data.Text +-- >>> data Foo +-- >>> instance ToJSON Foo where { toJSON = undefined } diff --git a/servant/src/Servant/API/TypeLevel.hs b/servant/src/Servant/API/TypeLevel.hs index 4a5e3c3b..a75dc331 100644 --- a/servant/src/Servant/API/TypeLevel.hs +++ b/servant/src/Servant/API/TypeLevel.hs @@ -57,6 +57,8 @@ import Servant.API.Alternative (type (:<|>)) import Servant.API.Capture (Capture, CaptureAll) +import Servant.API.Environment + (WithRoutingHeader) import Servant.API.Fragment import Servant.API.Header (Header) @@ -130,6 +132,7 @@ type family IsElem endpoint api :: Constraint where IsElem e (sa :<|> sb) = Or (IsElem e sa) (IsElem e sb) IsElem (e :> sa) (e :> sb) = IsElem sa sb IsElem sa (Header sym x :> sb) = IsElem sa sb + IsElem sa (WithRoutingHeader :> sb) = IsElem sa sb IsElem sa (ReqBody y x :> sb) = IsElem sa sb IsElem (CaptureAll z y :> sa) (CaptureAll x y :> sb) = IsElem sa sb