diff --git a/servant-docs/src/Servant/Docs/Internal.hs b/servant-docs/src/Servant/Docs/Internal.hs index 53ae472d..33cb86a0 100644 --- a/servant-docs/src/Servant/Docs/Internal.hs +++ b/servant-docs/src/Servant/Docs/Internal.hs @@ -496,19 +496,6 @@ sampleByteStrings ctypes@Proxy Proxy = enc (t, s) = uncurry (t,,) <$> allMimeRender ctypes s in concatMap enc samples' --- | Generate a list of 'MediaType' values describing the content types --- accepted by an API component. -class SupportedTypes (list :: [*]) where - supportedTypes :: Proxy list -> [M.MediaType] - -instance SupportedTypes '[] where - supportedTypes Proxy = [] - -instance (Accept ctype, SupportedTypes rest) => SupportedTypes (ctype ': rest) - where - supportedTypes Proxy = - contentType (Proxy :: Proxy ctype) : supportedTypes (Proxy :: Proxy rest) - -- | The class that helps us automatically get documentation -- for GET parameters. -- @@ -709,14 +696,14 @@ instance #if MIN_VERSION_base(4,8,0) {-# OVERLAPPABLe #-} #endif - (ToSample a, IsNonEmpty cts, AllMimeRender cts a, SupportedTypes cts) + (ToSample a, IsNonEmpty cts, AllMimeRender cts a) => HasDocs (Delete cts a) where docsFor Proxy (endpoint, action) DocOptions{..} = single endpoint' action' where endpoint' = endpoint & method .~ DocDELETE action' = action & response.respBody .~ take _maxSamples (sampleByteStrings t p) - & response.respTypes .~ supportedTypes t + & response.respTypes .~ allMime t t = Proxy :: Proxy cts p = Proxy :: Proxy a @@ -724,7 +711,7 @@ instance #if MIN_VERSION_base(4,8,0) {-# OVERLAPPING #-} #endif - (ToSample a, IsNonEmpty cts, AllMimeRender cts a, SupportedTypes cts + (ToSample a, IsNonEmpty cts, AllMimeRender cts a , AllHeaderSamples ls , GetHeaders (HList ls) ) => HasDocs (Delete cts (Headers ls a)) where docsFor Proxy (endpoint, action) DocOptions{..} = @@ -733,7 +720,7 @@ instance where hdrs = allHeaderToSample (Proxy :: Proxy ls) endpoint' = endpoint & method .~ DocDELETE action' = action & response.respBody .~ take _maxSamples (sampleByteStrings t p) - & response.respTypes .~ supportedTypes t + & response.respTypes .~ allMime t & response.respHeaders .~ hdrs t = Proxy :: Proxy cts p = Proxy :: Proxy a @@ -742,14 +729,14 @@ instance #if MIN_VERSION_base(4,8,0) {-# OVERLAPPABLe #-} #endif - (ToSample a, IsNonEmpty cts, AllMimeRender cts a, SupportedTypes cts) + (ToSample a, IsNonEmpty cts, AllMimeRender cts a) => HasDocs (Get cts a) where docsFor Proxy (endpoint, action) DocOptions{..} = single endpoint' action' where endpoint' = endpoint & method .~ DocGET action' = action & response.respBody .~ take _maxSamples (sampleByteStrings t p) - & response.respTypes .~ supportedTypes t + & response.respTypes .~ allMime t t = Proxy :: Proxy cts p = Proxy :: Proxy a @@ -757,7 +744,7 @@ instance #if MIN_VERSION_base(4,8,0) {-# OVERLAPPING #-} #endif - (ToSample a, IsNonEmpty cts, AllMimeRender cts a, SupportedTypes cts + (ToSample a, IsNonEmpty cts, AllMimeRender cts a , AllHeaderSamples ls , GetHeaders (HList ls) ) => HasDocs (Get cts (Headers ls a)) where docsFor Proxy (endpoint, action) DocOptions{..} = @@ -766,7 +753,7 @@ instance where hdrs = allHeaderToSample (Proxy :: Proxy ls) endpoint' = endpoint & method .~ DocGET action' = action & response.respBody .~ take _maxSamples (sampleByteStrings t p) - & response.respTypes .~ supportedTypes t + & response.respTypes .~ allMime t & response.respHeaders .~ hdrs t = Proxy :: Proxy cts p = Proxy :: Proxy a @@ -784,14 +771,14 @@ instance #if MIN_VERSION_base(4,8,0) {-# OVERLAPPABLE #-} #endif - (ToSample a, IsNonEmpty cts, AllMimeRender cts a, SupportedTypes cts) + (ToSample a, IsNonEmpty cts, AllMimeRender cts a) => HasDocs (Post cts a) where docsFor Proxy (endpoint, action) DocOptions{..} = single endpoint' action' where endpoint' = endpoint & method .~ DocPOST action' = action & response.respBody .~ take _maxSamples (sampleByteStrings t p) - & response.respTypes .~ supportedTypes t + & response.respTypes .~ allMime t & response.respStatus .~ 201 t = Proxy :: Proxy cts p = Proxy :: Proxy a @@ -800,7 +787,7 @@ instance #if MIN_VERSION_base(4,8,0) {-# OVERLAPPING #-} #endif - (ToSample a, IsNonEmpty cts, AllMimeRender cts a, SupportedTypes cts + (ToSample a, IsNonEmpty cts, AllMimeRender cts a , AllHeaderSamples ls , GetHeaders (HList ls) ) => HasDocs (Post cts (Headers ls a)) where docsFor Proxy (endpoint, action) DocOptions{..} = @@ -809,7 +796,7 @@ instance where hdrs = allHeaderToSample (Proxy :: Proxy ls) endpoint' = endpoint & method .~ DocPOST action' = action & response.respBody .~ take _maxSamples (sampleByteStrings t p) - & response.respTypes .~ supportedTypes t + & response.respTypes .~ allMime t & response.respStatus .~ 201 & response.respHeaders .~ hdrs t = Proxy :: Proxy cts @@ -819,14 +806,14 @@ instance #if MIN_VERSION_base(4,8,0) {-# OVERLAPPABLE #-} #endif - (ToSample a, IsNonEmpty cts, AllMimeRender cts a, SupportedTypes cts) + (ToSample a, IsNonEmpty cts, AllMimeRender cts a) => HasDocs (Put cts a) where docsFor Proxy (endpoint, action) DocOptions{..} = single endpoint' action' where endpoint' = endpoint & method .~ DocPUT action' = action & response.respBody .~ take _maxSamples (sampleByteStrings t p) - & response.respTypes .~ supportedTypes t + & response.respTypes .~ allMime t & response.respStatus .~ 200 t = Proxy :: Proxy cts p = Proxy :: Proxy a @@ -835,8 +822,8 @@ instance #if MIN_VERSION_base(4,8,0) {-# OVERLAPPING #-} #endif - (ToSample a, IsNonEmpty cts, AllMimeRender cts a, SupportedTypes cts - , AllHeaderSamples ls , GetHeaders (HList ls) ) + ( ToSample a, IsNonEmpty cts, AllMimeRender cts a, + AllHeaderSamples ls , GetHeaders (HList ls) ) => HasDocs (Put cts (Headers ls a)) where docsFor Proxy (endpoint, action) DocOptions{..} = single endpoint' action' @@ -844,7 +831,7 @@ instance where hdrs = allHeaderToSample (Proxy :: Proxy ls) endpoint' = endpoint & method .~ DocPUT action' = action & response.respBody .~ take _maxSamples (sampleByteStrings t p) - & response.respTypes .~ supportedTypes t + & response.respTypes .~ allMime t & response.respStatus .~ 200 & response.respHeaders .~ hdrs t = Proxy :: Proxy cts @@ -890,8 +877,7 @@ instance HasDocs Raw where -- example data. However, there's no reason to believe that the instances of -- 'AllMimeUnrender' and 'AllMimeRender' actually agree (or to suppose that -- both are even defined) for any particular type. -instance (ToSample a, IsNonEmpty cts, AllMimeRender cts a, HasDocs sublayout - , SupportedTypes cts) +instance (ToSample a, IsNonEmpty cts, AllMimeRender cts a, HasDocs sublayout) => HasDocs (ReqBody cts a :> sublayout) where docsFor Proxy (endpoint, action) = @@ -899,7 +885,7 @@ instance (ToSample a, IsNonEmpty cts, AllMimeRender cts a, HasDocs sublayout where sublayoutP = Proxy :: Proxy sublayout action' = action & rqbody .~ sampleByteString t p - & rqtypes .~ supportedTypes t + & rqtypes .~ allMime t t = Proxy :: Proxy cts p = Proxy :: Proxy a @@ -957,4 +943,3 @@ instance ToSample a => ToSample (Product a) instance ToSample a => ToSample (First a) instance ToSample a => ToSample (Last a) instance ToSample a => ToSample (Dual a) - diff --git a/servant-examples/auth-combinator/auth-combinator.hs b/servant-examples/auth-combinator/auth-combinator.hs index c0b4299d..ec152782 100644 --- a/servant-examples/auth-combinator/auth-combinator.hs +++ b/servant-examples/auth-combinator/auth-combinator.hs @@ -10,7 +10,6 @@ import Data.Aeson import Data.ByteString (ByteString) import Data.Text (Text) import GHC.Generics -import Network.HTTP.Types import Network.Wai import Network.Wai.Handler.Warp import Servant @@ -29,15 +28,16 @@ data AuthProtected instance HasServer rest => HasServer (AuthProtected :> rest) where type ServerT (AuthProtected :> rest) m = ServerT rest m - route Proxy a = WithRequest $ \ request -> - route (Proxy :: Proxy rest) $ do - case lookup "Cookie" (requestHeaders request) of - Nothing -> return $! FailFatal err401 { errBody = "Missing auth header" } - Just v -> do - authGranted <- isGoodCookie v - if authGranted - then a - else return $ FailFatal err403 { errBody = "Invalid cookie" } + route Proxy subserver = WithRequest $ \ request -> + route (Proxy :: Proxy rest) $ addAcceptCheck subserver $ cookieCheck request + where + cookieCheck req = case lookup "Cookie" (requestHeaders req) of + Nothing -> return $ FailFatal err401 { errBody = "Missing auth header" } + Just v -> do + authGranted <- isGoodCookie v + if authGranted + then return $ Route () + else return $ FailFatal err403 { errBody = "Invalid cookie" } type PrivateAPI = Get '[JSON] [PrivateData] diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index 1a7335d3..9b69d9c4 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -108,6 +108,7 @@ test-suite spec , network >= 2.6 , QuickCheck , parsec + , safe , servant , servant-server , string-conversions diff --git a/servant-server/src/Servant/Server.hs b/servant-server/src/Servant/Server.hs index f6781b66..a26941ea 100644 --- a/servant-server/src/Servant/Server.hs +++ b/servant-server/src/Servant/Server.hs @@ -103,7 +103,10 @@ import Servant.Server.Internal.Enter -- > main = Network.Wai.Handler.Warp.run 8080 app -- serve :: HasServer layout => Proxy layout -> Server layout -> Application -serve p server = toApplication (runRouter (route p (return (Route server)))) +serve p server = toApplication (runRouter (route p d)) + where + d = Delayed r r r (\ _ _ -> Route server) + r = return (Route ()) -- Documentation diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index 6c717fa2..4200d052 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -46,7 +46,9 @@ import Servant.API ((:<|>) (..), (:>), Capture, Raw, RemoteHost, ReqBody, Vault) import Servant.API.ContentTypes (AcceptHeader (..), AllCTRender (..), - AllCTUnrender (..)) + AllCTUnrender (..), + AllMime, + canHandleAcceptH) import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders, getResponse) @@ -60,7 +62,7 @@ import Web.HttpApiData.Internal (parseUrlPieceMaybe, parseHeaderMaybe, class HasServer layout where type ServerT layout (m :: * -> *) :: * - route :: Proxy layout -> IO (RouteResult (Server layout)) -> Router + route :: Proxy layout -> Delayed (Server layout) -> Router type Server layout = ServerT layout (ExceptT ServantErr IO) @@ -81,8 +83,8 @@ instance (HasServer a, HasServer b) => HasServer (a :<|> b) where type ServerT (a :<|> b) m = ServerT a m :<|> ServerT b m - route Proxy server = choice (route pa (extractL <$> server)) - (route pb (extractR <$> server)) + route Proxy server = choice (route pa ((\ (a :<|> _) -> a) <$> server)) + (route pb ((\ (_ :<|> b) -> b) <$> server)) where pa = Proxy :: Proxy a pb = Proxy :: Proxy b @@ -112,12 +114,15 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer sublayout) type ServerT (Capture capture a :> sublayout) m = a -> ServerT sublayout m - route Proxy subserver = - DynamicRouter $ \ first -> case captured captureProxy first of - Nothing -> LeafRouter (\_ r -> r $ Fail err404) - Just v -> route (Proxy :: Proxy sublayout) (feedTo subserver v) - - where captureProxy = Proxy :: Proxy (Capture capture a) + route Proxy d = + DynamicRouter $ \ first -> + route (Proxy :: Proxy sublayout) + (addCapture d $ case captured captureProxy first of + Nothing -> return $ Fail err404 + Just v -> return $ Route v + ) + where + captureProxy = Proxy :: Proxy (Capture capture a) allowedMethodHead :: Method -> Request -> Bool allowedMethodHead method request = method == methodGet && requestMethod request == methodHead @@ -130,56 +135,64 @@ processMethodRouter :: forall a. ConvertibleStrings a B.ByteString -> Maybe [(HeaderName, B.ByteString)] -> Request -> RouteResult Response processMethodRouter handleA status method headers request = case handleA of - Nothing -> FailFatal err406 - Just (contentT, body) -> Route $! responseLBS status hdrs bdy + Nothing -> FailFatal err406 -- this should not happen (checked before), so we make it fatal if it does + Just (contentT, body) -> Route $ responseLBS status hdrs bdy where bdy = if allowedMethodHead method request then "" else body hdrs = (hContentType, cs contentT) : (fromMaybe [] headers) +methodCheck :: Method -> Request -> IO (RouteResult ()) +methodCheck method request + | allowedMethod method request = return $ Route () + | otherwise = return $ Fail err405 + +acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> IO (RouteResult ()) +acceptCheck proxy accH + | canHandleAcceptH proxy (AcceptHeader accH) = return $ Route () + | otherwise = return $ Fail err406 + methodRouter :: (AllCTRender ctypes a) => Method -> Proxy ctypes -> Status - -> IO (RouteResult (ExceptT ServantErr IO a)) + -> Delayed (ExceptT ServantErr IO a) -> Router methodRouter method proxy status action = LeafRouter route' where route' request respond - | pathIsEmpty request && allowedMethod method request = do - runAction action respond $ \ output -> do - let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request - handleA = handleAcceptH proxy (AcceptHeader accH) output - processMethodRouter handleA status method Nothing request - | pathIsEmpty request && requestMethod request /= method = - respond $ Fail err405 + | pathIsEmpty request = + let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request + in runAction (action `addMethodCheck` methodCheck method request + `addAcceptCheck` acceptCheck proxy accH + ) respond $ \ output -> do + let handleA = handleAcceptH proxy (AcceptHeader accH) output + processMethodRouter handleA status method Nothing request | otherwise = respond $ Fail err404 methodRouterHeaders :: (GetHeaders (Headers h v), AllCTRender ctypes v) => Method -> Proxy ctypes -> Status - -> IO (RouteResult (ExceptT ServantErr IO (Headers h v))) + -> Delayed (ExceptT ServantErr IO (Headers h v)) -> Router methodRouterHeaders method proxy status action = LeafRouter route' where route' request respond - | pathIsEmpty request && allowedMethod method request = do - runAction action respond $ \ output -> do - let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request - headers = getHeaders output - handleA = handleAcceptH proxy (AcceptHeader accH) (getResponse output) - processMethodRouter handleA status method (Just headers) request - | pathIsEmpty request && requestMethod request /= method = - respond $ Fail err405 + | pathIsEmpty request = + let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request + in runAction (action `addMethodCheck` methodCheck method request + `addAcceptCheck` acceptCheck proxy accH + ) respond $ \ output -> do + let headers = getHeaders output + handleA = handleAcceptH proxy (AcceptHeader accH) (getResponse output) + processMethodRouter handleA status method (Just headers) request | otherwise = respond $ Fail err404 methodRouterEmpty :: Method - -> IO (RouteResult (ExceptT ServantErr IO ())) + -> Delayed (ExceptT ServantErr IO ()) -> Router methodRouterEmpty method action = LeafRouter route' where route' request respond - | pathIsEmpty request && allowedMethod method request = do - runAction action respond $ \ () -> + | pathIsEmpty request = do + runAction (addMethodCheck action (methodCheck method request)) respond $ \ () -> Route $! responseLBS noContent204 [] "" - | pathIsEmpty request && requestMethod request /= method = - respond $ Fail err405 | otherwise = respond $ Fail err404 -- | If you have a 'Delete' endpoint in your API, @@ -300,7 +313,7 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout) route Proxy subserver = WithRequest $ \ request -> let mheader = parseHeaderMaybe =<< lookup str (requestHeaders request) - in route (Proxy :: Proxy sublayout) (feedTo subserver mheader) + in route (Proxy :: Proxy sublayout) (passToServer subserver mheader) where str = fromString $ symbolVal (Proxy :: Proxy sym) -- | When implementing the handler for a 'Post' endpoint, @@ -472,7 +485,7 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout) Just Nothing -> Nothing -- param present with no value -> Nothing Just (Just v) -> parseQueryParamMaybe v -- if present, we try to convert to -- the right type - in route (Proxy :: Proxy sublayout) (feedTo subserver param) + in route (Proxy :: Proxy sublayout) (passToServer subserver param) where paramname = cs $ symbolVal (Proxy :: Proxy sym) -- | If you use @'QueryParams' "authors" Text@ in one of the endpoints for your API, @@ -507,7 +520,7 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout) -- corresponding values parameters = filter looksLikeParam querytext values = mapMaybe (convert . snd) parameters - in route (Proxy :: Proxy sublayout) (feedTo subserver values) + in route (Proxy :: Proxy sublayout) (passToServer subserver values) where paramname = cs $ symbolVal (Proxy :: Proxy sym) looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]") convert Nothing = Nothing @@ -537,7 +550,7 @@ instance (KnownSymbol sym, HasServer sublayout) Just Nothing -> True -- param is there, with no value Just (Just v) -> examine v -- param with a value Nothing -> False -- param not in the query string - in route (Proxy :: Proxy sublayout) (feedTo subserver param) + in route (Proxy :: Proxy sublayout) (passToServer subserver param) where paramname = cs $ symbolVal (Proxy :: Proxy sym) examine v | v == "true" || v == "1" || v == "" = True | otherwise = False @@ -555,7 +568,7 @@ instance HasServer Raw where type ServerT Raw m = Application route Proxy rawApplication = LeafRouter $ \ request respond -> do - r <- rawApplication + r <- runDelayed rawApplication case r of Route app -> app request (respond . Route) Fail a -> respond $ Fail a @@ -589,19 +602,21 @@ instance ( AllCTUnrender list a, HasServer sublayout a -> ServerT sublayout m route Proxy subserver = WithRequest $ \ request -> - route (Proxy :: Proxy sublayout) $ do - -- See HTTP RFC 2616, section 7.2.1 - -- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 - -- See also "W3C Internet Media Type registration, consistency of use" - -- http://www.w3.org/2001/tag/2002/0129-mime - let contentTypeH = fromMaybe "application/octet-stream" - $ lookup hContentType $ requestHeaders request - mrqbody <- handleCTypeH (Proxy :: Proxy list) (cs contentTypeH) - <$> lazyRequestBody request - case mrqbody of - Nothing -> return $ FailFatal err415 - Just (Left e) -> return $ FailFatal err400 { errBody = cs e } - Just (Right v) -> feedTo subserver v + route (Proxy :: Proxy sublayout) (addBodyCheck subserver (bodyCheck request)) + where + bodyCheck request = do + -- See HTTP RFC 2616, section 7.2.1 + -- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 + -- See also "W3C Internet Media Type registration, consistency of use" + -- http://www.w3.org/2001/tag/2002/0129-mime + let contentTypeH = fromMaybe "application/octet-stream" + $ lookup hContentType $ requestHeaders request + mrqbody <- handleCTypeH (Proxy :: Proxy list) (cs contentTypeH) + <$> lazyRequestBody request + case mrqbody of + Nothing -> return $ FailFatal err415 + Just (Left e) -> return $ FailFatal err400 { errBody = cs e } + Just (Right v) -> return $ Route v -- | Make sure the incoming request starts with @"/path"@, strip it and -- pass the rest of the request path to @sublayout@. @@ -618,13 +633,13 @@ instance HasServer api => HasServer (RemoteHost :> api) where type ServerT (RemoteHost :> api) m = SockAddr -> ServerT api m route Proxy subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) (feedTo subserver $ remoteHost req) + route (Proxy :: Proxy api) (passToServer subserver $ remoteHost req) instance HasServer api => HasServer (IsSecure :> api) where type ServerT (IsSecure :> api) m = IsSecure -> ServerT api m route Proxy subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) (feedTo subserver $ secure req) + route (Proxy :: Proxy api) (passToServer subserver $ secure req) where secure req = if isSecure req then Secure else NotSecure @@ -632,13 +647,13 @@ instance HasServer api => HasServer (Vault :> api) where type ServerT (Vault :> api) m = Vault -> ServerT api m route Proxy subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) (feedTo subserver $ vault req) + route (Proxy :: Proxy api) (passToServer subserver $ vault req) instance HasServer api => HasServer (HttpVersion :> api) where type ServerT (HttpVersion :> api) m = HttpVersion -> ServerT api m route Proxy subserver = WithRequest $ \req -> - route (Proxy :: Proxy api) (feedTo subserver $ httpVersion req) + route (Proxy :: Proxy api) (passToServer subserver $ httpVersion req) pathIsEmpty :: Request -> Bool pathIsEmpty = go . pathInfo diff --git a/servant-server/src/Servant/Server/Internal/Router.hs b/servant-server/src/Servant/Server/Internal/Router.hs index 3914af0d..63b05c05 100644 --- a/servant-server/src/Servant/Server/Internal/Router.hs +++ b/servant-server/src/Servant/Server/Internal/Router.hs @@ -6,9 +6,9 @@ import Data.Map (Map) import qualified Data.Map as M import Data.Text (Text) import Network.Wai (Request, Response, pathInfo) -import Servant.Server.Internal.ServantErr import Servant.Server.Internal.PathInfo import Servant.Server.Internal.RoutingApplication +import Servant.Server.Internal.ServantErr type Router = Router' RoutingApplication @@ -77,10 +77,18 @@ runRouter (Choice r1 r2) request respond = Fail _ -> runRouter r2 request $ \ mResponse2 -> respond (highestPri mResponse1 mResponse2) _ -> respond mResponse1 - where - highestPri (Fail e1) (Fail e2) = - if errHTTPCode e1 == 404 && errHTTPCode e2 /= 404 - then Fail e2 - else Fail e1 - highestPri (Fail _) y = y - highestPri x _ = x + where + highestPri (Fail e1) (Fail e2) = + if worseHTTPCode (errHTTPCode e1) (errHTTPCode e2) + then Fail e2 + else Fail e1 + highestPri (Fail _) y = y + highestPri x _ = x + + +-- Priority on HTTP codes. +-- +-- It just so happens that 404 < 405 < 406 as far as +-- we are concerned here, so we can use (<). +worseHTTPCode :: Int -> Int -> Bool +worseHTTPCode = (<) diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index f430fb2e..cc3f5965 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -2,6 +2,9 @@ {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} module Servant.Server.Internal.RoutingApplication where #if !MIN_VERSION_base(4,8,0) @@ -18,19 +21,18 @@ import Network.Wai (Application, Request, Response, ResponseReceived, requestBody, strictRequestBody) -import Servant.API ((:<|>) (..)) import Servant.Server.Internal.ServantErr type RoutingApplication = Request -- ^ the request, the field 'pathInfo' may be modified by url routing -> (RouteResult Response -> IO ResponseReceived) -> IO ResponseReceived --- | A wrapper around @'Either' 'RouteMismatch' a@. +-- | The result of matching against a path in the route tree. data RouteResult a = Fail ServantErr -- ^ Keep trying other paths. The @ServantErr@ - -- should only be 404 or 405. - | FailFatal ServantErr -- ^ Don't other paths. - | Route a + -- should only be 404, 405 or 406. + | FailFatal !ServantErr -- ^ Don't try other paths. + | Route !a deriving (Eq, Show, Read, Functor) data ReqBodyState = Uncalled @@ -63,15 +65,183 @@ toApplication ra request respond = do ra request{ requestBody = memoReqBody } routingRespond where routingRespond :: RouteResult Response -> IO ResponseReceived - routingRespond (Fail err) = respond $! responseServantErr err - routingRespond (FailFatal err) = respond $! responseServantErr err - routingRespond (Route v) = respond v + routingRespond (Fail err) = respond $ responseServantErr err + routingRespond (FailFatal err) = respond $ responseServantErr err + routingRespond (Route v) = respond v -runAction :: IO (RouteResult (ExceptT ServantErr IO a)) +-- TODO: The above may not be quite right yet. +-- +-- We currently mix up the order in which we perform checks +-- and the priority with which errors are reported. +-- +-- For example, we perform Capture checks prior to method checks, +-- and therefore get 404 before 405. +-- +-- However, we also perform body checks prior to method checks +-- now, and therefore get 415 before 405, which is wrong. +-- +-- If we delay Captures, but perform method checks eagerly, we +-- end up potentially preferring 405 over 404, whcih is also bad. +-- +-- So in principle, we'd like: +-- +-- static routes (can cause 404) +-- delayed captures (can cause 404) +-- methods (can cause 405) +-- delayed body (can cause 415, 400) +-- accept header (can cause 406) +-- +-- According to the HTTP decision diagram, the priority order +-- between HTTP status codes is as follows: +-- + +-- | A 'Delayed' is a representation of a handler with scheduled +-- delayed checks that can trigger errors. +-- +-- Why would we want to delay checks? +-- +-- There are two reasons: +-- +-- 1. Currently, the order in which we perform checks coincides +-- with the error we will generate. This is because during checks, +-- once an error occurs, we do not perform any subsequent checks, +-- but rather return this error. +-- +-- This is not a necessity: we could continue doing other checks, +-- and choose the preferred error. However, that would in general +-- mean more checking, which leads us to the other reason. +-- +-- 2. We really want to avoid doing certain checks too early. For +-- example, captures involve parsing, and are much more costly +-- than static route matches. In particular, if several paths +-- contain the "same" capture, we'd like as much as possible to +-- avoid trying the same parse many times. Also tricky is the +-- request body. Again, this involves parsing, but also, WAI makes +-- obtaining the request body a side-effecting operation. We +-- could/can work around this by manually caching the request body, +-- but we'd rather keep the number of times we actually try to +-- decode the request body to an absolute minimum. +-- +-- We prefer to have the following relative priorities of error +-- codes: +-- +-- @ +-- 404 +-- 405 (bad method) +-- 401 (unauthorized) +-- 415 (unsupported media type) +-- 400 (bad request) +-- 406 (not acceptable) +-- @ +-- +-- Therefore, while routing, we delay most checks so that they +-- will ultimately occur in the right order. +-- +-- A 'Delayed' contains three delayed blocks of tests, and +-- the actual handler: +-- +-- 1. Delayed captures. These can actually cause 404, and +-- while they're costly, they should be done first among the +-- delayed checks (at least as long as we do not decouple the +-- check order from the error reporting, see above). Delayed +-- captures can provide inputs to the actual handler. +-- +-- 2. Method check(s). This can cause a 405. On success, +-- it does not provide an input for the handler. Method checks +-- are comparatively cheap. +-- +-- 3. Body and accept header checks. The request body check can +-- cause both 400 and 415. This provides an input to the handler. +-- The accept header check can be performed as the final +-- computation in this block. It can cause a 406. +-- +data Delayed :: * -> * where + Delayed :: IO (RouteResult a) + -> IO (RouteResult ()) + -> IO (RouteResult b) + -> (a -> b -> RouteResult c) + -> Delayed c + +deriving instance Functor Delayed + +-- | Add a capture to the end of the capture block. +addCapture :: Delayed (a -> b) + -> IO (RouteResult a) + -> Delayed b +addCapture (Delayed captures method body server) new = + Delayed (combineRouteResults (,) captures new) method body (\ (x, v) y -> ($ v) <$> server x y) + +-- | Add a method check to the end of the method block. +addMethodCheck :: Delayed a + -> IO (RouteResult ()) + -> Delayed a +addMethodCheck (Delayed captures method body server) new = + Delayed captures (combineRouteResults const method new) body server + +-- | Add a body check to the end of the body block. +addBodyCheck :: Delayed (a -> b) + -> IO (RouteResult a) + -> Delayed b +addBodyCheck (Delayed captures method body server) new = + Delayed captures method (combineRouteResults (,) body new) (\ x (y, v) -> ($ v) <$> server x y) + +-- | Add an accept header check to the end of the body block. +-- The accept header check should occur after the body check, +-- but this will be the case, because the accept header check +-- is only scheduled by the method combinators. +addAcceptCheck :: Delayed a + -> IO (RouteResult ()) + -> Delayed a +addAcceptCheck (Delayed captures method body server) new = + Delayed captures method (combineRouteResults const body new) server + +-- | Many combinators extract information that is passed to +-- the handler without the possibility of failure. In such a +-- case, 'passToServer' can be used. +passToServer :: Delayed (a -> b) -> a -> Delayed b +passToServer d x = ($ x) <$> d + +-- | The combination 'IO . RouteResult' is a monad, but we +-- don't explicitly wrap it in a newtype in order to make it +-- an instance. This is the '>>=' of that monad. +-- +-- We stop on the first error. +bindRouteResults :: IO (RouteResult a) -> (a -> IO (RouteResult b)) -> IO (RouteResult b) +bindRouteResults m f = do + r <- m + case r of + Fail e -> return $ Fail e + FailFatal e -> return $ FailFatal e + Route a -> f a + +-- | Common special case of 'bindRouteResults', corresponding +-- to 'liftM2'. +combineRouteResults :: (a -> b -> c) -> IO (RouteResult a) -> IO (RouteResult b) -> IO (RouteResult c) +combineRouteResults f m1 m2 = + m1 `bindRouteResults` \ a -> + m2 `bindRouteResults` \ b -> + return (Route (f a b)) + +-- | Run a delayed server. Performs all scheduled operations +-- in order, and passes the results from the capture and body +-- blocks on to the actual handler. +runDelayed :: Delayed a + -> IO (RouteResult a) +runDelayed (Delayed captures method body server) = + captures `bindRouteResults` \ c -> + method `bindRouteResults` \ _ -> + body `bindRouteResults` \ b -> + return (server c b) + +-- | Runs a delayed server and the resulting action. +-- Takes a continuation that lets us send a response. +-- Also takes a continuation for how to turn the +-- result of the delayed server into a response. +runAction :: Delayed (ExceptT ServantErr IO a) -> (RouteResult Response -> IO r) -> (a -> RouteResult Response) -> IO r -runAction action respond k = action >>= go >>= respond +runAction action respond k = runDelayed action >>= go >>= respond where go (Fail e) = return $ Fail e go (FailFatal e) = return $ FailFatal e @@ -80,16 +250,3 @@ runAction action respond k = action >>= go >>= respond case e of Left err -> return . Route $ responseServantErr err Right x -> return $! k x - -feedTo :: IO (RouteResult (a -> b)) -> a -> IO (RouteResult b) -feedTo f x = (($ x) <$>) <$> f - -extractL :: RouteResult (a :<|> b) -> RouteResult a -extractL (Route (a :<|> _)) = Route a -extractL (Fail x) = Fail x -extractL (FailFatal x) = FailFatal x - -extractR :: RouteResult (a :<|> b) -> RouteResult b -extractR (Route (_ :<|> b)) = Route b -extractR (Fail x) = Fail x -extractR (FailFatal x) = FailFatal x diff --git a/servant-server/test/Servant/Server/ErrorSpec.hs b/servant-server/test/Servant/Server/ErrorSpec.hs index 9a0bb2dd..60212a4a 100644 --- a/servant-server/test/Servant/Server/ErrorSpec.hs +++ b/servant-server/test/Servant/Server/ErrorSpec.hs @@ -5,12 +5,14 @@ {-# OPTIONS_GHC -fno-warn-orphans #-} module Servant.Server.ErrorSpec (spec) where +import Control.Monad.Trans.Except (throwE) import Data.Aeson (encode) -import qualified Data.ByteString.Lazy.Char8 as BCL import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy.Char8 as BCL import Data.Proxy import Network.HTTP.Types (hAccept, hContentType, methodGet, methodPost, methodPut) +import Safe (readMay) import Test.Hspec import Test.Hspec.Wai @@ -54,7 +56,7 @@ errorOrderApi :: Proxy ErrorOrderApi errorOrderApi = Proxy errorOrderServer :: Server ErrorOrderApi -errorOrderServer = \_ _ -> return 5 +errorOrderServer = \_ _ -> throwE err402 errorOrderSpec :: Spec errorOrderSpec = describe "HTTP error order" @@ -65,6 +67,7 @@ errorOrderSpec = describe "HTTP error order" badUrl = "home/nonexistent" badBody = "nonsense" goodContentType = (hContentType, "application/json") + goodAccept = (hAccept, "application/json") goodMethod = methodPost goodUrl = "home/2" goodBody = encode (5 :: Int) @@ -89,6 +92,10 @@ errorOrderSpec = describe "HTTP error order" request goodMethod goodUrl [goodContentType, badAccept] goodBody `shouldRespondWith` 406 + it "has handler-level errors as last priority" $ do + request goodMethod goodUrl [goodContentType, goodAccept] goodBody + `shouldRespondWith` 402 + type PrioErrorsApi = ReqBody '[JSON] Integer :> "foo" :> Get '[JSON] Integer prioErrorsApi :: Proxy PrioErrorsApi @@ -107,7 +114,7 @@ prioErrorsSpec = describe "PrioErrors" $ do `shouldRespondWith` resp where fulldescr = "returns " ++ show (matchStatus resp) ++ " on " ++ mdescr - ++ " " ++ (BC.unpack path) ++ " (" ++ cdescr ++ ")" + ++ " " ++ BC.unpack path ++ " (" ++ cdescr ++ ")" get' = ("GET", methodGet) put' = ("PUT", methodPut) @@ -140,7 +147,7 @@ prioErrorsSpec = describe "PrioErrors" $ do -- * Error Retry {{{ type ErrorRetryApi - = "a" :> ReqBody '[JSON] Int :> Post '[JSON] Int -- 0 + = "a" :> ReqBody '[JSON] Int :> Post '[JSON] Int -- err402 :<|> "a" :> ReqBody '[PlainText] Int :> Post '[JSON] Int -- 1 :<|> "a" :> ReqBody '[JSON] Int :> Post '[PlainText] Int -- 2 :<|> "a" :> ReqBody '[JSON] String :> Post '[JSON] Int -- 3 @@ -154,7 +161,7 @@ errorRetryApi = Proxy errorRetryServer :: Server ErrorRetryApi errorRetryServer - = (\_ -> return 0) + = (\_ -> throwE err402) :<|> (\_ -> return 1) :<|> (\_ -> return 2) :<|> (\_ -> return 3) @@ -181,18 +188,6 @@ errorRetrySpec = describe "Handler search" request methodGet "a" [jsonCT, jsonAccept] jsonBody `shouldRespondWith` 200 { matchBody = Just $ encode (4 :: Int) } - it "should not continue when Content-Types don't match" $ do - request methodPost "a" [plainCT, jsonAccept] jsonBody - `shouldRespondWith` 415 - - it "should not continue when body can't be deserialized" $ do - request methodPost "a" [jsonCT, jsonAccept] (encode ("nonsense" :: String)) - `shouldRespondWith` 400 - - it "should not continue when Accepts don't match" $ do - request methodPost "a" [jsonCT, plainAccept] jsonBody - `shouldRespondWith` 406 - -- }}} ------------------------------------------------------------------------------ -- * Error Choice {{{ @@ -233,7 +228,7 @@ errorChoiceSpec = describe "Multiple handlers return errors" request methodPost "path3" [(hContentType, "application/json")] "" `shouldRespondWith` 400 request methodPost "path4" [(hContentType, "text/plain;charset=utf-8"), - (hAccept, "application/json")] "" + (hAccept, "blah")] "5" `shouldRespondWith` 406 @@ -242,10 +237,8 @@ errorChoiceSpec = describe "Multiple handlers return errors" -- * Instances {{{ instance MimeUnrender PlainText Int where - mimeUnrender _ = Right . read . BCL.unpack + mimeUnrender _ x = maybe (Left "no parse") Right (readMay $ BCL.unpack x) instance MimeRender PlainText Int where mimeRender _ = BCL.pack . show -- }}} --- - diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index 4ee65423..11816853 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -1,20 +1,12 @@ -<<<<<<< HEAD {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} -======= -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} ->>>>>>> Review fixes {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeSynonymInstances #-} -<<<<<<< HEAD {-# LANGUAGE FlexibleInstances #-} -======= ->>>>>>> Review fixes module Servant.ServerSpec where @@ -55,6 +47,7 @@ import Test.Hspec (Spec, describe, it, shouldBe) import Test.Hspec.Wai (get, liftIO, matchHeaders, matchStatus, post, request, shouldRespondWith, with, (<:>)) +<<<<<<< HEAD import Servant.API ((:<|>) (..), (:>), Capture, Delete, Get, Header (..), Headers, HttpVersion, IsSecure (..), JSON, @@ -63,12 +56,12 @@ import Servant.API ((:<|>) (..), (:>), Capture, Delete, Raw, RemoteHost, ReqBody, addHeader) import Servant.Server (Server, serve, ServantErr(..), err404) +======= +import Servant.Server.Internal.RoutingApplication (toApplication, RouteResult(..)) +>>>>>>> Rebase cleanup and test fixes. import Servant.Server.Internal.Router (tweakResponse, runRouter, Router, Router'(LeafRouter)) -import Servant.Server.Internal.RoutingApplication - (RouteResult(..), RouteMismatch(..), - toApplication) -- * test data types @@ -279,13 +272,13 @@ queryParamSpec = do } let params3'' = "?unknown=" - response3' <- Network.Wai.Test.request defaultRequest{ + response3'' <- Network.Wai.Test.request defaultRequest{ rawQueryString = params3'', queryString = parseQuery params3'', pathInfo = ["b"] } liftIO $ - decode' (simpleBody response3') `shouldBe` Just alice{ + decode' (simpleBody response3'') `shouldBe` Just alice{ name = "Alice" } @@ -553,7 +546,7 @@ routerSpec = do router', router :: Router router' = tweakResponse (twk <$>) router - router = LeafRouter $ \_ cont -> cont (RR . Right $ responseBuilder (Status 201 "") [] "") + router = LeafRouter $ \_ cont -> cont (Route $ responseBuilder (Status 201 "") [] "") twk :: Response -> Response twk (ResponseBuilder (Status i s) hs b) = ResponseBuilder (Status (i + 1) s) hs b diff --git a/servant/src/Servant/API/ContentTypes.hs b/servant/src/Servant/API/ContentTypes.hs index db8eb61e..ab857ce2 100644 --- a/servant/src/Servant/API/ContentTypes.hs +++ b/servant/src/Servant/API/ContentTypes.hs @@ -57,12 +57,14 @@ module Servant.API.ContentTypes , AcceptHeader(..) , AllCTRender(..) , AllCTUnrender(..) + , AllMime(..) , AllMimeRender(..) , AllMimeUnrender(..) , FromFormUrlEncoded(..) , ToFormUrlEncoded(..) , IsNonEmpty , eitherDecodeLenient + , canHandleAcceptH ) where #if !MIN_VERSION_base(4,8,0) @@ -81,6 +83,7 @@ import Data.ByteString.Lazy (ByteString, fromStrict, toStrict) import qualified Data.ByteString.Lazy as B import qualified Data.ByteString.Lazy.Char8 as BC +import Data.Maybe (isJust) import Data.Monoid import Data.String.Conversions (cs) import qualified Data.Text as TextS @@ -156,14 +159,13 @@ newtype AcceptHeader = AcceptHeader BS.ByteString class Accept ctype => MimeRender ctype a where mimeRender :: Proxy ctype -> a -> ByteString -class AllCTRender (list :: [*]) a where +class (AllMimeRender list a) => AllCTRender (list :: [*]) a where -- If the Accept header can be matched, returns (Just) a tuple of the -- Content-Type and response (serialization of @a@ into the appropriate -- mimetype). handleAcceptH :: Proxy list -> AcceptHeader -> a -> Maybe (ByteString, ByteString) -instance ( AllMimeRender ctyps a, IsNonEmpty ctyps - ) => AllCTRender ctyps a where +instance (AllMimeRender ctyps a, IsNonEmpty ctyps) => AllCTRender ctyps a where handleAcceptH _ (AcceptHeader accept) val = M.mapAcceptMedia lkup accept where pctyps = Proxy :: Proxy ctyps amrs = allMimeRender pctyps val @@ -211,11 +213,24 @@ instance ( AllMimeUnrender ctyps a, IsNonEmpty ctyps -------------------------------------------------------------------------- -- * Utils (Internal) +class AllMime (list :: [*]) where + allMime :: Proxy list -> [M.MediaType] + +instance AllMime '[] where + allMime _ = [] + +instance (Accept ctyp, AllMime ctyps) => AllMime (ctyp ': ctyps) where + allMime _ = (contentType pctyp):allMime pctyps + where pctyp = Proxy :: Proxy ctyp + pctyps = Proxy :: Proxy ctyps + +canHandleAcceptH :: AllMime list => Proxy list -> AcceptHeader -> Bool +canHandleAcceptH p (AcceptHeader h ) = isJust $ M.matchAccept (allMime p) h -------------------------------------------------------------------------- -- Check that all elements of list are instances of MimeRender -------------------------------------------------------------------------- -class AllMimeRender (list :: [*]) a where +class (AllMime list) => AllMimeRender (list :: [*]) a where allMimeRender :: Proxy list -> a -- value to serialize -> [(M.MediaType, ByteString)] -- content-types/response pairs @@ -239,7 +254,7 @@ instance AllMimeRender '[] a where -------------------------------------------------------------------------- -- Check that all elements of list are instances of MimeUnrender -------------------------------------------------------------------------- -class AllMimeUnrender (list :: [*]) a where +class (AllMime list) => AllMimeUnrender (list :: [*]) a where allMimeUnrender :: Proxy list -> ByteString -> [(M.MediaType, Either String a)]