Resolve todos

This commit is contained in:
Oleg Grenrus 2017-01-16 14:17:20 +02:00
parent a61551b87f
commit 58e931f48a
5 changed files with 108 additions and 74 deletions

View file

@ -28,10 +28,11 @@ import Control.Monad.Trans.Resource (runResourceT)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC8 import qualified Data.ByteString.Char8 as BC8
import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy as BL
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Either (partitionEithers) import Data.Either (partitionEithers)
import Data.Maybe (fromMaybe)
import Data.String (fromString) import Data.String (fromString)
import Data.String.Conversions (cs, (<>)) import Data.String.Conversions (cs, (<>))
import qualified Data.Text as T
import Data.Typeable import Data.Typeable
import GHC.TypeLits (KnownNat, KnownSymbol, natVal, import GHC.TypeLits (KnownNat, KnownSymbol, natVal,
symbolVal) symbolVal)
@ -319,10 +320,9 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
Just Nothing -> return Nothing -- param present with no value -> Nothing Just Nothing -> return Nothing -- param present with no value -> Nothing
Just (Just v) -> Just (Just v) ->
case parseQueryParam v of case parseQueryParam v of
-- TODO: This should set an error description (including Left e -> delayedFailFatal err400
-- paramname) { errBody = cs $ "Error parsing query parameter " <> paramname <> " failed: " <> e
Left _e -> delayedFailFatal err400 -- parsing the request }
-- paramter failed
Right param -> return $ Just param Right param -> return $ Just param
delayed = addParameterCheck subserver . withRequest $ \req -> delayed = addParameterCheck subserver . withRequest $ \req ->
@ -356,26 +356,25 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
type ServerT (QueryParams sym a :> api) m = type ServerT (QueryParams sym a :> api) m =
[a] -> ServerT api m [a] -> ServerT api m
route Proxy context subserver = route Proxy context subserver = route (Proxy :: Proxy api) context $
let querytext r = parseQueryText $ rawQueryString r subserver `addParameterCheck` withRequest paramsCheck
-- if sym is "foo", we look for query string parameters where
-- named "foo" or "foo[]" and call parseQueryParam on the paramname = cs $ symbolVal (Proxy :: Proxy sym)
-- corresponding values paramsCheck req =
parameters r = filter looksLikeParam (querytext r) case partitionEithers $ fmap parseQueryParam params of
parseParam (paramName, paramTxt) = ([], parsed) -> return parsed
case parseQueryParam (fromMaybe "" paramTxt) of (errs, _) -> delayedFailFatal err400
Left _e -> Left paramName -- On error, remember name of parameter { errBody = cs $ "Error parsing query parameter(s) " <> paramname <> " failed: " <> T.intercalate ", " errs
Right paramVal -> Right paramVal }
parseParams req = where
case partitionEithers $ parseParam <$> parameters req of params :: [T.Text]
([], params) -> return params -- No errors params = mapMaybe snd
-- TODO: This should set an error description . filter (looksLikeParam . fst)
(_errors, _) -> delayedFailFatal err400 . parseQueryText
delayed = addParameterCheck subserver . withRequest $ \req -> . rawQueryString
parseParams req $ req
in route (Proxy :: Proxy api) context delayed
where paramname = cs $ symbolVal (Proxy :: Proxy sym) looksLikeParam name = name == paramname || name == (paramname <> "[]")
looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]")
-- | If you use @'QueryFlag' "published"@ in one of the endpoints for your API, -- | If you use @'QueryFlag' "published"@ in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a function -- this automatically requires your server-side handler to be a function
@ -457,22 +456,28 @@ instance ( AllCTUnrender list a, HasServer api context
type ServerT (ReqBody list a :> api) m = type ServerT (ReqBody list a :> api) m =
a -> ServerT api m a -> ServerT api m
route Proxy context subserver = route Proxy context subserver
route (Proxy :: Proxy api) context (addBodyCheck subserver bodyCheck) = route (Proxy :: Proxy api) context $
addBodyCheck subserver ctCheck bodyCheck
where where
bodyCheck = withRequest $ \ request -> do -- Content-Type check, we only lookup we can try to parse the request body
ctCheck = withRequest $ \ request -> do
-- See HTTP RFC 2616, section 7.2.1 -- See HTTP RFC 2616, section 7.2.1
-- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 -- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
-- See also "W3C Internet Media Type registration, consistency of use" -- See also "W3C Internet Media Type registration, consistency of use"
-- http://www.w3.org/2001/tag/2002/0129-mime -- http://www.w3.org/2001/tag/2002/0129-mime
let contentTypeH = fromMaybe "application/octet-stream" let contentTypeH = fromMaybe "application/octet-stream"
$ lookup hContentType $ requestHeaders request $ lookup hContentType $ requestHeaders request
mrqbody <- handleCTypeH (Proxy :: Proxy list) (cs contentTypeH) case canHandleCTypeH (Proxy :: Proxy list) (cs contentTypeH) :: Maybe (BL.ByteString -> Either String a) of
<$> liftIO (lazyRequestBody request) Nothing -> delayedFailFatal err415
Just f -> return f
-- Body check, we get a body parsing functions as the first argument.
bodyCheck f = withRequest $ \ request -> do
mrqbody <- f <$> liftIO (lazyRequestBody request)
case mrqbody of case mrqbody of
Nothing -> delayedFailFatal err415 Left e -> delayedFailFatal err400 { errBody = cs e }
Just (Left e) -> delayedFailFatal err400 { errBody = cs e } Right v -> return v
Just (Right v) -> return v
-- | Make sure the incoming request starts with @"/path"@, strip it and -- | Make sure the incoming request starts with @"/path"@, strip it and
-- pass the rest of the request path to @api@. -- pass the rest of the request path to @api@.

View file

@ -132,14 +132,14 @@ toApplication ra request respond = ra request routingRespond
-- 405 (bad method) -- 405 (bad method)
-- 401 (unauthorized) -- 401 (unauthorized)
-- 415 (unsupported media type) -- 415 (unsupported media type)
-- 400 (bad request)
-- 406 (not acceptable) -- 406 (not acceptable)
-- 400 (bad request)
-- @ -- @
-- --
-- Therefore, while routing, we delay most checks so that they -- Therefore, while routing, we delay most checks so that they
-- will ultimately occur in the right order. -- will ultimately occur in the right order.
-- --
-- A 'Delayed' contains four delayed blocks of tests, and -- A 'Delayed' contains many delayed blocks of tests, and
-- the actual handler: -- the actual handler:
-- --
-- 1. Delayed captures. These can actually cause 404, and -- 1. Delayed captures. These can actually cause 404, and
@ -148,24 +148,28 @@ toApplication ra request respond = ra request routingRespond
-- check order from the error reporting, see above). Delayed -- check order from the error reporting, see above). Delayed
-- captures can provide inputs to the actual handler. -- captures can provide inputs to the actual handler.
-- --
-- 2. Query parameter checks. They require parsing and can cause 400 if the -- 2. Method check(s). This can cause a 405. On success,
-- parsing fails. Query parameter checks provide inputs to the handler
--
-- 3. Method check(s). This can cause a 405. On success,
-- it does not provide an input for the handler. Method checks -- it does not provide an input for the handler. Method checks
-- are comparatively cheap. -- are comparatively cheap.
-- --
-- 4. Body and accept header checks. The request body check can -- 3. Authentication checks. This can cause 401.
-- cause both 400 and 415. This provides an input to the handler. --
-- The accept header check can be performed as the final -- 4. Accept and content type header checks. These checks
-- computation in this block. It can cause a 406. -- can cause 415 and 406 errors.
--
-- 5. Query parameter checks. They require parsing and can cause 400 if the
-- parsing fails. Query parameter checks provide inputs to the handler
--
-- 6. Body check. The request body check can cause 400.
-- --
data Delayed env c where data Delayed env c where
Delayed :: { capturesD :: env -> DelayedIO captures Delayed :: { capturesD :: env -> DelayedIO captures
, paramsD :: DelayedIO params
, methodD :: DelayedIO () , methodD :: DelayedIO ()
, authD :: DelayedIO auth , authD :: DelayedIO auth
, bodyD :: DelayedIO body , acceptD :: DelayedIO ()
, contentD :: DelayedIO contentType
, paramsD :: DelayedIO params
, bodyD :: contentType -> DelayedIO body
, serverD :: captures , serverD :: captures
-> params -> params
-> auth -> auth
@ -209,7 +213,7 @@ runDelayedIO m req = transResourceT runRouteResultT $ runReaderT (runDelayedIO'
-- | A 'Delayed' without any stored checks. -- | A 'Delayed' without any stored checks.
emptyDelayed :: RouteResult a -> Delayed env a emptyDelayed :: RouteResult a -> Delayed env a
emptyDelayed result = emptyDelayed result =
Delayed (const r) r r r r (\ _ _ _ _ _ -> result) Delayed (const r) r r r r r (const r) (\ _ _ _ _ _ -> result)
where where
r = return () r = return ()
@ -270,20 +274,25 @@ addAuthCheck Delayed{..} new =
, .. , ..
} -- Note [Existential Record Update] } -- Note [Existential Record Update]
-- | Add a body check to the end of the body block. -- | Add a content type and body checks around parameter checks.
--
-- We'll report failed content type check (415), before trying to parse
-- query parameters (400). Which, in turn, happens before request body parsing.
addBodyCheck :: Delayed env (a -> b) addBodyCheck :: Delayed env (a -> b)
-> DelayedIO a -> DelayedIO c -- ^ content type check
-> (c -> DelayedIO a) -- ^ body check
-> Delayed env b -> Delayed env b
addBodyCheck Delayed{..} new = addBodyCheck Delayed{..} newContentD newBodyD =
Delayed Delayed
{ bodyD = (,) <$> bodyD <*> new { contentD = (,) <$> contentD <*> newContentD
, bodyD = \(content, c) -> (,) <$> bodyD content <*> newBodyD c
, serverD = \ c p a (z, v) req -> ($ v) <$> serverD c p a z req , serverD = \ c p a (z, v) req -> ($ v) <$> serverD c p a z req
, .. , ..
} -- Note [Existential Record Update] } -- Note [Existential Record Update]
-- | Add an accept header check to the beginning of the body -- | Add an accept header check before handling parameters.
-- block. There is a tradeoff here. In principle, we'd like -- In principle, we'd like
-- to take a bad body (400) response take precedence over a -- to take a bad body (400) response take precedence over a
-- failed accept check (406). BUT to allow streaming the body, -- failed accept check (406). BUT to allow streaming the body,
-- we cannot run the body check and then still backtrack. -- we cannot run the body check and then still backtrack.
@ -297,7 +306,7 @@ addAcceptCheck :: Delayed env a
-> Delayed env a -> Delayed env a
addAcceptCheck Delayed{..} new = addAcceptCheck Delayed{..} new =
Delayed Delayed
{ bodyD = new *> bodyD { acceptD = acceptD *> new
, .. , ..
} -- Note [Existential Record Update] } -- Note [Existential Record Update]
@ -322,12 +331,14 @@ runDelayed :: Delayed env a
-> Request -> Request
-> ResourceT IO (RouteResult a) -> ResourceT IO (RouteResult a)
runDelayed Delayed{..} env = runDelayedIO $ do runDelayed Delayed{..} env = runDelayedIO $ do
r <- ask
c <- capturesD env c <- capturesD env
methodD methodD
a <- authD a <- authD
b <- bodyD acceptD
r <- ask content <- contentD
p <- paramsD -- Has to be after body to respect the relative error order p <- paramsD -- Has to be before body parsing, but after content-type checks
b <- bodyD content
liftRouteResult (serverD c p a b r) liftRouteResult (serverD c p a b r)
-- | Runs a delayed server and the resulting action. -- | Runs a delayed server and the resulting action.

View file

@ -6,6 +6,7 @@
{-# OPTIONS_GHC -fno-warn-orphans #-} {-# OPTIONS_GHC -fno-warn-orphans #-}
module Servant.Server.ErrorSpec (spec) where module Servant.Server.ErrorSpec (spec) where
import Control.Monad (when)
import Data.Aeson (encode) import Data.Aeson (encode)
import qualified Data.ByteString.Char8 as BC import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy.Char8 as BCL import qualified Data.ByteString.Lazy.Char8 as BCL
@ -114,10 +115,19 @@ errorOrderSpec =
`shouldRespondWith` 415 `shouldRespondWith` 415
it "has 400 as its sixth highest priority error" $ do it "has 400 as its sixth highest priority error" $ do
request goodMethod badParams [goodAuth, goodContentType, goodAccept] goodBody badParamsRes <- request goodMethod badParams [goodAuth, goodContentType, goodAccept] goodBody
`shouldRespondWith` 400 badBodyRes <- request goodMethod goodUrl [goodAuth, goodContentType, goodAccept] badBody
request goodMethod goodUrl [goodAuth, goodContentType, goodAccept] badBody
`shouldRespondWith` 400 -- Both bad body and bad params result in 400
return badParamsRes `shouldRespondWith` 400
return badBodyRes `shouldRespondWith` 400
-- Param check should occur before body checks
both <- request goodMethod badParams [goodAuth, goodContentType, goodAccept ] badBody
when (both /= badParamsRes) $ liftIO $
expectationFailure $ "badParams + badBody /= badParams: " ++ show both ++ ", " ++ show badParamsRes
when (both == badBodyRes) $ liftIO $
expectationFailure $ "badParams + badBody == badBody: " ++ show both
it "has handler-level errors as last priority" $ do it "has handler-level errors as last priority" $ do
request goodMethod goodUrl [goodAuth, goodContentType, goodAccept] goodBody request goodMethod goodUrl [goodAuth, goodContentType, goodAccept] goodBody

View file

@ -19,6 +19,7 @@ import Data.Proxy
import GHC.TypeLits (Symbol, KnownSymbol, symbolVal) import GHC.TypeLits (Symbol, KnownSymbol, symbolVal)
import Servant import Servant
import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.RoutingApplication
import Network.Wai (defaultRequest)
import Test.Hspec import Test.Hspec
import Test.Hspec.Wai (request, shouldRespondWith, with) import Test.Hspec.Wai (request, shouldRespondWith, with)
@ -56,12 +57,14 @@ freeTestResource = modifyIORef delayedTestRef $ \r -> case r of
delayed :: DelayedIO () -> RouteResult (Handler ()) -> Delayed () (Handler ()) delayed :: DelayedIO () -> RouteResult (Handler ()) -> Delayed () (Handler ())
delayed body srv = Delayed delayed body srv = Delayed
{ capturesD = \_ -> return () { capturesD = \() -> return ()
, methodD = return () , methodD = return ()
, authD = return () , authD = return ()
, acceptD = return ()
, contentD = return ()
, paramsD = return () , paramsD = return ()
, bodyD = do , bodyD = \() -> do
liftIO (writeTestResource"hia" >> putStrLn "garbage created") liftIO (writeTestResource "hia" >> putStrLn "garbage created")
_ <- register (freeTestResource >> putStrLn "garbage collected") _ <- register (freeTestResource >> putStrLn "garbage collected")
body body
, serverD = \() () () _body _req -> srv , serverD = \() () () _body _req -> srv
@ -70,7 +73,7 @@ delayed body srv = Delayed
simpleRun :: Delayed () (Handler ()) simpleRun :: Delayed () (Handler ())
-> IO () -> IO ()
simpleRun d = fmap (either ignoreE id) . try $ simpleRun d = fmap (either ignoreE id) . try $
runAction d () undefined (\_ -> return ()) (\_ -> FailFatal err500) runAction d () defaultRequest (\_ -> return ()) (\_ -> FailFatal err500)
where ignoreE :: SomeException -> () where ignoreE :: SomeException -> ()
ignoreE = const () ignoreE = const ()
@ -85,10 +88,10 @@ data Res (sym :: Symbol)
instance (KnownSymbol sym, HasServer api ctx) => HasServer (Res sym :> api) ctx where instance (KnownSymbol sym, HasServer api ctx) => HasServer (Res sym :> api) ctx where
type ServerT (Res sym :> api) m = IORef (TestResource String) -> ServerT api m type ServerT (Res sym :> api) m = IORef (TestResource String) -> ServerT api m
route Proxy ctx server = route (Proxy :: Proxy api) ctx $ route Proxy ctx server = route (Proxy :: Proxy api) ctx $
server `addBodyCheck` check addBodyCheck server (return ()) check
where where
sym = symbolVal (Proxy :: Proxy sym) sym = symbolVal (Proxy :: Proxy sym)
check = do check () = do
liftIO $ writeTestResource sym liftIO $ writeTestResource sym
_ <- register freeTestResource _ <- register freeTestResource
return delayedTestRef return delayedTestRef

View file

@ -220,14 +220,20 @@ class Accept ctype => MimeUnrender ctype a where
{-# MINIMAL mimeUnrender | mimeUnrenderWithType #-} {-# MINIMAL mimeUnrender | mimeUnrenderWithType #-}
class AllCTUnrender (list :: [*]) a where class AllCTUnrender (list :: [*]) a where
canHandleCTypeH
:: Proxy list
-> ByteString -- Content-Type header
-> Maybe (ByteString -> Either String a)
handleCTypeH :: Proxy list handleCTypeH :: Proxy list
-> ByteString -- Content-Type header -> ByteString -- Content-Type header
-> ByteString -- Request body -> ByteString -- Request body
-> Maybe (Either String a) -> Maybe (Either String a)
handleCTypeH p ctypeH body = ($ body) `fmap` canHandleCTypeH p ctypeH
instance ( AllMimeUnrender ctyps a ) => AllCTUnrender ctyps a where instance ( AllMimeUnrender ctyps a ) => AllCTUnrender ctyps a where
handleCTypeH _ ctypeH body = M.mapContentMedia lkup (cs ctypeH) canHandleCTypeH p ctypeH =
where lkup = allMimeUnrender (Proxy :: Proxy ctyps) body M.mapContentMedia (allMimeUnrender p) (cs ctypeH)
-------------------------------------------------------------------------- --------------------------------------------------------------------------
-- * Utils (Internal) -- * Utils (Internal)
@ -292,20 +298,19 @@ instance OVERLAPPING_
-------------------------------------------------------------------------- --------------------------------------------------------------------------
class (AllMime list) => AllMimeUnrender (list :: [*]) a where class (AllMime list) => AllMimeUnrender (list :: [*]) a where
allMimeUnrender :: Proxy list allMimeUnrender :: Proxy list
-> ByteString -> [(M.MediaType, ByteString -> Either String a)]
-> [(M.MediaType, Either String a)]
instance AllMimeUnrender '[] a where instance AllMimeUnrender '[] a where
allMimeUnrender _ _ = [] allMimeUnrender _ = []
instance ( MimeUnrender ctyp a instance ( MimeUnrender ctyp a
, AllMimeUnrender ctyps a , AllMimeUnrender ctyps a
) => AllMimeUnrender (ctyp ': ctyps) a where ) => AllMimeUnrender (ctyp ': ctyps) a where
allMimeUnrender _ bs = allMimeUnrender _ =
(map mk $ NE.toList $ contentTypes pctyp) (map mk $ NE.toList $ contentTypes pctyp)
++ allMimeUnrender pctyps bs ++ allMimeUnrender pctyps
where where
mk ct = (ct, mimeUnrenderWithType pctyp ct bs) mk ct = (ct, \bs -> mimeUnrenderWithType pctyp ct bs)
pctyp = Proxy :: Proxy ctyp pctyp = Proxy :: Proxy ctyp
pctyps = Proxy :: Proxy ctyps pctyps = Proxy :: Proxy ctyps