add implementAuthCombinator and implementRequestCheck

This commit is contained in:
Sönke Hahn 2016-10-23 16:29:06 -04:00
parent 7177f0a729
commit be5e6e59c7
2 changed files with 123 additions and 40 deletions

View file

@ -7,7 +7,9 @@
module Servant.Server.CombinatorUtils (
CombinatorImplementation,
runCI,
captureCombinator,
implementCaptureCombinator,
implementRequestCheck,
implementAuthCombinator,
argumentCombinator,
-- * re-exports
RouteResult(..),
@ -25,37 +27,60 @@ data CombinatorImplementation combinator arg api context where
CI :: (forall env .
Proxy (combinator :> api)
-> Context context
-> Delayed env (arg -> Server api)
-> Delayed env (WithArg arg (Server api))
-> Router' env RoutingApplication)
-> CombinatorImplementation combinator arg api context
type family WithArg arg rest where
WithArg () rest = rest
WithArg arg rest = arg -> rest
runCI :: CombinatorImplementation combinator arg api context
-> Proxy (combinator :> api)
-> Context context
-> Delayed env (arg -> Server api)
-> Delayed env (WithArg arg (Server api))
-> Router' env RoutingApplication
runCI (CI i) = i
captureCombinator ::
implementCaptureCombinator ::
forall api combinator arg context .
(HasServer api context) =>
(HasServer api context,
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
(Text -> RouteResult arg)
-> CombinatorImplementation combinator arg api context
captureCombinator getArg = CI $ \ Proxy context delayed ->
implementCaptureCombinator getArg = CI $ \ Proxy context delayed ->
CaptureRouter $
route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured ->
(liftRouteResult (getArg captured))
implementRequestCheck ::
forall api combinator context .
(HasServer api context,
WithArg () (ServerT api Handler) ~ ServerT api Handler) =>
(Request -> RouteResult ())
-> CombinatorImplementation combinator () api context
implementRequestCheck check = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addMethodCheck delayed $
withRequest $ \ request -> liftRouteResult $ check request
implementAuthCombinator ::
forall api combinator arg context .
(HasServer api context,
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
(Request -> RouteResult arg)
-> CombinatorImplementation combinator arg api context
implementAuthCombinator authCheck = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addAuthCheck delayed $
withRequest $ \ request -> liftRouteResult $ authCheck request
argumentCombinator ::
forall api combinator arg context .
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
HasServer api context) =>
(Request -> RouteResult arg)
-> CombinatorImplementation combinator arg api context
argumentCombinator getArg = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $
addBodyCheck delayed contentTypeCheck bodyCheck
where
contentTypeCheck = return ()
bodyCheck () = withRequest $ \ request ->
liftRouteResult (getArg request)
route (Proxy :: Proxy api) context $ addBodyCheck delayed -- fixme: shouldn't be body
(return ())
(\ () -> withRequest $ \ request -> liftRouteResult (getArg request))

View file

@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
@ -48,24 +49,6 @@ responseBodyLbs response = do
spec :: Spec
spec = do
it "allows to write a combinator by providing a function (Request -> a)" $ do
let server = return
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
request = defaultRequest{
requestHeaders =
("FooHeader", "foo") :
requestHeaders defaultRequest
}
response <- runApp app request
responseBodyLbs response `shouldReturn` "\"foo\""
it "allows to write a combinator the errors out" $ do
let server = return
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
request = defaultRequest
response <- runApp app request
responseStatus response `shouldBe` status400
it "allows to write capture combinators" $ do
let server = return
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
@ -76,6 +59,53 @@ spec = do
response <- runApp app request
responseBodyLbs response `shouldReturn` "\"foo\""
it "allows to write request check combinators" $ do
let server = return ()
app = serve (Proxy :: Proxy (CheckFooHeader :> Get' ())) server
request = defaultRequest{
requestHeaders =
("Foo", "foo") :
requestHeaders defaultRequest
}
response <- runApp app request
responseBodyLbs response `shouldReturn` "[]"
it "allows to write a combinator that errors out" $ do
let server = return
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
request = defaultRequest {
rawPathInfo = "/error",
pathInfo = ["error"]
}
response <- runApp app request
responseStatus response `shouldBe` status418
it "allows to write a combinator using IO" $ do
pending
it "allows to write a combinator by providing a function (Request -> a)" $ do
let server = return
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
request = defaultRequest{
requestHeaders =
("Foo", "foo") :
requestHeaders defaultRequest
}
response <- runApp app request
responseBodyLbs response `shouldReturn` "\"foo\""
it "allows to write an auth combinator" $ do
let server (User name) = return name
app = serve (Proxy :: Proxy (AuthCombinator :> Get' String)) server
request = defaultRequest{
requestHeaders =
("Auth", "secret") :
requestHeaders defaultRequest
}
response <- runApp app request
responseStatus response `shouldBe` ok200
responseBodyLbs response `shouldReturn` "\"Alice\""
it "allows to pick the request check phase" $ do
pending
@ -96,6 +126,43 @@ spec = do
type Get' = Get '[JSON]
data StringCapture
instance HasServer api context => HasServer (StringCapture :> api) context where
type ServerT (StringCapture :> api) m = String -> ServerT api m
route = runCI $ implementCaptureCombinator getCapture
getCapture :: Text -> RouteResult String
getCapture = \case
"error" -> FailFatal $ ServantErr 418 "I'm a teapot" "" []
text -> Route $ cs text
data CheckFooHeader
instance HasServer api context => HasServer (CheckFooHeader :> api) context where
type ServerT (CheckFooHeader :> api) m = ServerT api m
route = runCI $ implementRequestCheck checkFooHeader
checkFooHeader :: Request -> RouteResult ()
checkFooHeader request = case lookup "Foo" (requestHeaders request) of
Just _ -> Route ()
Nothing -> FailFatal err400
data AuthCombinator
data User = User String
deriving (Eq, Show)
instance HasServer api context => HasServer (AuthCombinator :> api) context where
type ServerT (AuthCombinator :> api) m = User -> ServerT api m
route = runCI $ implementAuthCombinator checkAuth
checkAuth :: Request -> RouteResult User
checkAuth request = case lookup "Auth" (requestHeaders request) of
Just "secret" -> Route $ User "Alice"
Just _ -> FailFatal err401
Nothing -> FailFatal err400
data FooHeader
instance HasServer api context => HasServer (FooHeader :> api) context where
@ -103,15 +170,6 @@ instance HasServer api context => HasServer (FooHeader :> api) context where
route = runCI $ argumentCombinator getCustom
getCustom :: Request -> RouteResult String
getCustom request = case lookup "FooHeader" (requestHeaders request) of
getCustom request = case lookup "Foo" (requestHeaders request) of
Nothing -> FailFatal err400
Just l -> Route $ cs l
data StringCapture
instance HasServer api context => HasServer (StringCapture :> api) context where
type ServerT (StringCapture :> api) m = String -> ServerT api m
route = runCI $ captureCombinator getCapture
getCapture :: Text -> RouteResult String
getCapture = Route . cs