throw an exception in case of unallowed request body access
This commit is contained in:
parent
833551e2ea
commit
6a5256c3ff
2 changed files with 63 additions and 17 deletions
|
@ -19,6 +19,7 @@ module Servant.Server.Utils.CustomCombinators (
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad.IO.Class
|
import Control.Monad.IO.Class
|
||||||
|
import Control.Exception (throwIO, ErrorCall(..))
|
||||||
import Data.ByteString
|
import Data.ByteString
|
||||||
import Data.Proxy
|
import Data.Proxy
|
||||||
import Data.Text
|
import Data.Text
|
||||||
|
@ -36,6 +37,7 @@ data CombinatorImplementation combinator arg api context where
|
||||||
-> Router' env RoutingApplication)
|
-> Router' env RoutingApplication)
|
||||||
-> CombinatorImplementation combinator arg api context
|
-> CombinatorImplementation combinator arg api context
|
||||||
|
|
||||||
|
-- fixme: get rid of WithArg?
|
||||||
type family WithArg arg rest where
|
type family WithArg arg rest where
|
||||||
WithArg () rest = rest
|
WithArg () rest = rest
|
||||||
WithArg arg rest = arg -> rest
|
WithArg arg rest = arg -> rest
|
||||||
|
@ -66,7 +68,8 @@ makeRequestCheckCombinator ::
|
||||||
-> CombinatorImplementation combinator () api context
|
-> CombinatorImplementation combinator () api context
|
||||||
makeRequestCheckCombinator check = CI $ \ Proxy context delayed ->
|
makeRequestCheckCombinator check = CI $ \ Proxy context delayed ->
|
||||||
route (Proxy :: Proxy api) context $ addMethodCheck delayed $
|
route (Proxy :: Proxy api) context $ addMethodCheck delayed $
|
||||||
withRequest $ \ request -> liftRouteResult =<< liftIO (check request)
|
withRequest $ \ request ->
|
||||||
|
liftRouteResult =<< liftIO (check (protectBody "makeRequestCheckCombinator" request))
|
||||||
|
|
||||||
makeAuthCombinator ::
|
makeAuthCombinator ::
|
||||||
forall api combinator arg context .
|
forall api combinator arg context .
|
||||||
|
@ -76,7 +79,8 @@ makeAuthCombinator ::
|
||||||
-> CombinatorImplementation combinator arg api context
|
-> CombinatorImplementation combinator arg api context
|
||||||
makeAuthCombinator authCheck = CI $ \ Proxy context delayed ->
|
makeAuthCombinator authCheck = CI $ \ Proxy context delayed ->
|
||||||
route (Proxy :: Proxy api) context $ addAuthCheck delayed $
|
route (Proxy :: Proxy api) context $ addAuthCheck delayed $
|
||||||
withRequest $ \ request -> liftRouteResult =<< liftIO (authCheck request)
|
withRequest $ \ request ->
|
||||||
|
liftRouteResult =<< liftIO (authCheck (protectBody "makeAuthCombinator" request))
|
||||||
|
|
||||||
makeReqBodyCombinator ::
|
makeReqBodyCombinator ::
|
||||||
forall api combinator arg context .
|
forall api combinator arg context .
|
||||||
|
@ -88,7 +92,8 @@ makeReqBodyCombinator ::
|
||||||
makeReqBodyCombinator getArg = CI $ \ Proxy context delayed ->
|
makeReqBodyCombinator getArg = CI $ \ Proxy context delayed ->
|
||||||
route (Proxy :: Proxy api) context $ addBodyCheck delayed
|
route (Proxy :: Proxy api) context $ addBodyCheck delayed
|
||||||
(return ())
|
(return ())
|
||||||
(\ () -> withRequest $ \ request -> liftRouteResult $ Route $ getArg $ requestBody request)
|
(\ () -> withRequest $ \ request ->
|
||||||
|
liftRouteResult $ Route $ getArg $ requestBody request)
|
||||||
|
|
||||||
makeCombinator ::
|
makeCombinator ::
|
||||||
forall api combinator arg context .
|
forall api combinator arg context .
|
||||||
|
@ -100,4 +105,11 @@ makeCombinator ::
|
||||||
makeCombinator getArg = CI $ \ Proxy context delayed ->
|
makeCombinator getArg = CI $ \ Proxy context delayed ->
|
||||||
route (Proxy :: Proxy api) context $ addBodyCheck delayed
|
route (Proxy :: Proxy api) context $ addBodyCheck delayed
|
||||||
(return ())
|
(return ())
|
||||||
(\ () -> withRequest $ \ request -> liftRouteResult =<< liftIO (getArg request))
|
(\ () -> withRequest $ \ request ->
|
||||||
|
liftRouteResult =<< liftIO (getArg (protectBody "makeCombinator" request)))
|
||||||
|
|
||||||
|
protectBody :: String -> Request -> Request
|
||||||
|
protectBody name request = request{
|
||||||
|
requestBody = throwIO $ ErrorCall $
|
||||||
|
"ERROR: " ++ name ++ ": combinator must not access the request body"
|
||||||
|
}
|
||||||
|
|
|
@ -86,7 +86,6 @@ spec = do
|
||||||
responseStatus response `shouldBe` ok200
|
responseStatus response `shouldBe` ok200
|
||||||
responseBodyLbs response `shouldReturn` "\"Alice\""
|
responseBodyLbs response `shouldReturn` "\"Alice\""
|
||||||
|
|
||||||
-- fixme: rename
|
|
||||||
it "allows to write a combinator by providing a function (Request -> a)" $ do
|
it "allows to write a combinator by providing a function (Request -> a)" $ do
|
||||||
let server = return
|
let server = return
|
||||||
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
|
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
|
||||||
|
@ -108,14 +107,6 @@ spec = do
|
||||||
chunk <- a
|
chunk <- a
|
||||||
return (r, chunk)
|
return (r, chunk)
|
||||||
[] -> return ([], "")
|
[] -> return ([], "")
|
||||||
fromBody :: IO SBS.ByteString -> IO SBS.ByteString
|
|
||||||
fromBody getChunk = do
|
|
||||||
chunk <- getChunk
|
|
||||||
if chunk == ""
|
|
||||||
then return ""
|
|
||||||
else do
|
|
||||||
rest <- fromBody getChunk
|
|
||||||
return $ chunk <> rest
|
|
||||||
|
|
||||||
it "allows to write combinators" $ do
|
it "allows to write combinators" $ do
|
||||||
body <- toBody $ map return ["foo", "bar"]
|
body <- toBody $ map return ["foo", "bar"]
|
||||||
|
@ -152,11 +143,21 @@ spec = do
|
||||||
response <- runApp app request
|
response <- runApp app request
|
||||||
responseStatus response `shouldBe` status418
|
responseStatus response `shouldBe` status418
|
||||||
|
|
||||||
it "allows to pick the request check phase" $ do
|
|
||||||
pending
|
|
||||||
|
|
||||||
it "disallows to access the request body unless in the checkBody phase" $ do
|
it "disallows to access the request body unless in the checkBody phase" $ do
|
||||||
pending
|
let server = return ()
|
||||||
|
app = serve (Proxy :: Proxy (InvalidRequestCheckCombinator :> Get' ())) server
|
||||||
|
request = defaultRequest
|
||||||
|
runApp app request `shouldThrow`
|
||||||
|
errorCall "ERROR: makeRequestCheckCombinator: combinator must not access the request body"
|
||||||
|
|
||||||
|
it "disallows to access the request body unless in the auth phase" $ do
|
||||||
|
let server _user = return "foo"
|
||||||
|
app = serve (Proxy :: Proxy (InvalidAuthCombinator :> Get' String)) server
|
||||||
|
request = defaultRequest
|
||||||
|
runApp app request `shouldThrow`
|
||||||
|
errorCall "ERROR: makeAuthCombinator: combinator must not access the request body"
|
||||||
|
|
||||||
|
-- fixme: reorder tests
|
||||||
|
|
||||||
it "allows to access the context" $ do
|
it "allows to access the context" $ do
|
||||||
pending
|
pending
|
||||||
|
@ -228,3 +229,36 @@ instance HasServer api context => HasServer (StreamRequest :> api) context where
|
||||||
|
|
||||||
getSource :: IO SBS.ByteString -> Source
|
getSource :: IO SBS.ByteString -> Source
|
||||||
getSource = Source
|
getSource = Source
|
||||||
|
|
||||||
|
-- | a combinator that tries to access the request body in an invalid way
|
||||||
|
data InvalidRequestCheckCombinator
|
||||||
|
|
||||||
|
instance HasServer api context => HasServer (InvalidRequestCheckCombinator :> api) context where
|
||||||
|
type ServerT (InvalidRequestCheckCombinator :> api) m = ServerT api m
|
||||||
|
route = runCI $ makeRequestCheckCombinator accessReqBody
|
||||||
|
|
||||||
|
accessReqBody :: Request -> IO (RouteResult ())
|
||||||
|
accessReqBody request = do
|
||||||
|
body <- fromBody $ requestBody request
|
||||||
|
deepseq body (return $ Route ())
|
||||||
|
|
||||||
|
-- | a combinator that tries to access the request body in an invalid way
|
||||||
|
data InvalidAuthCombinator
|
||||||
|
|
||||||
|
instance HasServer api context => HasServer (InvalidAuthCombinator :> api) context where
|
||||||
|
type ServerT (InvalidAuthCombinator :> api) m = User -> ServerT api m
|
||||||
|
route = runCI $ makeAuthCombinator authWithReqBody
|
||||||
|
|
||||||
|
authWithReqBody :: Request -> IO (RouteResult User)
|
||||||
|
authWithReqBody request = do
|
||||||
|
body <- fromBody $ requestBody request
|
||||||
|
deepseq body (return $ Route $ User $ cs body)
|
||||||
|
|
||||||
|
fromBody :: IO SBS.ByteString -> IO SBS.ByteString
|
||||||
|
fromBody getChunk = do
|
||||||
|
chunk <- getChunk
|
||||||
|
if chunk == ""
|
||||||
|
then return ""
|
||||||
|
else do
|
||||||
|
rest <- fromBody getChunk
|
||||||
|
return $ chunk <> rest
|
||||||
|
|
Loading…
Add table
Reference in a new issue