allow to access contexts

This commit is contained in:
Sönke Hahn 2016-10-23 22:38:01 -04:00
parent f9085b6b7a
commit d7587d1df9
2 changed files with 50 additions and 21 deletions

View file

@ -7,6 +7,7 @@
-- fixme: document phases -- fixme: document phases
-- fixme: add doctests -- fixme: add doctests
-- fixme: document that the req body can only be consumed once -- fixme: document that the req body can only be consumed once
-- fixme: document dependency problem
module Servant.Server.Utils.CustomCombinators ( module Servant.Server.Utils.CustomCombinators (
CombinatorImplementation, CombinatorImplementation,
@ -57,60 +58,60 @@ makeCaptureCombinator ::
forall api combinator arg context . forall api combinator arg context .
(HasServer api context, (HasServer api context,
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) => WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
(Text -> IO (RouteResult arg)) (Context context -> Text -> IO (RouteResult arg))
-> CombinatorImplementation combinator arg api context -> CombinatorImplementation combinator arg api context
makeCaptureCombinator getArg = CI $ \ Proxy context delayed -> makeCaptureCombinator getArg = CI $ \ Proxy context delayed ->
CaptureRouter $ CaptureRouter $
route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured -> route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured ->
(liftRouteResult =<< liftIO (getArg captured)) (liftRouteResult =<< liftIO (getArg context captured))
makeRequestCheckCombinator :: makeRequestCheckCombinator ::
forall api combinator context . forall api combinator context .
(HasServer api context, (HasServer api context,
WithArg () (ServerT api Handler) ~ ServerT api Handler) => WithArg () (ServerT api Handler) ~ ServerT api Handler) =>
(Request -> IO (RouteResult ())) (Context context -> Request -> IO (RouteResult ()))
-> CombinatorImplementation combinator () api context -> CombinatorImplementation combinator () api context
makeRequestCheckCombinator check = CI $ \ Proxy context delayed -> makeRequestCheckCombinator check = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addMethodCheck delayed $ route (Proxy :: Proxy api) context $ addMethodCheck delayed $
withRequest $ \ request -> withRequest $ \ request ->
liftRouteResult =<< liftIO (check (protectBody "makeRequestCheckCombinator" request)) liftRouteResult =<< liftIO (check context (protectBody "makeRequestCheckCombinator" request))
makeAuthCombinator :: makeAuthCombinator ::
forall api combinator arg context . forall api combinator arg context .
(HasServer api context, (HasServer api context,
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) => WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
(Request -> IO (RouteResult arg)) (Context context -> Request -> IO (RouteResult arg))
-> CombinatorImplementation combinator arg api context -> CombinatorImplementation combinator arg api context
makeAuthCombinator authCheck = CI $ \ Proxy context delayed -> makeAuthCombinator authCheck = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addAuthCheck delayed $ route (Proxy :: Proxy api) context $ addAuthCheck delayed $
withRequest $ \ request -> withRequest $ \ request ->
liftRouteResult =<< liftIO (authCheck (protectBody "makeAuthCombinator" request)) liftRouteResult =<< liftIO (authCheck context (protectBody "makeAuthCombinator" request))
makeReqBodyCombinator :: makeReqBodyCombinator ::
forall api combinator arg context . forall api combinator arg context .
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler), (ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler), WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
HasServer api context) => HasServer api context) =>
(IO ByteString -> arg) (Context context -> IO ByteString -> arg)
-> CombinatorImplementation combinator arg api context -> CombinatorImplementation combinator arg api context
makeReqBodyCombinator getArg = CI $ \ Proxy context delayed -> makeReqBodyCombinator getArg = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addBodyCheck delayed route (Proxy :: Proxy api) context $ addBodyCheck delayed
(return ()) (return ())
(\ () -> withRequest $ \ request -> (\ () -> withRequest $ \ request ->
liftRouteResult $ Route $ getArg $ requestBody request) liftRouteResult $ Route $ getArg context $ requestBody request)
makeCombinator :: makeCombinator ::
forall api combinator arg context . forall api combinator arg context .
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler), (ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler), WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
HasServer api context) => HasServer api context) =>
(Request -> IO (RouteResult arg)) (Context context -> Request -> IO (RouteResult arg))
-> CombinatorImplementation combinator arg api context -> CombinatorImplementation combinator arg api context
makeCombinator getArg = CI $ \ Proxy context delayed -> makeCombinator getArg = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addBodyCheck delayed route (Proxy :: Proxy api) context $ addBodyCheck delayed
(return ()) (return ())
(\ () -> withRequest $ \ request -> (\ () -> withRequest $ \ request ->
liftRouteResult =<< liftIO (getArg (protectBody "makeCombinator" request))) liftRouteResult =<< liftIO (getArg context (protectBody "makeCombinator" request)))
protectBody :: String -> Request -> Request protectBody :: String -> Request -> Request
protectBody name request = request{ protectBody name request = request{

View file

@ -23,7 +23,7 @@ import Data.Text hiding (map)
import Network.HTTP.Types import Network.HTTP.Types
import Network.Wai import Network.Wai
import Network.Wai.Internal import Network.Wai.Internal
import Test.Hspec import Test.Hspec hiding (context)
import Servant.API import Servant.API
import Servant.Server import Servant.Server
@ -113,6 +113,20 @@ spec = do
runApp app request `shouldThrow` runApp app request `shouldThrow`
errorCall "ERROR: makeAuthCombinator: combinator must not access the request body" errorCall "ERROR: makeAuthCombinator: combinator must not access the request body"
it "allows to access the context" $ do
let server (User name) = return name
context :: Context '[ [(SBS.ByteString, User)] ]
context = [("secret", User "Bob")] :. EmptyContext
app = serveWithContext (Proxy :: Proxy (AuthWithContext :> Get' String)) context server
request = defaultRequest{
requestHeaders =
("Auth", "secret") :
requestHeaders defaultRequest
}
response <- runApp app request
responseStatus response `shouldBe` ok200
responseBodyLbs response `shouldReturn` "\"Bob\""
describe "makeCombinator" $ do describe "makeCombinator" $ do
it "allows to write a combinator by providing a function (Request -> a)" $ do it "allows to write a combinator by providing a function (Request -> a)" $ do
let server = return let server = return
@ -161,9 +175,6 @@ spec = do
response <- runApp app request response <- runApp app request
responseBodyLbs response `shouldReturn` "\"foobar\"" responseBodyLbs response `shouldReturn` "\"foobar\""
it "allows to access the context" $ do
pending
it "allows to implement combinators in terms of existing combinators" $ do it "allows to implement combinators in terms of existing combinators" $ do
pending pending
@ -175,7 +186,7 @@ data StringCapture
instance HasServer api context => HasServer (StringCapture :> api) context where instance HasServer api context => HasServer (StringCapture :> api) context where
type ServerT (StringCapture :> api) m = String -> ServerT api m type ServerT (StringCapture :> api) m = String -> ServerT api m
route = runCI $ makeCaptureCombinator getCapture route = runCI $ makeCaptureCombinator (const getCapture)
getCapture :: Text -> IO (RouteResult String) getCapture :: Text -> IO (RouteResult String)
getCapture snippet = return $ case snippet of getCapture snippet = return $ case snippet of
@ -188,7 +199,7 @@ data CheckFooHeader
instance HasServer api context => HasServer (CheckFooHeader :> api) context where instance HasServer api context => HasServer (CheckFooHeader :> api) context where
type ServerT (CheckFooHeader :> api) m = ServerT api m type ServerT (CheckFooHeader :> api) m = ServerT api m
route = runCI $ makeRequestCheckCombinator checkFooHeader route = runCI $ makeRequestCheckCombinator (const checkFooHeader)
checkFooHeader :: Request -> IO (RouteResult ()) checkFooHeader :: Request -> IO (RouteResult ())
checkFooHeader request = return $ checkFooHeader request = return $
@ -201,7 +212,7 @@ data InvalidRequestCheckCombinator
instance HasServer api context => HasServer (InvalidRequestCheckCombinator :> api) context where instance HasServer api context => HasServer (InvalidRequestCheckCombinator :> api) context where
type ServerT (InvalidRequestCheckCombinator :> api) m = ServerT api m type ServerT (InvalidRequestCheckCombinator :> api) m = ServerT api m
route = runCI $ makeRequestCheckCombinator accessReqBody route = runCI $ makeRequestCheckCombinator (const accessReqBody)
accessReqBody :: Request -> IO (RouteResult ()) accessReqBody :: Request -> IO (RouteResult ())
accessReqBody request = do accessReqBody request = do
@ -217,7 +228,7 @@ data User = User String
instance HasServer api context => HasServer (AuthCombinator :> api) context where instance HasServer api context => HasServer (AuthCombinator :> api) context where
type ServerT (AuthCombinator :> api) m = User -> ServerT api m type ServerT (AuthCombinator :> api) m = User -> ServerT api m
route = runCI $ makeAuthCombinator checkAuth route = runCI $ makeAuthCombinator (const checkAuth)
checkAuth :: Request -> IO (RouteResult User) checkAuth :: Request -> IO (RouteResult User)
checkAuth request = return $ case lookup "Auth" (requestHeaders request) of checkAuth request = return $ case lookup "Auth" (requestHeaders request) of
@ -230,20 +241,37 @@ data InvalidAuthCombinator
instance HasServer api context => HasServer (InvalidAuthCombinator :> api) context where instance HasServer api context => HasServer (InvalidAuthCombinator :> api) context where
type ServerT (InvalidAuthCombinator :> api) m = User -> ServerT api m type ServerT (InvalidAuthCombinator :> api) m = User -> ServerT api m
route = runCI $ makeAuthCombinator authWithReqBody route = runCI $ makeAuthCombinator (const authWithReqBody)
authWithReqBody :: Request -> IO (RouteResult User) authWithReqBody :: Request -> IO (RouteResult User)
authWithReqBody request = do authWithReqBody request = do
body <- fromBody $ requestBody request body <- fromBody $ requestBody request
deepseq body (return $ Route $ User $ cs body) deepseq body (return $ Route $ User $ cs body)
data AuthWithContext
instance (HasContextEntry context [(SBS.ByteString, User)], HasServer api context) =>
HasServer (AuthWithContext :> api) context where
type ServerT (AuthWithContext :> api) m = User -> ServerT api m
route = runCI $ makeAuthCombinator authWithContext
-- fixme: remove foralls from haddock
authWithContext :: (HasContextEntry context [(SBS.ByteString, User)]) =>
Context context -> Request -> IO (RouteResult User)
authWithContext context request = return $ case lookup "Auth" (requestHeaders request) of
Just authToken -> case lookup authToken userDict of
Just user -> Route user
where
userDict = getContextEntry context
-- * general combinators -- * general combinators
data FooHeader data FooHeader
instance HasServer api context => HasServer (FooHeader :> api) context where instance HasServer api context => HasServer (FooHeader :> api) context where
type ServerT (FooHeader :> api) m = String -> ServerT api m type ServerT (FooHeader :> api) m = String -> ServerT api m
route = runCI $ makeCombinator getCustom route = runCI $ makeCombinator (const getCustom)
getCustom :: Request -> IO (RouteResult String) getCustom :: Request -> IO (RouteResult String)
getCustom request = return $ case lookup "Foo" (requestHeaders request) of getCustom request = return $ case lookup "Foo" (requestHeaders request) of
@ -258,7 +286,7 @@ data Source = Source (IO SBS.ByteString)
instance HasServer api context => HasServer (StreamRequest :> api) context where instance HasServer api context => HasServer (StreamRequest :> api) context where
type ServerT (StreamRequest :> api) m = Source -> ServerT api m type ServerT (StreamRequest :> api) m = Source -> ServerT api m
route = runCI $ makeReqBodyCombinator getSource route = runCI $ makeReqBodyCombinator (const getSource)
getSource :: IO SBS.ByteString -> Source getSource :: IO SBS.ByteString -> Source
getSource = Source getSource = Source