This commit is contained in:
Sönke Hahn 2016-10-23 21:44:43 -04:00
parent e5f46e8ba0
commit 833551e2ea
2 changed files with 21 additions and 22 deletions

View file

@ -18,6 +18,7 @@ module Servant.Server.Utils.CustomCombinators (
RouteResult(..), RouteResult(..),
) where ) where
import Control.Monad.IO.Class
import Data.ByteString import Data.ByteString
import Data.Proxy import Data.Proxy
import Data.Text import Data.Text
@ -50,32 +51,32 @@ makeCaptureCombinator ::
forall api combinator arg context . forall api combinator arg context .
(HasServer api context, (HasServer api context,
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) => WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
(Text -> RouteResult arg) (Text -> IO (RouteResult arg))
-> CombinatorImplementation combinator arg api context -> CombinatorImplementation combinator arg api context
makeCaptureCombinator getArg = CI $ \ Proxy context delayed -> makeCaptureCombinator getArg = CI $ \ Proxy context delayed ->
CaptureRouter $ CaptureRouter $
route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured -> route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured ->
(liftRouteResult (getArg captured)) (liftRouteResult =<< liftIO (getArg captured))
makeRequestCheckCombinator :: makeRequestCheckCombinator ::
forall api combinator context . forall api combinator context .
(HasServer api context, (HasServer api context,
WithArg () (ServerT api Handler) ~ ServerT api Handler) => WithArg () (ServerT api Handler) ~ ServerT api Handler) =>
(Request -> RouteResult ()) (Request -> IO (RouteResult ()))
-> CombinatorImplementation combinator () api context -> CombinatorImplementation combinator () api context
makeRequestCheckCombinator check = CI $ \ Proxy context delayed -> makeRequestCheckCombinator check = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addMethodCheck delayed $ route (Proxy :: Proxy api) context $ addMethodCheck delayed $
withRequest $ \ request -> liftRouteResult $ check request withRequest $ \ request -> liftRouteResult =<< liftIO (check request)
makeAuthCombinator :: makeAuthCombinator ::
forall api combinator arg context . forall api combinator arg context .
(HasServer api context, (HasServer api context,
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) => WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
(Request -> RouteResult arg) (Request -> IO (RouteResult arg))
-> CombinatorImplementation combinator arg api context -> CombinatorImplementation combinator arg api context
makeAuthCombinator authCheck = CI $ \ Proxy context delayed -> makeAuthCombinator authCheck = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addAuthCheck delayed $ route (Proxy :: Proxy api) context $ addAuthCheck delayed $
withRequest $ \ request -> liftRouteResult $ authCheck request withRequest $ \ request -> liftRouteResult =<< liftIO (authCheck request)
makeReqBodyCombinator :: makeReqBodyCombinator ::
forall api combinator arg context . forall api combinator arg context .
@ -94,9 +95,9 @@ makeCombinator ::
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler), (ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler), WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
HasServer api context) => HasServer api context) =>
(Request -> RouteResult arg) (Request -> IO (RouteResult arg))
-> CombinatorImplementation combinator arg api context -> CombinatorImplementation combinator arg api context
makeCombinator getArg = CI $ \ Proxy context delayed -> makeCombinator getArg = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addBodyCheck delayed -- fixme: shouldn't be body route (Proxy :: Proxy api) context $ addBodyCheck delayed
(return ()) (return ())
(\ () -> withRequest $ \ request -> liftRouteResult $ getArg request) (\ () -> withRequest $ \ request -> liftRouteResult =<< liftIO (getArg request))

View file

@ -152,9 +152,6 @@ spec = do
response <- runApp app request response <- runApp app request
responseStatus response `shouldBe` status418 responseStatus response `shouldBe` status418
it "allows to write a combinator using IO" $ do
pending
it "allows to pick the request check phase" $ do it "allows to pick the request check phase" $ do
pending pending
@ -178,8 +175,8 @@ instance HasServer api context => HasServer (StringCapture :> api) context where
type ServerT (StringCapture :> api) m = String -> ServerT api m type ServerT (StringCapture :> api) m = String -> ServerT api m
route = runCI $ makeCaptureCombinator getCapture route = runCI $ makeCaptureCombinator getCapture
getCapture :: Text -> RouteResult String getCapture :: Text -> IO (RouteResult String)
getCapture = \case getCapture snippet = return $ case snippet of
"error" -> FailFatal $ ServantErr 418 "I'm a teapot" "" [] "error" -> FailFatal $ ServantErr 418 "I'm a teapot" "" []
text -> Route $ cs text text -> Route $ cs text
@ -189,10 +186,11 @@ instance HasServer api context => HasServer (CheckFooHeader :> api) context wher
type ServerT (CheckFooHeader :> api) m = ServerT api m type ServerT (CheckFooHeader :> api) m = ServerT api m
route = runCI $ makeRequestCheckCombinator checkFooHeader route = runCI $ makeRequestCheckCombinator checkFooHeader
checkFooHeader :: Request -> RouteResult () checkFooHeader :: Request -> IO (RouteResult ())
checkFooHeader request = case lookup "Foo" (requestHeaders request) of checkFooHeader request = return $
Just _ -> Route () case lookup "Foo" (requestHeaders request) of
Nothing -> FailFatal err400 Just _ -> Route ()
Nothing -> FailFatal err400
data AuthCombinator data AuthCombinator
@ -203,8 +201,8 @@ instance HasServer api context => HasServer (AuthCombinator :> api) context wher
type ServerT (AuthCombinator :> api) m = User -> ServerT api m type ServerT (AuthCombinator :> api) m = User -> ServerT api m
route = runCI $ makeAuthCombinator checkAuth route = runCI $ makeAuthCombinator checkAuth
checkAuth :: Request -> RouteResult User checkAuth :: Request -> IO (RouteResult User)
checkAuth request = case lookup "Auth" (requestHeaders request) of checkAuth request = return $ case lookup "Auth" (requestHeaders request) of
Just "secret" -> Route $ User "Alice" Just "secret" -> Route $ User "Alice"
Just _ -> FailFatal err401 Just _ -> FailFatal err401
Nothing -> FailFatal err400 Nothing -> FailFatal err400
@ -215,8 +213,8 @@ instance HasServer api context => HasServer (FooHeader :> api) context where
type ServerT (FooHeader :> api) m = String -> ServerT api m type ServerT (FooHeader :> api) m = String -> ServerT api m
route = runCI $ makeCombinator getCustom route = runCI $ makeCombinator getCustom
getCustom :: Request -> RouteResult String getCustom :: Request -> IO (RouteResult String)
getCustom request = case lookup "Foo" (requestHeaders request) of getCustom request = return $ case lookup "Foo" (requestHeaders request) of
Nothing -> FailFatal err400 Nothing -> FailFatal err400
Just l -> Route $ cs l Just l -> Route $ cs l