diff --git a/servant-server/src/Servant/Server.hs b/servant-server/src/Servant/Server.hs index 58c34ea3..2a886683 100644 --- a/servant-server/src/Servant/Server.hs +++ b/servant-server/src/Servant/Server.hs @@ -33,7 +33,7 @@ module Servant.Server , embedNat , squashNat , generalizeNat - + , tweakResponse -- * Default error type , ServantErr(..) diff --git a/servant-server/src/Servant/Server/Internal/Router.hs b/servant-server/src/Servant/Server/Internal/Router.hs index f188955e..89f7c144 100644 --- a/servant-server/src/Servant/Server/Internal/Router.hs +++ b/servant-server/src/Servant/Server/Internal/Router.hs @@ -1,25 +1,34 @@ +{-# LANGUAGE DeriveFunctor #-} + module Servant.Server.Internal.Router where import Data.Map (Map) import qualified Data.Map as M import Data.Monoid ((<>)) import Data.Text (Text) -import Network.Wai (Request, pathInfo) +import Network.Wai (Request, Response, pathInfo) import Servant.Server.Internal.PathInfo import Servant.Server.Internal.RoutingApplication +type Router = Router' RoutingApplication + -- | Internal representation of a router. -data Router = +data Router' a = WithRequest (Request -> Router) -- ^ current request is passed to the router | StaticRouter (Map Text Router) -- ^ first path component used for lookup and removed afterwards | DynamicRouter (Text -> Router) -- ^ first path component used for lookup and removed afterwards - | LeafRouter RoutingApplication + | LeafRouter a -- ^ to be used for routes that match an empty path | Choice Router Router -- ^ left-biased choice between two routers + deriving Functor + +-- | Apply a transformation to the response of a `Router`. +tweakResponse :: (RouteResult Response -> RouteResult Response) -> Router -> Router +tweakResponse f = fmap (\a -> \req cont -> a req (cont . f)) -- | Smart constructor for the choice between routers. -- We currently optimize the following cases: @@ -69,4 +78,3 @@ runRouter (Choice r1 r2) request respond = then runRouter r2 request $ \ mResponse2 -> respond (mResponse1 <> mResponse2) else respond mResponse1 - diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index 45519e42..9c7e85c3 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -1,14 +1,17 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeSynonymInstances #-} -{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances #-} module Servant.ServerSpec where - +#if !MIN_VERSION_base(4,8,0) +import Control.Applicative ((<$>)) +#endif import Control.Monad (forM_, when) import Control.Monad.Trans.Except (ExceptT, throwE) import Data.Aeson (FromJSON, ToJSON, decode', encode) @@ -23,10 +26,12 @@ import GHC.Generics (Generic) import Network.HTTP.Types (hAccept, hContentType, methodDelete, methodGet, methodHead, methodPatch, methodPost, methodPut, - ok200, parseQuery, status409) + ok200, parseQuery, status409, + Status(..)) import Network.Wai (Application, Request, pathInfo, queryString, rawQueryString, - responseLBS) + responseLBS, responseBuilder) +import Network.Wai.Internal (Response(ResponseBuilder)) import Network.Wai.Test (defaultRequest, request, runSession, simpleBody) import Test.Hspec (Spec, describe, it, shouldBe) @@ -41,8 +46,12 @@ import Servant.API ((:<|>) (..), (:>), Post, Put, RemoteHost, QueryFlag, QueryParam, QueryParams, Raw, ReqBody) import Servant.Server (Server, serve, ServantErr(..), err404) +import Servant.Server.Internal.Router + (tweakResponse, runRouter, + Router, Router'(LeafRouter)) import Servant.Server.Internal.RoutingApplication - (RouteMismatch (..)) + (RouteResult(..), RouteMismatch(..), + toApplication) -- * test data types @@ -92,6 +101,7 @@ spec = do unionSpec prioErrorsSpec errorsSpec + routerSpec responseHeadersSpec miscReqCombinatorsSpec @@ -697,6 +707,24 @@ errorsSpec = do nf <> ib `shouldBe` ib nf <> wm `shouldBe` wm +routerSpec :: Spec +routerSpec = do + describe "Servant.Server.Internal.Router" $ do + let app' :: Application + app' = toApplication $ runRouter router' + + router', router :: Router + router' = tweakResponse (twk <$>) router + router = LeafRouter $ \_ cont -> cont (RR . Right $ responseBuilder (Status 201 "") [] "") + + twk :: Response -> Response + twk (ResponseBuilder (Status i s) hs b) = ResponseBuilder (Status (i + 1) s) hs b + twk b = b + + describe "tweakResponse" . with (return app') $ do + it "calls f on route result" $ do + get "" `shouldRespondWith` 202 + type MiscCombinatorsAPI = "version" :> HttpVersion :> Get '[JSON] String :<|> "secure" :> IsSecure :> Get '[JSON] String