From 08786aadbec639b2f769ca7059b5a429ee447251 Mon Sep 17 00:00:00 2001
From: Philipp Balzarek
Date: Thu, 6 Apr 2017 13:59:16 +0200
Subject: [PATCH] Check for parse errors in HasServer Header instance
---
servant-server/src/Servant/Server/Internal.hs | 28 +++++++++++----
.../Server/Internal/RoutingApplication.hs | 34 ++++++++++++++-----
.../Server/Internal/RoutingApplicationSpec.hs | 3 +-
servant-server/test/Servant/ServerSpec.hs | 7 ++++
4 files changed, 55 insertions(+), 17 deletions(-)
diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs
index 686cf59d..926e654e 100644
--- a/servant-server/src/Servant/Server/Internal.hs
+++ b/servant-server/src/Servant/Server/Internal.hs
@@ -46,7 +46,7 @@ import Network.Wai (Application, Request, Response,
responseLBS, vault)
import Prelude ()
import Prelude.Compat
-import Web.HttpApiData (FromHttpApiData, parseHeaderMaybe,
+import Web.HttpApiData (FromHttpApiData, parseHeader,
parseQueryParam,
parseUrlPieceMaybe,
parseUrlPieces)
@@ -280,10 +280,21 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
type ServerT (Header sym a :> api) m =
Maybe a -> ServerT api m
- route Proxy context subserver =
- let mheader req = parseHeaderMaybe =<< lookup str (requestHeaders req)
- in route (Proxy :: Proxy api) context (passToServer subserver mheader)
- where str = fromString $ symbolVal (Proxy :: Proxy sym)
+ route Proxy context subserver = route (Proxy :: Proxy api) context $
+ subserver `addHeaderCheck` withRequest headerCheck
+ where
+ headerName = symbolVal (Proxy :: Proxy sym)
+ headerCheck req =
+ case lookup (fromString headerName) (requestHeaders req) of
+ Nothing -> return Nothing
+ Just txt ->
+ case parseHeader txt of
+ Left e -> delayedFailFatal err400
+ { errBody = cs $ "Error parsing header "
+ <> fromString headerName
+ <> " failed: " <> e
+ }
+ Right header -> return $ Just header
-- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a function
@@ -321,7 +332,8 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
Just (Just v) ->
case parseQueryParam v of
Left e -> delayedFailFatal err400
- { errBody = cs $ "Error parsing query parameter " <> paramname <> " failed: " <> e
+ { errBody = cs $ "Error parsing query parameter "
+ <> paramname <> " failed: " <> e
}
Right param -> return $ Just param
@@ -364,7 +376,9 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
case partitionEithers $ fmap parseQueryParam params of
([], parsed) -> return parsed
(errs, _) -> delayedFailFatal err400
- { errBody = cs $ "Error parsing query parameter(s) " <> paramname <> " failed: " <> T.intercalate ", " errs
+ { errBody = cs $ "Error parsing query parameter(s) "
+ <> paramname <> " failed: "
+ <> T.intercalate ", " errs
}
where
params :: [T.Text]
diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs
index 9c8a411c..9488a70a 100644
--- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs
+++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs
@@ -160,7 +160,9 @@ toApplication ra request respond = ra request routingRespond
-- 5. Query parameter checks. They require parsing and can cause 400 if the
-- parsing fails. Query parameter checks provide inputs to the handler
--
--- 6. Body check. The request body check can cause 400.
+-- 6. Header Checks. They also require parsing and can cause 400 if parsing fails.
+--
+-- 7. Body check. The request body check can cause 400.
--
data Delayed env c where
Delayed :: { capturesD :: env -> DelayedIO captures
@@ -169,9 +171,11 @@ data Delayed env c where
, acceptD :: DelayedIO ()
, contentD :: DelayedIO contentType
, paramsD :: DelayedIO params
+ , headersD :: DelayedIO headers
, bodyD :: contentType -> DelayedIO body
, serverD :: captures
-> params
+ -> headers
-> auth
-> body
-> Request
@@ -181,7 +185,7 @@ data Delayed env c where
instance Functor (Delayed env) where
fmap f Delayed{..} =
Delayed
- { serverD = \ c p a b req -> f <$> serverD c p a b req
+ { serverD = \ c p h a b req -> f <$> serverD c p h a b req
, ..
} -- Note [Existential Record Update]
@@ -213,7 +217,7 @@ runDelayedIO m req = transResourceT runRouteResultT $ runReaderT (runDelayedIO'
-- | A 'Delayed' without any stored checks.
emptyDelayed :: RouteResult a -> Delayed env a
emptyDelayed result =
- Delayed (const r) r r r r r (const r) (\ _ _ _ _ _ -> result)
+ Delayed (const r) r r r r r r (const r) (\ _ _ _ _ _ _ -> result)
where
r = return ()
@@ -238,7 +242,7 @@ addCapture :: Delayed env (a -> b)
addCapture Delayed{..} new =
Delayed
{ capturesD = \ (txt, env) -> (,) <$> capturesD env <*> new txt
- , serverD = \ (x, v) p a b req -> ($ v) <$> serverD x p a b req
+ , serverD = \ (x, v) p h a b req -> ($ v) <$> serverD x p h a b req
, ..
} -- Note [Existential Record Update]
@@ -249,7 +253,18 @@ addParameterCheck :: Delayed env (a -> b)
addParameterCheck Delayed {..} new =
Delayed
{ paramsD = (,) <$> paramsD <*> new
- , serverD = \c (p, pNew) a b req -> ($ pNew) <$> serverD c p a b req
+ , serverD = \c (p, pNew) h a b req -> ($ pNew) <$> serverD c p h a b req
+ , ..
+ }
+
+-- | Add a parameter check to the end of the params block
+addHeaderCheck :: Delayed env (a -> b)
+ -> DelayedIO a
+ -> Delayed env b
+addHeaderCheck Delayed {..} new =
+ Delayed
+ { headersD = (,) <$> headersD <*> new
+ , serverD = \c p (h, hNew) a b req -> ($ hNew) <$> serverD c p h a b req
, ..
}
@@ -270,7 +285,7 @@ addAuthCheck :: Delayed env (a -> b)
addAuthCheck Delayed{..} new =
Delayed
{ authD = (,) <$> authD <*> new
- , serverD = \ c p (y, v) b req -> ($ v) <$> serverD c p y b req
+ , serverD = \ c p h (y, v) b req -> ($ v) <$> serverD c p h y b req
, ..
} -- Note [Existential Record Update]
@@ -286,7 +301,7 @@ addBodyCheck Delayed{..} newContentD newBodyD =
Delayed
{ contentD = (,) <$> contentD <*> newContentD
, bodyD = \(content, c) -> (,) <$> bodyD content <*> newBodyD c
- , serverD = \ c p a (z, v) req -> ($ v) <$> serverD c p a z req
+ , serverD = \ c p h a (z, v) req -> ($ v) <$> serverD c p h a z req
, ..
} -- Note [Existential Record Update]
@@ -316,7 +331,7 @@ addAcceptCheck Delayed{..} new =
passToServer :: Delayed env (a -> b) -> (Request -> a) -> Delayed env b
passToServer Delayed{..} x =
Delayed
- { serverD = \ c p a b req -> ($ x req) <$> serverD c p a b req
+ { serverD = \ c p h a b req -> ($ x req) <$> serverD c p h a b req
, ..
} -- Note [Existential Record Update]
@@ -338,8 +353,9 @@ runDelayed Delayed{..} env = runDelayedIO $ do
acceptD
content <- contentD
p <- paramsD -- Has to be before body parsing, but after content-type checks
+ h <- headersD
b <- bodyD content
- liftRouteResult (serverD c p a b r)
+ liftRouteResult (serverD c p h a b r)
-- | Runs a delayed server and the resulting action.
-- Takes a continuation that lets us send a response.
diff --git a/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs
index a3be12f5..30710db0 100644
--- a/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs
+++ b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs
@@ -63,11 +63,12 @@ delayed body srv = Delayed
, acceptD = return ()
, contentD = return ()
, paramsD = return ()
+ , headersD = return ()
, bodyD = \() -> do
liftIO (writeTestResource "hia" >> putStrLn "garbage created")
_ <- register (freeTestResource >> putStrLn "garbage collected")
body
- , serverD = \() () () _body _req -> srv
+ , serverD = \() () () () _body _req -> srv
}
simpleRun :: Delayed () (Handler ())
diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs
index c0042f44..40a850c7 100644
--- a/servant-server/test/Servant/ServerSpec.hs
+++ b/servant-server/test/Servant/ServerSpec.hs
@@ -477,6 +477,13 @@ headerSpec = describe "Servant.API.Header" $ do
it "passes the header to the handler (String)" $
delete' "/" "" `shouldRespondWith` 200
+ with (return (serve headerApi expectsInt)) $ do
+ let delete' x = THW.request methodDelete x [("MyHeader", "not a number")]
+
+ it "checks for parse errors" $
+ delete' "/" "" `shouldRespondWith` 400
+
+
-- }}}
------------------------------------------------------------------------------
-- * rawSpec {{{