diff --git a/servant-server/src/Servant/Server/Utils/CustomCombinators.hs b/servant-server/src/Servant/Server/Utils/CustomCombinators.hs index ad91bd07..dcbd2d94 100644 --- a/servant-server/src/Servant/Server/Utils/CustomCombinators.hs +++ b/servant-server/src/Servant/Server/Utils/CustomCombinators.hs @@ -7,6 +7,7 @@ -- fixme: document phases -- fixme: add doctests -- fixme: document that the req body can only be consumed once +-- fixme: document dependency problem module Servant.Server.Utils.CustomCombinators ( CombinatorImplementation, @@ -57,60 +58,60 @@ makeCaptureCombinator :: forall api combinator arg context . (HasServer api context, WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) => - (Text -> IO (RouteResult arg)) + (Context context -> Text -> IO (RouteResult arg)) -> CombinatorImplementation combinator arg api context makeCaptureCombinator getArg = CI $ \ Proxy context delayed -> CaptureRouter $ route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured -> - (liftRouteResult =<< liftIO (getArg captured)) + (liftRouteResult =<< liftIO (getArg context captured)) makeRequestCheckCombinator :: forall api combinator context . (HasServer api context, WithArg () (ServerT api Handler) ~ ServerT api Handler) => - (Request -> IO (RouteResult ())) + (Context context -> Request -> IO (RouteResult ())) -> CombinatorImplementation combinator () api context makeRequestCheckCombinator check = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addMethodCheck delayed $ withRequest $ \ request -> - liftRouteResult =<< liftIO (check (protectBody "makeRequestCheckCombinator" request)) + liftRouteResult =<< liftIO (check context (protectBody "makeRequestCheckCombinator" request)) makeAuthCombinator :: forall api combinator arg context . (HasServer api context, WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) => - (Request -> IO (RouteResult arg)) + (Context context -> Request -> IO (RouteResult arg)) -> CombinatorImplementation combinator arg api context makeAuthCombinator authCheck = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addAuthCheck delayed $ withRequest $ \ request -> - liftRouteResult =<< liftIO (authCheck (protectBody "makeAuthCombinator" request)) + liftRouteResult =<< liftIO (authCheck context (protectBody "makeAuthCombinator" request)) makeReqBodyCombinator :: forall api combinator arg context . (ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler), WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler), HasServer api context) => - (IO ByteString -> arg) + (Context context -> IO ByteString -> arg) -> CombinatorImplementation combinator arg api context makeReqBodyCombinator getArg = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addBodyCheck delayed (return ()) (\ () -> withRequest $ \ request -> - liftRouteResult $ Route $ getArg $ requestBody request) + liftRouteResult $ Route $ getArg context $ requestBody request) makeCombinator :: forall api combinator arg context . (ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler), WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler), HasServer api context) => - (Request -> IO (RouteResult arg)) + (Context context -> Request -> IO (RouteResult arg)) -> CombinatorImplementation combinator arg api context makeCombinator getArg = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addBodyCheck delayed (return ()) (\ () -> withRequest $ \ request -> - liftRouteResult =<< liftIO (getArg (protectBody "makeCombinator" request))) + liftRouteResult =<< liftIO (getArg context (protectBody "makeCombinator" request))) protectBody :: String -> Request -> Request protectBody name request = request{ diff --git a/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs b/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs index e2eead46..a1d62a1e 100644 --- a/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs +++ b/servant-server/test/Servant/Server/Utils/CustomCombinatorsSpec.hs @@ -23,7 +23,7 @@ import Data.Text hiding (map) import Network.HTTP.Types import Network.Wai import Network.Wai.Internal -import Test.Hspec +import Test.Hspec hiding (context) import Servant.API import Servant.Server @@ -113,6 +113,20 @@ spec = do runApp app request `shouldThrow` errorCall "ERROR: makeAuthCombinator: combinator must not access the request body" + it "allows to access the context" $ do + let server (User name) = return name + context :: Context '[ [(SBS.ByteString, User)] ] + context = [("secret", User "Bob")] :. EmptyContext + app = serveWithContext (Proxy :: Proxy (AuthWithContext :> Get' String)) context server + request = defaultRequest{ + requestHeaders = + ("Auth", "secret") : + requestHeaders defaultRequest + } + response <- runApp app request + responseStatus response `shouldBe` ok200 + responseBodyLbs response `shouldReturn` "\"Bob\"" + describe "makeCombinator" $ do it "allows to write a combinator by providing a function (Request -> a)" $ do let server = return @@ -161,9 +175,6 @@ spec = do response <- runApp app request responseBodyLbs response `shouldReturn` "\"foobar\"" - it "allows to access the context" $ do - pending - it "allows to implement combinators in terms of existing combinators" $ do pending @@ -175,7 +186,7 @@ data StringCapture instance HasServer api context => HasServer (StringCapture :> api) context where type ServerT (StringCapture :> api) m = String -> ServerT api m - route = runCI $ makeCaptureCombinator getCapture + route = runCI $ makeCaptureCombinator (const getCapture) getCapture :: Text -> IO (RouteResult String) getCapture snippet = return $ case snippet of @@ -188,7 +199,7 @@ data CheckFooHeader instance HasServer api context => HasServer (CheckFooHeader :> api) context where type ServerT (CheckFooHeader :> api) m = ServerT api m - route = runCI $ makeRequestCheckCombinator checkFooHeader + route = runCI $ makeRequestCheckCombinator (const checkFooHeader) checkFooHeader :: Request -> IO (RouteResult ()) checkFooHeader request = return $ @@ -201,7 +212,7 @@ data InvalidRequestCheckCombinator instance HasServer api context => HasServer (InvalidRequestCheckCombinator :> api) context where type ServerT (InvalidRequestCheckCombinator :> api) m = ServerT api m - route = runCI $ makeRequestCheckCombinator accessReqBody + route = runCI $ makeRequestCheckCombinator (const accessReqBody) accessReqBody :: Request -> IO (RouteResult ()) accessReqBody request = do @@ -217,7 +228,7 @@ data User = User String instance HasServer api context => HasServer (AuthCombinator :> api) context where type ServerT (AuthCombinator :> api) m = User -> ServerT api m - route = runCI $ makeAuthCombinator checkAuth + route = runCI $ makeAuthCombinator (const checkAuth) checkAuth :: Request -> IO (RouteResult User) checkAuth request = return $ case lookup "Auth" (requestHeaders request) of @@ -230,20 +241,37 @@ data InvalidAuthCombinator instance HasServer api context => HasServer (InvalidAuthCombinator :> api) context where type ServerT (InvalidAuthCombinator :> api) m = User -> ServerT api m - route = runCI $ makeAuthCombinator authWithReqBody + route = runCI $ makeAuthCombinator (const authWithReqBody) authWithReqBody :: Request -> IO (RouteResult User) authWithReqBody request = do body <- fromBody $ requestBody request deepseq body (return $ Route $ User $ cs body) +data AuthWithContext + +instance (HasContextEntry context [(SBS.ByteString, User)], HasServer api context) => + HasServer (AuthWithContext :> api) context where + type ServerT (AuthWithContext :> api) m = User -> ServerT api m + route = runCI $ makeAuthCombinator authWithContext + +-- fixme: remove foralls from haddock + +authWithContext :: (HasContextEntry context [(SBS.ByteString, User)]) => + Context context -> Request -> IO (RouteResult User) +authWithContext context request = return $ case lookup "Auth" (requestHeaders request) of + Just authToken -> case lookup authToken userDict of + Just user -> Route user + where + userDict = getContextEntry context + -- * general combinators data FooHeader instance HasServer api context => HasServer (FooHeader :> api) context where type ServerT (FooHeader :> api) m = String -> ServerT api m - route = runCI $ makeCombinator getCustom + route = runCI $ makeCombinator (const getCustom) getCustom :: Request -> IO (RouteResult String) getCustom request = return $ case lookup "Foo" (requestHeaders request) of @@ -258,7 +286,7 @@ data Source = Source (IO SBS.ByteString) instance HasServer api context => HasServer (StreamRequest :> api) context where type ServerT (StreamRequest :> api) m = Source -> ServerT api m - route = runCI $ makeReqBodyCombinator getSource + route = runCI $ makeReqBodyCombinator (const getSource) getSource :: IO SBS.ByteString -> Source getSource = Source