From 08786aadbec639b2f769ca7059b5a429ee447251 Mon Sep 17 00:00:00 2001 From: Philipp Balzarek Date: Thu, 6 Apr 2017 13:59:16 +0200 Subject: [PATCH] Check for parse errors in HasServer Header instance --- servant-server/src/Servant/Server/Internal.hs | 28 +++++++++++---- .../Server/Internal/RoutingApplication.hs | 34 ++++++++++++++----- .../Server/Internal/RoutingApplicationSpec.hs | 3 +- servant-server/test/Servant/ServerSpec.hs | 7 ++++ 4 files changed, 55 insertions(+), 17 deletions(-) diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index 686cf59d..926e654e 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -46,7 +46,7 @@ import Network.Wai (Application, Request, Response, responseLBS, vault) import Prelude () import Prelude.Compat -import Web.HttpApiData (FromHttpApiData, parseHeaderMaybe, +import Web.HttpApiData (FromHttpApiData, parseHeader, parseQueryParam, parseUrlPieceMaybe, parseUrlPieces) @@ -280,10 +280,21 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context) type ServerT (Header sym a :> api) m = Maybe a -> ServerT api m - route Proxy context subserver = - let mheader req = parseHeaderMaybe =<< lookup str (requestHeaders req) - in route (Proxy :: Proxy api) context (passToServer subserver mheader) - where str = fromString $ symbolVal (Proxy :: Proxy sym) + route Proxy context subserver = route (Proxy :: Proxy api) context $ + subserver `addHeaderCheck` withRequest headerCheck + where + headerName = symbolVal (Proxy :: Proxy sym) + headerCheck req = + case lookup (fromString headerName) (requestHeaders req) of + Nothing -> return Nothing + Just txt -> + case parseHeader txt of + Left e -> delayedFailFatal err400 + { errBody = cs $ "Error parsing header " + <> fromString headerName + <> " failed: " <> e + } + Right header -> return $ Just header -- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API, -- this automatically requires your server-side handler to be a function @@ -321,7 +332,8 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context) Just (Just v) -> case parseQueryParam v of Left e -> delayedFailFatal err400 - { errBody = cs $ "Error parsing query parameter " <> paramname <> " failed: " <> e + { errBody = cs $ "Error parsing query parameter " + <> paramname <> " failed: " <> e } Right param -> return $ Just param @@ -364,7 +376,9 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context) case partitionEithers $ fmap parseQueryParam params of ([], parsed) -> return parsed (errs, _) -> delayedFailFatal err400 - { errBody = cs $ "Error parsing query parameter(s) " <> paramname <> " failed: " <> T.intercalate ", " errs + { errBody = cs $ "Error parsing query parameter(s) " + <> paramname <> " failed: " + <> T.intercalate ", " errs } where params :: [T.Text] diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 9c8a411c..9488a70a 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -160,7 +160,9 @@ toApplication ra request respond = ra request routingRespond -- 5. Query parameter checks. They require parsing and can cause 400 if the -- parsing fails. Query parameter checks provide inputs to the handler -- --- 6. Body check. The request body check can cause 400. +-- 6. Header Checks. They also require parsing and can cause 400 if parsing fails. +-- +-- 7. Body check. The request body check can cause 400. -- data Delayed env c where Delayed :: { capturesD :: env -> DelayedIO captures @@ -169,9 +171,11 @@ data Delayed env c where , acceptD :: DelayedIO () , contentD :: DelayedIO contentType , paramsD :: DelayedIO params + , headersD :: DelayedIO headers , bodyD :: contentType -> DelayedIO body , serverD :: captures -> params + -> headers -> auth -> body -> Request @@ -181,7 +185,7 @@ data Delayed env c where instance Functor (Delayed env) where fmap f Delayed{..} = Delayed - { serverD = \ c p a b req -> f <$> serverD c p a b req + { serverD = \ c p h a b req -> f <$> serverD c p h a b req , .. } -- Note [Existential Record Update] @@ -213,7 +217,7 @@ runDelayedIO m req = transResourceT runRouteResultT $ runReaderT (runDelayedIO' -- | A 'Delayed' without any stored checks. emptyDelayed :: RouteResult a -> Delayed env a emptyDelayed result = - Delayed (const r) r r r r r (const r) (\ _ _ _ _ _ -> result) + Delayed (const r) r r r r r r (const r) (\ _ _ _ _ _ _ -> result) where r = return () @@ -238,7 +242,7 @@ addCapture :: Delayed env (a -> b) addCapture Delayed{..} new = Delayed { capturesD = \ (txt, env) -> (,) <$> capturesD env <*> new txt - , serverD = \ (x, v) p a b req -> ($ v) <$> serverD x p a b req + , serverD = \ (x, v) p h a b req -> ($ v) <$> serverD x p h a b req , .. } -- Note [Existential Record Update] @@ -249,7 +253,18 @@ addParameterCheck :: Delayed env (a -> b) addParameterCheck Delayed {..} new = Delayed { paramsD = (,) <$> paramsD <*> new - , serverD = \c (p, pNew) a b req -> ($ pNew) <$> serverD c p a b req + , serverD = \c (p, pNew) h a b req -> ($ pNew) <$> serverD c p h a b req + , .. + } + +-- | Add a parameter check to the end of the params block +addHeaderCheck :: Delayed env (a -> b) + -> DelayedIO a + -> Delayed env b +addHeaderCheck Delayed {..} new = + Delayed + { headersD = (,) <$> headersD <*> new + , serverD = \c p (h, hNew) a b req -> ($ hNew) <$> serverD c p h a b req , .. } @@ -270,7 +285,7 @@ addAuthCheck :: Delayed env (a -> b) addAuthCheck Delayed{..} new = Delayed { authD = (,) <$> authD <*> new - , serverD = \ c p (y, v) b req -> ($ v) <$> serverD c p y b req + , serverD = \ c p h (y, v) b req -> ($ v) <$> serverD c p h y b req , .. } -- Note [Existential Record Update] @@ -286,7 +301,7 @@ addBodyCheck Delayed{..} newContentD newBodyD = Delayed { contentD = (,) <$> contentD <*> newContentD , bodyD = \(content, c) -> (,) <$> bodyD content <*> newBodyD c - , serverD = \ c p a (z, v) req -> ($ v) <$> serverD c p a z req + , serverD = \ c p h a (z, v) req -> ($ v) <$> serverD c p h a z req , .. } -- Note [Existential Record Update] @@ -316,7 +331,7 @@ addAcceptCheck Delayed{..} new = passToServer :: Delayed env (a -> b) -> (Request -> a) -> Delayed env b passToServer Delayed{..} x = Delayed - { serverD = \ c p a b req -> ($ x req) <$> serverD c p a b req + { serverD = \ c p h a b req -> ($ x req) <$> serverD c p h a b req , .. } -- Note [Existential Record Update] @@ -338,8 +353,9 @@ runDelayed Delayed{..} env = runDelayedIO $ do acceptD content <- contentD p <- paramsD -- Has to be before body parsing, but after content-type checks + h <- headersD b <- bodyD content - liftRouteResult (serverD c p a b r) + liftRouteResult (serverD c p h a b r) -- | Runs a delayed server and the resulting action. -- Takes a continuation that lets us send a response. diff --git a/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs index a3be12f5..30710db0 100644 --- a/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs +++ b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs @@ -63,11 +63,12 @@ delayed body srv = Delayed , acceptD = return () , contentD = return () , paramsD = return () + , headersD = return () , bodyD = \() -> do liftIO (writeTestResource "hia" >> putStrLn "garbage created") _ <- register (freeTestResource >> putStrLn "garbage collected") body - , serverD = \() () () _body _req -> srv + , serverD = \() () () () _body _req -> srv } simpleRun :: Delayed () (Handler ()) diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index c0042f44..40a850c7 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -477,6 +477,13 @@ headerSpec = describe "Servant.API.Header" $ do it "passes the header to the handler (String)" $ delete' "/" "" `shouldRespondWith` 200 + with (return (serve headerApi expectsInt)) $ do + let delete' x = THW.request methodDelete x [("MyHeader", "not a number")] + + it "checks for parse errors" $ + delete' "/" "" `shouldRespondWith` 400 + + -- }}} ------------------------------------------------------------------------------ -- * rawSpec {{{