From 58e931f48a9dba161c146e40cf6fd57f908306a2 Mon Sep 17 00:00:00 2001 From: Oleg Grenrus Date: Mon, 16 Jan 2017 14:17:20 +0200 Subject: [PATCH] Resolve todos --- servant-server/src/Servant/Server/Internal.hs | 71 ++++++++++--------- .../Server/Internal/RoutingApplication.hs | 57 +++++++++------ .../test/Servant/Server/ErrorSpec.hs | 18 +++-- .../Server/Internal/RoutingApplicationSpec.hs | 15 ++-- servant/src/Servant/API/ContentTypes.hs | 21 +++--- 5 files changed, 108 insertions(+), 74 deletions(-) diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index 60c0a044..686cf59d 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -28,10 +28,11 @@ import Control.Monad.Trans.Resource (runResourceT) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC8 import qualified Data.ByteString.Lazy as BL +import Data.Maybe (fromMaybe, mapMaybe) import Data.Either (partitionEithers) -import Data.Maybe (fromMaybe) import Data.String (fromString) import Data.String.Conversions (cs, (<>)) +import qualified Data.Text as T import Data.Typeable import GHC.TypeLits (KnownNat, KnownSymbol, natVal, symbolVal) @@ -319,10 +320,9 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context) Just Nothing -> return Nothing -- param present with no value -> Nothing Just (Just v) -> case parseQueryParam v of - -- TODO: This should set an error description (including - -- paramname) - Left _e -> delayedFailFatal err400 -- parsing the request - -- paramter failed + Left e -> delayedFailFatal err400 + { errBody = cs $ "Error parsing query parameter " <> paramname <> " failed: " <> e + } Right param -> return $ Just param delayed = addParameterCheck subserver . withRequest $ \req -> @@ -356,26 +356,25 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context) type ServerT (QueryParams sym a :> api) m = [a] -> ServerT api m - route Proxy context subserver = - let querytext r = parseQueryText $ rawQueryString r - -- if sym is "foo", we look for query string parameters - -- named "foo" or "foo[]" and call parseQueryParam on the - -- corresponding values - parameters r = filter looksLikeParam (querytext r) - parseParam (paramName, paramTxt) = - case parseQueryParam (fromMaybe "" paramTxt) of - Left _e -> Left paramName -- On error, remember name of parameter - Right paramVal -> Right paramVal - parseParams req = - case partitionEithers $ parseParam <$> parameters req of - ([], params) -> return params -- No errors - -- TODO: This should set an error description - (_errors, _) -> delayedFailFatal err400 - delayed = addParameterCheck subserver . withRequest $ \req -> - parseParams req - in route (Proxy :: Proxy api) context delayed - where paramname = cs $ symbolVal (Proxy :: Proxy sym) - looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]") + route Proxy context subserver = route (Proxy :: Proxy api) context $ + subserver `addParameterCheck` withRequest paramsCheck + where + paramname = cs $ symbolVal (Proxy :: Proxy sym) + paramsCheck req = + case partitionEithers $ fmap parseQueryParam params of + ([], parsed) -> return parsed + (errs, _) -> delayedFailFatal err400 + { errBody = cs $ "Error parsing query parameter(s) " <> paramname <> " failed: " <> T.intercalate ", " errs + } + where + params :: [T.Text] + params = mapMaybe snd + . filter (looksLikeParam . fst) + . parseQueryText + . rawQueryString + $ req + + looksLikeParam name = name == paramname || name == (paramname <> "[]") -- | If you use @'QueryFlag' "published"@ in one of the endpoints for your API, -- this automatically requires your server-side handler to be a function @@ -457,22 +456,28 @@ instance ( AllCTUnrender list a, HasServer api context type ServerT (ReqBody list a :> api) m = a -> ServerT api m - route Proxy context subserver = - route (Proxy :: Proxy api) context (addBodyCheck subserver bodyCheck) + route Proxy context subserver + = route (Proxy :: Proxy api) context $ + addBodyCheck subserver ctCheck bodyCheck where - bodyCheck = withRequest $ \ request -> do + -- Content-Type check, we only lookup we can try to parse the request body + ctCheck = withRequest $ \ request -> do -- See HTTP RFC 2616, section 7.2.1 -- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 -- See also "W3C Internet Media Type registration, consistency of use" -- http://www.w3.org/2001/tag/2002/0129-mime let contentTypeH = fromMaybe "application/octet-stream" $ lookup hContentType $ requestHeaders request - mrqbody <- handleCTypeH (Proxy :: Proxy list) (cs contentTypeH) - <$> liftIO (lazyRequestBody request) + case canHandleCTypeH (Proxy :: Proxy list) (cs contentTypeH) :: Maybe (BL.ByteString -> Either String a) of + Nothing -> delayedFailFatal err415 + Just f -> return f + + -- Body check, we get a body parsing functions as the first argument. + bodyCheck f = withRequest $ \ request -> do + mrqbody <- f <$> liftIO (lazyRequestBody request) case mrqbody of - Nothing -> delayedFailFatal err415 - Just (Left e) -> delayedFailFatal err400 { errBody = cs e } - Just (Right v) -> return v + Left e -> delayedFailFatal err400 { errBody = cs e } + Right v -> return v -- | Make sure the incoming request starts with @"/path"@, strip it and -- pass the rest of the request path to @api@. diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 3cee450b..9c8a411c 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -132,14 +132,14 @@ toApplication ra request respond = ra request routingRespond -- 405 (bad method) -- 401 (unauthorized) -- 415 (unsupported media type) --- 400 (bad request) -- 406 (not acceptable) +-- 400 (bad request) -- @ -- -- Therefore, while routing, we delay most checks so that they -- will ultimately occur in the right order. -- --- A 'Delayed' contains four delayed blocks of tests, and +-- A 'Delayed' contains many delayed blocks of tests, and -- the actual handler: -- -- 1. Delayed captures. These can actually cause 404, and @@ -148,24 +148,28 @@ toApplication ra request respond = ra request routingRespond -- check order from the error reporting, see above). Delayed -- captures can provide inputs to the actual handler. -- --- 2. Query parameter checks. They require parsing and can cause 400 if the --- parsing fails. Query parameter checks provide inputs to the handler --- --- 3. Method check(s). This can cause a 405. On success, +-- 2. Method check(s). This can cause a 405. On success, -- it does not provide an input for the handler. Method checks -- are comparatively cheap. -- --- 4. Body and accept header checks. The request body check can --- cause both 400 and 415. This provides an input to the handler. --- The accept header check can be performed as the final --- computation in this block. It can cause a 406. +-- 3. Authentication checks. This can cause 401. +-- +-- 4. Accept and content type header checks. These checks +-- can cause 415 and 406 errors. +-- +-- 5. Query parameter checks. They require parsing and can cause 400 if the +-- parsing fails. Query parameter checks provide inputs to the handler +-- +-- 6. Body check. The request body check can cause 400. -- data Delayed env c where Delayed :: { capturesD :: env -> DelayedIO captures - , paramsD :: DelayedIO params , methodD :: DelayedIO () , authD :: DelayedIO auth - , bodyD :: DelayedIO body + , acceptD :: DelayedIO () + , contentD :: DelayedIO contentType + , paramsD :: DelayedIO params + , bodyD :: contentType -> DelayedIO body , serverD :: captures -> params -> auth @@ -209,7 +213,7 @@ runDelayedIO m req = transResourceT runRouteResultT $ runReaderT (runDelayedIO' -- | A 'Delayed' without any stored checks. emptyDelayed :: RouteResult a -> Delayed env a emptyDelayed result = - Delayed (const r) r r r r (\ _ _ _ _ _ -> result) + Delayed (const r) r r r r r (const r) (\ _ _ _ _ _ -> result) where r = return () @@ -270,20 +274,25 @@ addAuthCheck Delayed{..} new = , .. } -- Note [Existential Record Update] --- | Add a body check to the end of the body block. +-- | Add a content type and body checks around parameter checks. +-- +-- We'll report failed content type check (415), before trying to parse +-- query parameters (400). Which, in turn, happens before request body parsing. addBodyCheck :: Delayed env (a -> b) - -> DelayedIO a + -> DelayedIO c -- ^ content type check + -> (c -> DelayedIO a) -- ^ body check -> Delayed env b -addBodyCheck Delayed{..} new = +addBodyCheck Delayed{..} newContentD newBodyD = Delayed - { bodyD = (,) <$> bodyD <*> new + { contentD = (,) <$> contentD <*> newContentD + , bodyD = \(content, c) -> (,) <$> bodyD content <*> newBodyD c , serverD = \ c p a (z, v) req -> ($ v) <$> serverD c p a z req , .. } -- Note [Existential Record Update] --- | Add an accept header check to the beginning of the body --- block. There is a tradeoff here. In principle, we'd like +-- | Add an accept header check before handling parameters. +-- In principle, we'd like -- to take a bad body (400) response take precedence over a -- failed accept check (406). BUT to allow streaming the body, -- we cannot run the body check and then still backtrack. @@ -297,7 +306,7 @@ addAcceptCheck :: Delayed env a -> Delayed env a addAcceptCheck Delayed{..} new = Delayed - { bodyD = new *> bodyD + { acceptD = acceptD *> new , .. } -- Note [Existential Record Update] @@ -322,12 +331,14 @@ runDelayed :: Delayed env a -> Request -> ResourceT IO (RouteResult a) runDelayed Delayed{..} env = runDelayedIO $ do + r <- ask c <- capturesD env methodD a <- authD - b <- bodyD - r <- ask - p <- paramsD -- Has to be after body to respect the relative error order + acceptD + content <- contentD + p <- paramsD -- Has to be before body parsing, but after content-type checks + b <- bodyD content liftRouteResult (serverD c p a b r) -- | Runs a delayed server and the resulting action. diff --git a/servant-server/test/Servant/Server/ErrorSpec.hs b/servant-server/test/Servant/Server/ErrorSpec.hs index 5efb7051..787185da 100644 --- a/servant-server/test/Servant/Server/ErrorSpec.hs +++ b/servant-server/test/Servant/Server/ErrorSpec.hs @@ -6,6 +6,7 @@ {-# OPTIONS_GHC -fno-warn-orphans #-} module Servant.Server.ErrorSpec (spec) where +import Control.Monad (when) import Data.Aeson (encode) import qualified Data.ByteString.Char8 as BC import qualified Data.ByteString.Lazy.Char8 as BCL @@ -114,10 +115,19 @@ errorOrderSpec = `shouldRespondWith` 415 it "has 400 as its sixth highest priority error" $ do - request goodMethod badParams [goodAuth, goodContentType, goodAccept] goodBody - `shouldRespondWith` 400 - request goodMethod goodUrl [goodAuth, goodContentType, goodAccept] badBody - `shouldRespondWith` 400 + badParamsRes <- request goodMethod badParams [goodAuth, goodContentType, goodAccept] goodBody + badBodyRes <- request goodMethod goodUrl [goodAuth, goodContentType, goodAccept] badBody + + -- Both bad body and bad params result in 400 + return badParamsRes `shouldRespondWith` 400 + return badBodyRes `shouldRespondWith` 400 + + -- Param check should occur before body checks + both <- request goodMethod badParams [goodAuth, goodContentType, goodAccept ] badBody + when (both /= badParamsRes) $ liftIO $ + expectationFailure $ "badParams + badBody /= badParams: " ++ show both ++ ", " ++ show badParamsRes + when (both == badBodyRes) $ liftIO $ + expectationFailure $ "badParams + badBody == badBody: " ++ show both it "has handler-level errors as last priority" $ do request goodMethod goodUrl [goodAuth, goodContentType, goodAccept] goodBody diff --git a/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs index c4b72fbf..a3be12f5 100644 --- a/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs +++ b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs @@ -19,6 +19,7 @@ import Data.Proxy import GHC.TypeLits (Symbol, KnownSymbol, symbolVal) import Servant import Servant.Server.Internal.RoutingApplication +import Network.Wai (defaultRequest) import Test.Hspec import Test.Hspec.Wai (request, shouldRespondWith, with) @@ -56,12 +57,14 @@ freeTestResource = modifyIORef delayedTestRef $ \r -> case r of delayed :: DelayedIO () -> RouteResult (Handler ()) -> Delayed () (Handler ()) delayed body srv = Delayed - { capturesD = \_ -> return () + { capturesD = \() -> return () , methodD = return () , authD = return () + , acceptD = return () + , contentD = return () , paramsD = return () - , bodyD = do - liftIO (writeTestResource"hia" >> putStrLn "garbage created") + , bodyD = \() -> do + liftIO (writeTestResource "hia" >> putStrLn "garbage created") _ <- register (freeTestResource >> putStrLn "garbage collected") body , serverD = \() () () _body _req -> srv @@ -70,7 +73,7 @@ delayed body srv = Delayed simpleRun :: Delayed () (Handler ()) -> IO () simpleRun d = fmap (either ignoreE id) . try $ - runAction d () undefined (\_ -> return ()) (\_ -> FailFatal err500) + runAction d () defaultRequest (\_ -> return ()) (\_ -> FailFatal err500) where ignoreE :: SomeException -> () ignoreE = const () @@ -85,10 +88,10 @@ data Res (sym :: Symbol) instance (KnownSymbol sym, HasServer api ctx) => HasServer (Res sym :> api) ctx where type ServerT (Res sym :> api) m = IORef (TestResource String) -> ServerT api m route Proxy ctx server = route (Proxy :: Proxy api) ctx $ - server `addBodyCheck` check + addBodyCheck server (return ()) check where sym = symbolVal (Proxy :: Proxy sym) - check = do + check () = do liftIO $ writeTestResource sym _ <- register freeTestResource return delayedTestRef diff --git a/servant/src/Servant/API/ContentTypes.hs b/servant/src/Servant/API/ContentTypes.hs index d13d9951..d5967e2a 100644 --- a/servant/src/Servant/API/ContentTypes.hs +++ b/servant/src/Servant/API/ContentTypes.hs @@ -220,14 +220,20 @@ class Accept ctype => MimeUnrender ctype a where {-# MINIMAL mimeUnrender | mimeUnrenderWithType #-} class AllCTUnrender (list :: [*]) a where + canHandleCTypeH + :: Proxy list + -> ByteString -- Content-Type header + -> Maybe (ByteString -> Either String a) + handleCTypeH :: Proxy list -> ByteString -- Content-Type header -> ByteString -- Request body -> Maybe (Either String a) + handleCTypeH p ctypeH body = ($ body) `fmap` canHandleCTypeH p ctypeH instance ( AllMimeUnrender ctyps a ) => AllCTUnrender ctyps a where - handleCTypeH _ ctypeH body = M.mapContentMedia lkup (cs ctypeH) - where lkup = allMimeUnrender (Proxy :: Proxy ctyps) body + canHandleCTypeH p ctypeH = + M.mapContentMedia (allMimeUnrender p) (cs ctypeH) -------------------------------------------------------------------------- -- * Utils (Internal) @@ -292,20 +298,19 @@ instance OVERLAPPING_ -------------------------------------------------------------------------- class (AllMime list) => AllMimeUnrender (list :: [*]) a where allMimeUnrender :: Proxy list - -> ByteString - -> [(M.MediaType, Either String a)] + -> [(M.MediaType, ByteString -> Either String a)] instance AllMimeUnrender '[] a where - allMimeUnrender _ _ = [] + allMimeUnrender _ = [] instance ( MimeUnrender ctyp a , AllMimeUnrender ctyps a ) => AllMimeUnrender (ctyp ': ctyps) a where - allMimeUnrender _ bs = + allMimeUnrender _ = (map mk $ NE.toList $ contentTypes pctyp) - ++ allMimeUnrender pctyps bs + ++ allMimeUnrender pctyps where - mk ct = (ct, mimeUnrenderWithType pctyp ct bs) + mk ct = (ct, \bs -> mimeUnrenderWithType pctyp ct bs) pctyp = Proxy :: Proxy ctyp pctyps = Proxy :: Proxy ctyps