add implementAuthCombinator and implementRequestCheck
This commit is contained in:
parent
7177f0a729
commit
be5e6e59c7
2 changed files with 123 additions and 40 deletions
|
@ -7,7 +7,9 @@
|
||||||
module Servant.Server.CombinatorUtils (
|
module Servant.Server.CombinatorUtils (
|
||||||
CombinatorImplementation,
|
CombinatorImplementation,
|
||||||
runCI,
|
runCI,
|
||||||
captureCombinator,
|
implementCaptureCombinator,
|
||||||
|
implementRequestCheck,
|
||||||
|
implementAuthCombinator,
|
||||||
argumentCombinator,
|
argumentCombinator,
|
||||||
-- * re-exports
|
-- * re-exports
|
||||||
RouteResult(..),
|
RouteResult(..),
|
||||||
|
@ -25,37 +27,60 @@ data CombinatorImplementation combinator arg api context where
|
||||||
CI :: (forall env .
|
CI :: (forall env .
|
||||||
Proxy (combinator :> api)
|
Proxy (combinator :> api)
|
||||||
-> Context context
|
-> Context context
|
||||||
-> Delayed env (arg -> Server api)
|
-> Delayed env (WithArg arg (Server api))
|
||||||
-> Router' env RoutingApplication)
|
-> Router' env RoutingApplication)
|
||||||
-> CombinatorImplementation combinator arg api context
|
-> CombinatorImplementation combinator arg api context
|
||||||
|
|
||||||
|
type family WithArg arg rest where
|
||||||
|
WithArg () rest = rest
|
||||||
|
WithArg arg rest = arg -> rest
|
||||||
|
|
||||||
runCI :: CombinatorImplementation combinator arg api context
|
runCI :: CombinatorImplementation combinator arg api context
|
||||||
-> Proxy (combinator :> api)
|
-> Proxy (combinator :> api)
|
||||||
-> Context context
|
-> Context context
|
||||||
-> Delayed env (arg -> Server api)
|
-> Delayed env (WithArg arg (Server api))
|
||||||
-> Router' env RoutingApplication
|
-> Router' env RoutingApplication
|
||||||
runCI (CI i) = i
|
runCI (CI i) = i
|
||||||
|
|
||||||
captureCombinator ::
|
implementCaptureCombinator ::
|
||||||
forall api combinator arg context .
|
forall api combinator arg context .
|
||||||
(HasServer api context) =>
|
(HasServer api context,
|
||||||
|
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
|
||||||
(Text -> RouteResult arg)
|
(Text -> RouteResult arg)
|
||||||
-> CombinatorImplementation combinator arg api context
|
-> CombinatorImplementation combinator arg api context
|
||||||
captureCombinator getArg = CI $ \ Proxy context delayed ->
|
implementCaptureCombinator getArg = CI $ \ Proxy context delayed ->
|
||||||
CaptureRouter $
|
CaptureRouter $
|
||||||
route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured ->
|
route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured ->
|
||||||
(liftRouteResult (getArg captured))
|
(liftRouteResult (getArg captured))
|
||||||
|
|
||||||
|
implementRequestCheck ::
|
||||||
|
forall api combinator context .
|
||||||
|
(HasServer api context,
|
||||||
|
WithArg () (ServerT api Handler) ~ ServerT api Handler) =>
|
||||||
|
(Request -> RouteResult ())
|
||||||
|
-> CombinatorImplementation combinator () api context
|
||||||
|
implementRequestCheck check = CI $ \ Proxy context delayed ->
|
||||||
|
route (Proxy :: Proxy api) context $ addMethodCheck delayed $
|
||||||
|
withRequest $ \ request -> liftRouteResult $ check request
|
||||||
|
|
||||||
|
implementAuthCombinator ::
|
||||||
|
forall api combinator arg context .
|
||||||
|
(HasServer api context,
|
||||||
|
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
|
||||||
|
(Request -> RouteResult arg)
|
||||||
|
-> CombinatorImplementation combinator arg api context
|
||||||
|
implementAuthCombinator authCheck = CI $ \ Proxy context delayed ->
|
||||||
|
route (Proxy :: Proxy api) context $ addAuthCheck delayed $
|
||||||
|
withRequest $ \ request -> liftRouteResult $ authCheck request
|
||||||
|
|
||||||
argumentCombinator ::
|
argumentCombinator ::
|
||||||
forall api combinator arg context .
|
forall api combinator arg context .
|
||||||
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
|
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
|
||||||
|
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
|
||||||
HasServer api context) =>
|
HasServer api context) =>
|
||||||
(Request -> RouteResult arg)
|
(Request -> RouteResult arg)
|
||||||
-> CombinatorImplementation combinator arg api context
|
-> CombinatorImplementation combinator arg api context
|
||||||
argumentCombinator getArg = CI $ \ Proxy context delayed ->
|
argumentCombinator getArg = CI $ \ Proxy context delayed ->
|
||||||
route (Proxy :: Proxy api) context $
|
route (Proxy :: Proxy api) context $ addBodyCheck delayed -- fixme: shouldn't be body
|
||||||
addBodyCheck delayed contentTypeCheck bodyCheck
|
(return ())
|
||||||
where
|
(\ () -> withRequest $ \ request -> liftRouteResult (getArg request))
|
||||||
contentTypeCheck = return ()
|
|
||||||
bodyCheck () = withRequest $ \ request ->
|
|
||||||
liftRouteResult (getArg request)
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
@ -48,24 +49,6 @@ responseBodyLbs response = do
|
||||||
|
|
||||||
spec :: Spec
|
spec :: Spec
|
||||||
spec = do
|
spec = do
|
||||||
it "allows to write a combinator by providing a function (Request -> a)" $ do
|
|
||||||
let server = return
|
|
||||||
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
|
|
||||||
request = defaultRequest{
|
|
||||||
requestHeaders =
|
|
||||||
("FooHeader", "foo") :
|
|
||||||
requestHeaders defaultRequest
|
|
||||||
}
|
|
||||||
response <- runApp app request
|
|
||||||
responseBodyLbs response `shouldReturn` "\"foo\""
|
|
||||||
|
|
||||||
it "allows to write a combinator the errors out" $ do
|
|
||||||
let server = return
|
|
||||||
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
|
|
||||||
request = defaultRequest
|
|
||||||
response <- runApp app request
|
|
||||||
responseStatus response `shouldBe` status400
|
|
||||||
|
|
||||||
it "allows to write capture combinators" $ do
|
it "allows to write capture combinators" $ do
|
||||||
let server = return
|
let server = return
|
||||||
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
|
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
|
||||||
|
@ -76,6 +59,53 @@ spec = do
|
||||||
response <- runApp app request
|
response <- runApp app request
|
||||||
responseBodyLbs response `shouldReturn` "\"foo\""
|
responseBodyLbs response `shouldReturn` "\"foo\""
|
||||||
|
|
||||||
|
it "allows to write request check combinators" $ do
|
||||||
|
let server = return ()
|
||||||
|
app = serve (Proxy :: Proxy (CheckFooHeader :> Get' ())) server
|
||||||
|
request = defaultRequest{
|
||||||
|
requestHeaders =
|
||||||
|
("Foo", "foo") :
|
||||||
|
requestHeaders defaultRequest
|
||||||
|
}
|
||||||
|
response <- runApp app request
|
||||||
|
responseBodyLbs response `shouldReturn` "[]"
|
||||||
|
|
||||||
|
it "allows to write a combinator that errors out" $ do
|
||||||
|
let server = return
|
||||||
|
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
|
||||||
|
request = defaultRequest {
|
||||||
|
rawPathInfo = "/error",
|
||||||
|
pathInfo = ["error"]
|
||||||
|
}
|
||||||
|
response <- runApp app request
|
||||||
|
responseStatus response `shouldBe` status418
|
||||||
|
|
||||||
|
it "allows to write a combinator using IO" $ do
|
||||||
|
pending
|
||||||
|
|
||||||
|
it "allows to write a combinator by providing a function (Request -> a)" $ do
|
||||||
|
let server = return
|
||||||
|
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
|
||||||
|
request = defaultRequest{
|
||||||
|
requestHeaders =
|
||||||
|
("Foo", "foo") :
|
||||||
|
requestHeaders defaultRequest
|
||||||
|
}
|
||||||
|
response <- runApp app request
|
||||||
|
responseBodyLbs response `shouldReturn` "\"foo\""
|
||||||
|
|
||||||
|
it "allows to write an auth combinator" $ do
|
||||||
|
let server (User name) = return name
|
||||||
|
app = serve (Proxy :: Proxy (AuthCombinator :> Get' String)) server
|
||||||
|
request = defaultRequest{
|
||||||
|
requestHeaders =
|
||||||
|
("Auth", "secret") :
|
||||||
|
requestHeaders defaultRequest
|
||||||
|
}
|
||||||
|
response <- runApp app request
|
||||||
|
responseStatus response `shouldBe` ok200
|
||||||
|
responseBodyLbs response `shouldReturn` "\"Alice\""
|
||||||
|
|
||||||
it "allows to pick the request check phase" $ do
|
it "allows to pick the request check phase" $ do
|
||||||
pending
|
pending
|
||||||
|
|
||||||
|
@ -96,6 +126,43 @@ spec = do
|
||||||
|
|
||||||
type Get' = Get '[JSON]
|
type Get' = Get '[JSON]
|
||||||
|
|
||||||
|
data StringCapture
|
||||||
|
|
||||||
|
instance HasServer api context => HasServer (StringCapture :> api) context where
|
||||||
|
type ServerT (StringCapture :> api) m = String -> ServerT api m
|
||||||
|
route = runCI $ implementCaptureCombinator getCapture
|
||||||
|
|
||||||
|
getCapture :: Text -> RouteResult String
|
||||||
|
getCapture = \case
|
||||||
|
"error" -> FailFatal $ ServantErr 418 "I'm a teapot" "" []
|
||||||
|
text -> Route $ cs text
|
||||||
|
|
||||||
|
data CheckFooHeader
|
||||||
|
|
||||||
|
instance HasServer api context => HasServer (CheckFooHeader :> api) context where
|
||||||
|
type ServerT (CheckFooHeader :> api) m = ServerT api m
|
||||||
|
route = runCI $ implementRequestCheck checkFooHeader
|
||||||
|
|
||||||
|
checkFooHeader :: Request -> RouteResult ()
|
||||||
|
checkFooHeader request = case lookup "Foo" (requestHeaders request) of
|
||||||
|
Just _ -> Route ()
|
||||||
|
Nothing -> FailFatal err400
|
||||||
|
|
||||||
|
data AuthCombinator
|
||||||
|
|
||||||
|
data User = User String
|
||||||
|
deriving (Eq, Show)
|
||||||
|
|
||||||
|
instance HasServer api context => HasServer (AuthCombinator :> api) context where
|
||||||
|
type ServerT (AuthCombinator :> api) m = User -> ServerT api m
|
||||||
|
route = runCI $ implementAuthCombinator checkAuth
|
||||||
|
|
||||||
|
checkAuth :: Request -> RouteResult User
|
||||||
|
checkAuth request = case lookup "Auth" (requestHeaders request) of
|
||||||
|
Just "secret" -> Route $ User "Alice"
|
||||||
|
Just _ -> FailFatal err401
|
||||||
|
Nothing -> FailFatal err400
|
||||||
|
|
||||||
data FooHeader
|
data FooHeader
|
||||||
|
|
||||||
instance HasServer api context => HasServer (FooHeader :> api) context where
|
instance HasServer api context => HasServer (FooHeader :> api) context where
|
||||||
|
@ -103,15 +170,6 @@ instance HasServer api context => HasServer (FooHeader :> api) context where
|
||||||
route = runCI $ argumentCombinator getCustom
|
route = runCI $ argumentCombinator getCustom
|
||||||
|
|
||||||
getCustom :: Request -> RouteResult String
|
getCustom :: Request -> RouteResult String
|
||||||
getCustom request = case lookup "FooHeader" (requestHeaders request) of
|
getCustom request = case lookup "Foo" (requestHeaders request) of
|
||||||
Nothing -> FailFatal err400
|
Nothing -> FailFatal err400
|
||||||
Just l -> Route $ cs l
|
Just l -> Route $ cs l
|
||||||
|
|
||||||
data StringCapture
|
|
||||||
|
|
||||||
instance HasServer api context => HasServer (StringCapture :> api) context where
|
|
||||||
type ServerT (StringCapture :> api) m = String -> ServerT api m
|
|
||||||
route = runCI $ captureCombinator getCapture
|
|
||||||
|
|
||||||
getCapture :: Text -> RouteResult String
|
|
||||||
getCapture = Route . cs
|
|
||||||
|
|
Loading…
Reference in a new issue