diff --git a/servant-server/src/Servant/Server/Utils/CustomCombinators.hs b/servant-server/src/Servant/Server/Utils/CustomCombinators.hs index 7e82c61a..9f8b34a4 100644 --- a/servant-server/src/Servant/Server/Utils/CustomCombinators.hs +++ b/servant-server/src/Servant/Server/Utils/CustomCombinators.hs @@ -19,6 +19,7 @@ module Servant.Server.Utils.CustomCombinators ( ) where import Control.Monad.IO.Class +import Control.Exception (throwIO, ErrorCall(..)) import Data.ByteString import Data.Proxy import Data.Text @@ -36,6 +37,7 @@ data CombinatorImplementation combinator arg api context where -> Router' env RoutingApplication) -> CombinatorImplementation combinator arg api context +-- fixme: get rid of WithArg? type family WithArg arg rest where WithArg () rest = rest WithArg arg rest = arg -> rest @@ -66,7 +68,8 @@ makeRequestCheckCombinator :: -> CombinatorImplementation combinator () api context makeRequestCheckCombinator check = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addMethodCheck delayed $ - withRequest $ \ request -> liftRouteResult =<< liftIO (check request) + withRequest $ \ request -> + liftRouteResult =<< liftIO (check (protectBody "makeRequestCheckCombinator" request)) makeAuthCombinator :: forall api combinator arg context . @@ -76,7 +79,8 @@ makeAuthCombinator :: -> CombinatorImplementation combinator arg api context makeAuthCombinator authCheck = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addAuthCheck delayed $ - withRequest $ \ request -> liftRouteResult =<< liftIO (authCheck request) + withRequest $ \ request -> + liftRouteResult =<< liftIO (authCheck (protectBody "makeAuthCombinator" request)) makeReqBodyCombinator :: forall api combinator arg context . @@ -88,7 +92,8 @@ makeReqBodyCombinator :: makeReqBodyCombinator getArg = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addBodyCheck delayed (return ()) - (\ () -> withRequest $ \ request -> liftRouteResult $ Route $ getArg $ requestBody request) + (\ () -> withRequest $ \ request -> + liftRouteResult $ Route $ getArg $ requestBody request) makeCombinator :: forall api combinator arg context . @@ -100,4 +105,11 @@ makeCombinator :: makeCombinator getArg = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addBodyCheck delayed (return ()) - (\ () -> withRequest $ \ request -> liftRouteResult =<< liftIO (getArg request)) + (\ () -> withRequest $ \ request -> + liftRouteResult =<< liftIO (getArg (protectBody "makeCombinator" request))) + +protectBody :: String -> Request -> Request +protectBody name request = request{ + requestBody = throwIO $ ErrorCall $ + "ERROR: " ++ name ++ ": combinator must not access the request body" +} diff --git a/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs b/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs index d28b27a1..754280fb 100644 --- a/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs +++ b/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs @@ -86,7 +86,6 @@ spec = do responseStatus response `shouldBe` ok200 responseBodyLbs response `shouldReturn` "\"Alice\"" - -- fixme: rename it "allows to write a combinator by providing a function (Request -> a)" $ do let server = return app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server @@ -108,14 +107,6 @@ spec = do chunk <- a return (r, chunk) [] -> return ([], "") - fromBody :: IO SBS.ByteString -> IO SBS.ByteString - fromBody getChunk = do - chunk <- getChunk - if chunk == "" - then return "" - else do - rest <- fromBody getChunk - return $ chunk <> rest it "allows to write combinators" $ do body <- toBody $ map return ["foo", "bar"] @@ -152,11 +143,21 @@ spec = do response <- runApp app request responseStatus response `shouldBe` status418 - it "allows to pick the request check phase" $ do - pending - it "disallows to access the request body unless in the checkBody phase" $ do - pending + let server = return () + app = serve (Proxy :: Proxy (InvalidRequestCheckCombinator :> Get' ())) server + request = defaultRequest + runApp app request `shouldThrow` + errorCall "ERROR: makeRequestCheckCombinator: combinator must not access the request body" + + it "disallows to access the request body unless in the auth phase" $ do + let server _user = return "foo" + app = serve (Proxy :: Proxy (InvalidAuthCombinator :> Get' String)) server + request = defaultRequest + runApp app request `shouldThrow` + errorCall "ERROR: makeAuthCombinator: combinator must not access the request body" + + -- fixme: reorder tests it "allows to access the context" $ do pending @@ -228,3 +229,36 @@ instance HasServer api context => HasServer (StreamRequest :> api) context where getSource :: IO SBS.ByteString -> Source getSource = Source + +-- | a combinator that tries to access the request body in an invalid way +data InvalidRequestCheckCombinator + +instance HasServer api context => HasServer (InvalidRequestCheckCombinator :> api) context where + type ServerT (InvalidRequestCheckCombinator :> api) m = ServerT api m + route = runCI $ makeRequestCheckCombinator accessReqBody + +accessReqBody :: Request -> IO (RouteResult ()) +accessReqBody request = do + body <- fromBody $ requestBody request + deepseq body (return $ Route ()) + +-- | a combinator that tries to access the request body in an invalid way +data InvalidAuthCombinator + +instance HasServer api context => HasServer (InvalidAuthCombinator :> api) context where + type ServerT (InvalidAuthCombinator :> api) m = User -> ServerT api m + route = runCI $ makeAuthCombinator authWithReqBody + +authWithReqBody :: Request -> IO (RouteResult User) +authWithReqBody request = do + body <- fromBody $ requestBody request + deepseq body (return $ Route $ User $ cs body) + +fromBody :: IO SBS.ByteString -> IO SBS.ByteString +fromBody getChunk = do + chunk <- getChunk + if chunk == "" + then return "" + else do + rest <- fromBody getChunk + return $ chunk <> rest