From 833551e2ea7d65d0b624593de86bae03c18a8adf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6nke=20Hahn?= Date: Sun, 23 Oct 2016 21:44:43 -0400 Subject: [PATCH] allow IO --- .../Servant/Server/Utils/CustomCombinators.hs | 19 ++++++++------- .../Server/Utils/CustomCombinatorsSpec.hs | 24 +++++++++---------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/servant-server/src/Servant/Server/Utils/CustomCombinators.hs b/servant-server/src/Servant/Server/Utils/CustomCombinators.hs index e0e133e6..7e82c61a 100644 --- a/servant-server/src/Servant/Server/Utils/CustomCombinators.hs +++ b/servant-server/src/Servant/Server/Utils/CustomCombinators.hs @@ -18,6 +18,7 @@ module Servant.Server.Utils.CustomCombinators ( RouteResult(..), ) where +import Control.Monad.IO.Class import Data.ByteString import Data.Proxy import Data.Text @@ -50,32 +51,32 @@ makeCaptureCombinator :: forall api combinator arg context . (HasServer api context, WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) => - (Text -> RouteResult arg) + (Text -> IO (RouteResult arg)) -> CombinatorImplementation combinator arg api context makeCaptureCombinator getArg = CI $ \ Proxy context delayed -> CaptureRouter $ route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured -> - (liftRouteResult (getArg captured)) + (liftRouteResult =<< liftIO (getArg captured)) makeRequestCheckCombinator :: forall api combinator context . (HasServer api context, WithArg () (ServerT api Handler) ~ ServerT api Handler) => - (Request -> RouteResult ()) + (Request -> IO (RouteResult ())) -> CombinatorImplementation combinator () api context makeRequestCheckCombinator check = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addMethodCheck delayed $ - withRequest $ \ request -> liftRouteResult $ check request + withRequest $ \ request -> liftRouteResult =<< liftIO (check request) makeAuthCombinator :: forall api combinator arg context . (HasServer api context, WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) => - (Request -> RouteResult arg) + (Request -> IO (RouteResult arg)) -> CombinatorImplementation combinator arg api context makeAuthCombinator authCheck = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addAuthCheck delayed $ - withRequest $ \ request -> liftRouteResult $ authCheck request + withRequest $ \ request -> liftRouteResult =<< liftIO (authCheck request) makeReqBodyCombinator :: forall api combinator arg context . @@ -94,9 +95,9 @@ makeCombinator :: (ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler), WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler), HasServer api context) => - (Request -> RouteResult arg) + (Request -> IO (RouteResult arg)) -> CombinatorImplementation combinator arg api context makeCombinator getArg = CI $ \ Proxy context delayed -> - route (Proxy :: Proxy api) context $ addBodyCheck delayed -- fixme: shouldn't be body + route (Proxy :: Proxy api) context $ addBodyCheck delayed (return ()) - (\ () -> withRequest $ \ request -> liftRouteResult $ getArg request) + (\ () -> withRequest $ \ request -> liftRouteResult =<< liftIO (getArg request)) diff --git a/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs b/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs index 31839cd8..d28b27a1 100644 --- a/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs +++ b/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs @@ -152,9 +152,6 @@ spec = do response <- runApp app request responseStatus response `shouldBe` status418 - it "allows to write a combinator using IO" $ do - pending - it "allows to pick the request check phase" $ do pending @@ -178,8 +175,8 @@ instance HasServer api context => HasServer (StringCapture :> api) context where type ServerT (StringCapture :> api) m = String -> ServerT api m route = runCI $ makeCaptureCombinator getCapture -getCapture :: Text -> RouteResult String -getCapture = \case +getCapture :: Text -> IO (RouteResult String) +getCapture snippet = return $ case snippet of "error" -> FailFatal $ ServantErr 418 "I'm a teapot" "" [] text -> Route $ cs text @@ -189,10 +186,11 @@ instance HasServer api context => HasServer (CheckFooHeader :> api) context wher type ServerT (CheckFooHeader :> api) m = ServerT api m route = runCI $ makeRequestCheckCombinator checkFooHeader -checkFooHeader :: Request -> RouteResult () -checkFooHeader request = case lookup "Foo" (requestHeaders request) of - Just _ -> Route () - Nothing -> FailFatal err400 +checkFooHeader :: Request -> IO (RouteResult ()) +checkFooHeader request = return $ + case lookup "Foo" (requestHeaders request) of + Just _ -> Route () + Nothing -> FailFatal err400 data AuthCombinator @@ -203,8 +201,8 @@ instance HasServer api context => HasServer (AuthCombinator :> api) context wher type ServerT (AuthCombinator :> api) m = User -> ServerT api m route = runCI $ makeAuthCombinator checkAuth -checkAuth :: Request -> RouteResult User -checkAuth request = case lookup "Auth" (requestHeaders request) of +checkAuth :: Request -> IO (RouteResult User) +checkAuth request = return $ case lookup "Auth" (requestHeaders request) of Just "secret" -> Route $ User "Alice" Just _ -> FailFatal err401 Nothing -> FailFatal err400 @@ -215,8 +213,8 @@ instance HasServer api context => HasServer (FooHeader :> api) context where type ServerT (FooHeader :> api) m = String -> ServerT api m route = runCI $ makeCombinator getCustom -getCustom :: Request -> RouteResult String -getCustom request = case lookup "Foo" (requestHeaders request) of +getCustom :: Request -> IO (RouteResult String) +getCustom request = return $ case lookup "Foo" (requestHeaders request) of Nothing -> FailFatal err400 Just l -> Route $ cs l