add implementAuthCombinator and implementRequestCheck
This commit is contained in:
parent
7177f0a729
commit
be5e6e59c7
2 changed files with 123 additions and 40 deletions
|
@ -7,7 +7,9 @@
|
|||
module Servant.Server.CombinatorUtils (
|
||||
CombinatorImplementation,
|
||||
runCI,
|
||||
captureCombinator,
|
||||
implementCaptureCombinator,
|
||||
implementRequestCheck,
|
||||
implementAuthCombinator,
|
||||
argumentCombinator,
|
||||
-- * re-exports
|
||||
RouteResult(..),
|
||||
|
@ -25,37 +27,60 @@ data CombinatorImplementation combinator arg api context where
|
|||
CI :: (forall env .
|
||||
Proxy (combinator :> api)
|
||||
-> Context context
|
||||
-> Delayed env (arg -> Server api)
|
||||
-> Delayed env (WithArg arg (Server api))
|
||||
-> Router' env RoutingApplication)
|
||||
-> CombinatorImplementation combinator arg api context
|
||||
|
||||
type family WithArg arg rest where
|
||||
WithArg () rest = rest
|
||||
WithArg arg rest = arg -> rest
|
||||
|
||||
runCI :: CombinatorImplementation combinator arg api context
|
||||
-> Proxy (combinator :> api)
|
||||
-> Context context
|
||||
-> Delayed env (arg -> Server api)
|
||||
-> Delayed env (WithArg arg (Server api))
|
||||
-> Router' env RoutingApplication
|
||||
runCI (CI i) = i
|
||||
|
||||
captureCombinator ::
|
||||
implementCaptureCombinator ::
|
||||
forall api combinator arg context .
|
||||
(HasServer api context) =>
|
||||
(HasServer api context,
|
||||
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
|
||||
(Text -> RouteResult arg)
|
||||
-> CombinatorImplementation combinator arg api context
|
||||
captureCombinator getArg = CI $ \ Proxy context delayed ->
|
||||
implementCaptureCombinator getArg = CI $ \ Proxy context delayed ->
|
||||
CaptureRouter $
|
||||
route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured ->
|
||||
(liftRouteResult (getArg captured))
|
||||
|
||||
implementRequestCheck ::
|
||||
forall api combinator context .
|
||||
(HasServer api context,
|
||||
WithArg () (ServerT api Handler) ~ ServerT api Handler) =>
|
||||
(Request -> RouteResult ())
|
||||
-> CombinatorImplementation combinator () api context
|
||||
implementRequestCheck check = CI $ \ Proxy context delayed ->
|
||||
route (Proxy :: Proxy api) context $ addMethodCheck delayed $
|
||||
withRequest $ \ request -> liftRouteResult $ check request
|
||||
|
||||
implementAuthCombinator ::
|
||||
forall api combinator arg context .
|
||||
(HasServer api context,
|
||||
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
|
||||
(Request -> RouteResult arg)
|
||||
-> CombinatorImplementation combinator arg api context
|
||||
implementAuthCombinator authCheck = CI $ \ Proxy context delayed ->
|
||||
route (Proxy :: Proxy api) context $ addAuthCheck delayed $
|
||||
withRequest $ \ request -> liftRouteResult $ authCheck request
|
||||
|
||||
argumentCombinator ::
|
||||
forall api combinator arg context .
|
||||
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
|
||||
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
|
||||
HasServer api context) =>
|
||||
(Request -> RouteResult arg)
|
||||
-> CombinatorImplementation combinator arg api context
|
||||
argumentCombinator getArg = CI $ \ Proxy context delayed ->
|
||||
route (Proxy :: Proxy api) context $
|
||||
addBodyCheck delayed contentTypeCheck bodyCheck
|
||||
where
|
||||
contentTypeCheck = return ()
|
||||
bodyCheck () = withRequest $ \ request ->
|
||||
liftRouteResult (getArg request)
|
||||
route (Proxy :: Proxy api) context $ addBodyCheck delayed -- fixme: shouldn't be body
|
||||
(return ())
|
||||
(\ () -> withRequest $ \ request -> liftRouteResult (getArg request))
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
|
@ -48,24 +49,6 @@ responseBodyLbs response = do
|
|||
|
||||
spec :: Spec
|
||||
spec = do
|
||||
it "allows to write a combinator by providing a function (Request -> a)" $ do
|
||||
let server = return
|
||||
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
|
||||
request = defaultRequest{
|
||||
requestHeaders =
|
||||
("FooHeader", "foo") :
|
||||
requestHeaders defaultRequest
|
||||
}
|
||||
response <- runApp app request
|
||||
responseBodyLbs response `shouldReturn` "\"foo\""
|
||||
|
||||
it "allows to write a combinator the errors out" $ do
|
||||
let server = return
|
||||
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
|
||||
request = defaultRequest
|
||||
response <- runApp app request
|
||||
responseStatus response `shouldBe` status400
|
||||
|
||||
it "allows to write capture combinators" $ do
|
||||
let server = return
|
||||
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
|
||||
|
@ -76,6 +59,53 @@ spec = do
|
|||
response <- runApp app request
|
||||
responseBodyLbs response `shouldReturn` "\"foo\""
|
||||
|
||||
it "allows to write request check combinators" $ do
|
||||
let server = return ()
|
||||
app = serve (Proxy :: Proxy (CheckFooHeader :> Get' ())) server
|
||||
request = defaultRequest{
|
||||
requestHeaders =
|
||||
("Foo", "foo") :
|
||||
requestHeaders defaultRequest
|
||||
}
|
||||
response <- runApp app request
|
||||
responseBodyLbs response `shouldReturn` "[]"
|
||||
|
||||
it "allows to write a combinator that errors out" $ do
|
||||
let server = return
|
||||
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
|
||||
request = defaultRequest {
|
||||
rawPathInfo = "/error",
|
||||
pathInfo = ["error"]
|
||||
}
|
||||
response <- runApp app request
|
||||
responseStatus response `shouldBe` status418
|
||||
|
||||
it "allows to write a combinator using IO" $ do
|
||||
pending
|
||||
|
||||
it "allows to write a combinator by providing a function (Request -> a)" $ do
|
||||
let server = return
|
||||
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
|
||||
request = defaultRequest{
|
||||
requestHeaders =
|
||||
("Foo", "foo") :
|
||||
requestHeaders defaultRequest
|
||||
}
|
||||
response <- runApp app request
|
||||
responseBodyLbs response `shouldReturn` "\"foo\""
|
||||
|
||||
it "allows to write an auth combinator" $ do
|
||||
let server (User name) = return name
|
||||
app = serve (Proxy :: Proxy (AuthCombinator :> Get' String)) server
|
||||
request = defaultRequest{
|
||||
requestHeaders =
|
||||
("Auth", "secret") :
|
||||
requestHeaders defaultRequest
|
||||
}
|
||||
response <- runApp app request
|
||||
responseStatus response `shouldBe` ok200
|
||||
responseBodyLbs response `shouldReturn` "\"Alice\""
|
||||
|
||||
it "allows to pick the request check phase" $ do
|
||||
pending
|
||||
|
||||
|
@ -96,6 +126,43 @@ spec = do
|
|||
|
||||
type Get' = Get '[JSON]
|
||||
|
||||
data StringCapture
|
||||
|
||||
instance HasServer api context => HasServer (StringCapture :> api) context where
|
||||
type ServerT (StringCapture :> api) m = String -> ServerT api m
|
||||
route = runCI $ implementCaptureCombinator getCapture
|
||||
|
||||
getCapture :: Text -> RouteResult String
|
||||
getCapture = \case
|
||||
"error" -> FailFatal $ ServantErr 418 "I'm a teapot" "" []
|
||||
text -> Route $ cs text
|
||||
|
||||
data CheckFooHeader
|
||||
|
||||
instance HasServer api context => HasServer (CheckFooHeader :> api) context where
|
||||
type ServerT (CheckFooHeader :> api) m = ServerT api m
|
||||
route = runCI $ implementRequestCheck checkFooHeader
|
||||
|
||||
checkFooHeader :: Request -> RouteResult ()
|
||||
checkFooHeader request = case lookup "Foo" (requestHeaders request) of
|
||||
Just _ -> Route ()
|
||||
Nothing -> FailFatal err400
|
||||
|
||||
data AuthCombinator
|
||||
|
||||
data User = User String
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance HasServer api context => HasServer (AuthCombinator :> api) context where
|
||||
type ServerT (AuthCombinator :> api) m = User -> ServerT api m
|
||||
route = runCI $ implementAuthCombinator checkAuth
|
||||
|
||||
checkAuth :: Request -> RouteResult User
|
||||
checkAuth request = case lookup "Auth" (requestHeaders request) of
|
||||
Just "secret" -> Route $ User "Alice"
|
||||
Just _ -> FailFatal err401
|
||||
Nothing -> FailFatal err400
|
||||
|
||||
data FooHeader
|
||||
|
||||
instance HasServer api context => HasServer (FooHeader :> api) context where
|
||||
|
@ -103,15 +170,6 @@ instance HasServer api context => HasServer (FooHeader :> api) context where
|
|||
route = runCI $ argumentCombinator getCustom
|
||||
|
||||
getCustom :: Request -> RouteResult String
|
||||
getCustom request = case lookup "FooHeader" (requestHeaders request) of
|
||||
getCustom request = case lookup "Foo" (requestHeaders request) of
|
||||
Nothing -> FailFatal err400
|
||||
Just l -> Route $ cs l
|
||||
|
||||
data StringCapture
|
||||
|
||||
instance HasServer api context => HasServer (StringCapture :> api) context where
|
||||
type ServerT (StringCapture :> api) m = String -> ServerT api m
|
||||
route = runCI $ captureCombinator getCapture
|
||||
|
||||
getCapture :: Text -> RouteResult String
|
||||
getCapture = Route . cs
|
||||
|
|
Loading…
Reference in a new issue