allow to access contexts
This commit is contained in:
parent
f9085b6b7a
commit
d7587d1df9
2 changed files with 50 additions and 21 deletions
|
@ -7,6 +7,7 @@
|
||||||
-- fixme: document phases
|
-- fixme: document phases
|
||||||
-- fixme: add doctests
|
-- fixme: add doctests
|
||||||
-- fixme: document that the req body can only be consumed once
|
-- fixme: document that the req body can only be consumed once
|
||||||
|
-- fixme: document dependency problem
|
||||||
|
|
||||||
module Servant.Server.Utils.CustomCombinators (
|
module Servant.Server.Utils.CustomCombinators (
|
||||||
CombinatorImplementation,
|
CombinatorImplementation,
|
||||||
|
@ -57,60 +58,60 @@ makeCaptureCombinator ::
|
||||||
forall api combinator arg context .
|
forall api combinator arg context .
|
||||||
(HasServer api context,
|
(HasServer api context,
|
||||||
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
|
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
|
||||||
(Text -> IO (RouteResult arg))
|
(Context context -> Text -> IO (RouteResult arg))
|
||||||
-> CombinatorImplementation combinator arg api context
|
-> CombinatorImplementation combinator arg api context
|
||||||
makeCaptureCombinator getArg = CI $ \ Proxy context delayed ->
|
makeCaptureCombinator getArg = CI $ \ Proxy context delayed ->
|
||||||
CaptureRouter $
|
CaptureRouter $
|
||||||
route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured ->
|
route (Proxy :: Proxy api) context $ addCapture delayed $ \ captured ->
|
||||||
(liftRouteResult =<< liftIO (getArg captured))
|
(liftRouteResult =<< liftIO (getArg context captured))
|
||||||
|
|
||||||
makeRequestCheckCombinator ::
|
makeRequestCheckCombinator ::
|
||||||
forall api combinator context .
|
forall api combinator context .
|
||||||
(HasServer api context,
|
(HasServer api context,
|
||||||
WithArg () (ServerT api Handler) ~ ServerT api Handler) =>
|
WithArg () (ServerT api Handler) ~ ServerT api Handler) =>
|
||||||
(Request -> IO (RouteResult ()))
|
(Context context -> Request -> IO (RouteResult ()))
|
||||||
-> CombinatorImplementation combinator () api context
|
-> CombinatorImplementation combinator () api context
|
||||||
makeRequestCheckCombinator check = CI $ \ Proxy context delayed ->
|
makeRequestCheckCombinator check = CI $ \ Proxy context delayed ->
|
||||||
route (Proxy :: Proxy api) context $ addMethodCheck delayed $
|
route (Proxy :: Proxy api) context $ addMethodCheck delayed $
|
||||||
withRequest $ \ request ->
|
withRequest $ \ request ->
|
||||||
liftRouteResult =<< liftIO (check (protectBody "makeRequestCheckCombinator" request))
|
liftRouteResult =<< liftIO (check context (protectBody "makeRequestCheckCombinator" request))
|
||||||
|
|
||||||
makeAuthCombinator ::
|
makeAuthCombinator ::
|
||||||
forall api combinator arg context .
|
forall api combinator arg context .
|
||||||
(HasServer api context,
|
(HasServer api context,
|
||||||
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
|
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler)) =>
|
||||||
(Request -> IO (RouteResult arg))
|
(Context context -> Request -> IO (RouteResult arg))
|
||||||
-> CombinatorImplementation combinator arg api context
|
-> CombinatorImplementation combinator arg api context
|
||||||
makeAuthCombinator authCheck = CI $ \ Proxy context delayed ->
|
makeAuthCombinator authCheck = CI $ \ Proxy context delayed ->
|
||||||
route (Proxy :: Proxy api) context $ addAuthCheck delayed $
|
route (Proxy :: Proxy api) context $ addAuthCheck delayed $
|
||||||
withRequest $ \ request ->
|
withRequest $ \ request ->
|
||||||
liftRouteResult =<< liftIO (authCheck (protectBody "makeAuthCombinator" request))
|
liftRouteResult =<< liftIO (authCheck context (protectBody "makeAuthCombinator" request))
|
||||||
|
|
||||||
makeReqBodyCombinator ::
|
makeReqBodyCombinator ::
|
||||||
forall api combinator arg context .
|
forall api combinator arg context .
|
||||||
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
|
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
|
||||||
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
|
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
|
||||||
HasServer api context) =>
|
HasServer api context) =>
|
||||||
(IO ByteString -> arg)
|
(Context context -> IO ByteString -> arg)
|
||||||
-> CombinatorImplementation combinator arg api context
|
-> CombinatorImplementation combinator arg api context
|
||||||
makeReqBodyCombinator getArg = CI $ \ Proxy context delayed ->
|
makeReqBodyCombinator getArg = CI $ \ Proxy context delayed ->
|
||||||
route (Proxy :: Proxy api) context $ addBodyCheck delayed
|
route (Proxy :: Proxy api) context $ addBodyCheck delayed
|
||||||
(return ())
|
(return ())
|
||||||
(\ () -> withRequest $ \ request ->
|
(\ () -> withRequest $ \ request ->
|
||||||
liftRouteResult $ Route $ getArg $ requestBody request)
|
liftRouteResult $ Route $ getArg context $ requestBody request)
|
||||||
|
|
||||||
makeCombinator ::
|
makeCombinator ::
|
||||||
forall api combinator arg context .
|
forall api combinator arg context .
|
||||||
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
|
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
|
||||||
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
|
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
|
||||||
HasServer api context) =>
|
HasServer api context) =>
|
||||||
(Request -> IO (RouteResult arg))
|
(Context context -> Request -> IO (RouteResult arg))
|
||||||
-> CombinatorImplementation combinator arg api context
|
-> CombinatorImplementation combinator arg api context
|
||||||
makeCombinator getArg = CI $ \ Proxy context delayed ->
|
makeCombinator getArg = CI $ \ Proxy context delayed ->
|
||||||
route (Proxy :: Proxy api) context $ addBodyCheck delayed
|
route (Proxy :: Proxy api) context $ addBodyCheck delayed
|
||||||
(return ())
|
(return ())
|
||||||
(\ () -> withRequest $ \ request ->
|
(\ () -> withRequest $ \ request ->
|
||||||
liftRouteResult =<< liftIO (getArg (protectBody "makeCombinator" request)))
|
liftRouteResult =<< liftIO (getArg context (protectBody "makeCombinator" request)))
|
||||||
|
|
||||||
protectBody :: String -> Request -> Request
|
protectBody :: String -> Request -> Request
|
||||||
protectBody name request = request{
|
protectBody name request = request{
|
||||||
|
|
|
@ -23,7 +23,7 @@ import Data.Text hiding (map)
|
||||||
import Network.HTTP.Types
|
import Network.HTTP.Types
|
||||||
import Network.Wai
|
import Network.Wai
|
||||||
import Network.Wai.Internal
|
import Network.Wai.Internal
|
||||||
import Test.Hspec
|
import Test.Hspec hiding (context)
|
||||||
|
|
||||||
import Servant.API
|
import Servant.API
|
||||||
import Servant.Server
|
import Servant.Server
|
||||||
|
@ -113,6 +113,20 @@ spec = do
|
||||||
runApp app request `shouldThrow`
|
runApp app request `shouldThrow`
|
||||||
errorCall "ERROR: makeAuthCombinator: combinator must not access the request body"
|
errorCall "ERROR: makeAuthCombinator: combinator must not access the request body"
|
||||||
|
|
||||||
|
it "allows to access the context" $ do
|
||||||
|
let server (User name) = return name
|
||||||
|
context :: Context '[ [(SBS.ByteString, User)] ]
|
||||||
|
context = [("secret", User "Bob")] :. EmptyContext
|
||||||
|
app = serveWithContext (Proxy :: Proxy (AuthWithContext :> Get' String)) context server
|
||||||
|
request = defaultRequest{
|
||||||
|
requestHeaders =
|
||||||
|
("Auth", "secret") :
|
||||||
|
requestHeaders defaultRequest
|
||||||
|
}
|
||||||
|
response <- runApp app request
|
||||||
|
responseStatus response `shouldBe` ok200
|
||||||
|
responseBodyLbs response `shouldReturn` "\"Bob\""
|
||||||
|
|
||||||
describe "makeCombinator" $ do
|
describe "makeCombinator" $ do
|
||||||
it "allows to write a combinator by providing a function (Request -> a)" $ do
|
it "allows to write a combinator by providing a function (Request -> a)" $ do
|
||||||
let server = return
|
let server = return
|
||||||
|
@ -161,9 +175,6 @@ spec = do
|
||||||
response <- runApp app request
|
response <- runApp app request
|
||||||
responseBodyLbs response `shouldReturn` "\"foobar\""
|
responseBodyLbs response `shouldReturn` "\"foobar\""
|
||||||
|
|
||||||
it "allows to access the context" $ do
|
|
||||||
pending
|
|
||||||
|
|
||||||
it "allows to implement combinators in terms of existing combinators" $ do
|
it "allows to implement combinators in terms of existing combinators" $ do
|
||||||
pending
|
pending
|
||||||
|
|
||||||
|
@ -175,7 +186,7 @@ data StringCapture
|
||||||
|
|
||||||
instance HasServer api context => HasServer (StringCapture :> api) context where
|
instance HasServer api context => HasServer (StringCapture :> api) context where
|
||||||
type ServerT (StringCapture :> api) m = String -> ServerT api m
|
type ServerT (StringCapture :> api) m = String -> ServerT api m
|
||||||
route = runCI $ makeCaptureCombinator getCapture
|
route = runCI $ makeCaptureCombinator (const getCapture)
|
||||||
|
|
||||||
getCapture :: Text -> IO (RouteResult String)
|
getCapture :: Text -> IO (RouteResult String)
|
||||||
getCapture snippet = return $ case snippet of
|
getCapture snippet = return $ case snippet of
|
||||||
|
@ -188,7 +199,7 @@ data CheckFooHeader
|
||||||
|
|
||||||
instance HasServer api context => HasServer (CheckFooHeader :> api) context where
|
instance HasServer api context => HasServer (CheckFooHeader :> api) context where
|
||||||
type ServerT (CheckFooHeader :> api) m = ServerT api m
|
type ServerT (CheckFooHeader :> api) m = ServerT api m
|
||||||
route = runCI $ makeRequestCheckCombinator checkFooHeader
|
route = runCI $ makeRequestCheckCombinator (const checkFooHeader)
|
||||||
|
|
||||||
checkFooHeader :: Request -> IO (RouteResult ())
|
checkFooHeader :: Request -> IO (RouteResult ())
|
||||||
checkFooHeader request = return $
|
checkFooHeader request = return $
|
||||||
|
@ -201,7 +212,7 @@ data InvalidRequestCheckCombinator
|
||||||
|
|
||||||
instance HasServer api context => HasServer (InvalidRequestCheckCombinator :> api) context where
|
instance HasServer api context => HasServer (InvalidRequestCheckCombinator :> api) context where
|
||||||
type ServerT (InvalidRequestCheckCombinator :> api) m = ServerT api m
|
type ServerT (InvalidRequestCheckCombinator :> api) m = ServerT api m
|
||||||
route = runCI $ makeRequestCheckCombinator accessReqBody
|
route = runCI $ makeRequestCheckCombinator (const accessReqBody)
|
||||||
|
|
||||||
accessReqBody :: Request -> IO (RouteResult ())
|
accessReqBody :: Request -> IO (RouteResult ())
|
||||||
accessReqBody request = do
|
accessReqBody request = do
|
||||||
|
@ -217,7 +228,7 @@ data User = User String
|
||||||
|
|
||||||
instance HasServer api context => HasServer (AuthCombinator :> api) context where
|
instance HasServer api context => HasServer (AuthCombinator :> api) context where
|
||||||
type ServerT (AuthCombinator :> api) m = User -> ServerT api m
|
type ServerT (AuthCombinator :> api) m = User -> ServerT api m
|
||||||
route = runCI $ makeAuthCombinator checkAuth
|
route = runCI $ makeAuthCombinator (const checkAuth)
|
||||||
|
|
||||||
checkAuth :: Request -> IO (RouteResult User)
|
checkAuth :: Request -> IO (RouteResult User)
|
||||||
checkAuth request = return $ case lookup "Auth" (requestHeaders request) of
|
checkAuth request = return $ case lookup "Auth" (requestHeaders request) of
|
||||||
|
@ -230,20 +241,37 @@ data InvalidAuthCombinator
|
||||||
|
|
||||||
instance HasServer api context => HasServer (InvalidAuthCombinator :> api) context where
|
instance HasServer api context => HasServer (InvalidAuthCombinator :> api) context where
|
||||||
type ServerT (InvalidAuthCombinator :> api) m = User -> ServerT api m
|
type ServerT (InvalidAuthCombinator :> api) m = User -> ServerT api m
|
||||||
route = runCI $ makeAuthCombinator authWithReqBody
|
route = runCI $ makeAuthCombinator (const authWithReqBody)
|
||||||
|
|
||||||
authWithReqBody :: Request -> IO (RouteResult User)
|
authWithReqBody :: Request -> IO (RouteResult User)
|
||||||
authWithReqBody request = do
|
authWithReqBody request = do
|
||||||
body <- fromBody $ requestBody request
|
body <- fromBody $ requestBody request
|
||||||
deepseq body (return $ Route $ User $ cs body)
|
deepseq body (return $ Route $ User $ cs body)
|
||||||
|
|
||||||
|
data AuthWithContext
|
||||||
|
|
||||||
|
instance (HasContextEntry context [(SBS.ByteString, User)], HasServer api context) =>
|
||||||
|
HasServer (AuthWithContext :> api) context where
|
||||||
|
type ServerT (AuthWithContext :> api) m = User -> ServerT api m
|
||||||
|
route = runCI $ makeAuthCombinator authWithContext
|
||||||
|
|
||||||
|
-- fixme: remove foralls from haddock
|
||||||
|
|
||||||
|
authWithContext :: (HasContextEntry context [(SBS.ByteString, User)]) =>
|
||||||
|
Context context -> Request -> IO (RouteResult User)
|
||||||
|
authWithContext context request = return $ case lookup "Auth" (requestHeaders request) of
|
||||||
|
Just authToken -> case lookup authToken userDict of
|
||||||
|
Just user -> Route user
|
||||||
|
where
|
||||||
|
userDict = getContextEntry context
|
||||||
|
|
||||||
-- * general combinators
|
-- * general combinators
|
||||||
|
|
||||||
data FooHeader
|
data FooHeader
|
||||||
|
|
||||||
instance HasServer api context => HasServer (FooHeader :> api) context where
|
instance HasServer api context => HasServer (FooHeader :> api) context where
|
||||||
type ServerT (FooHeader :> api) m = String -> ServerT api m
|
type ServerT (FooHeader :> api) m = String -> ServerT api m
|
||||||
route = runCI $ makeCombinator getCustom
|
route = runCI $ makeCombinator (const getCustom)
|
||||||
|
|
||||||
getCustom :: Request -> IO (RouteResult String)
|
getCustom :: Request -> IO (RouteResult String)
|
||||||
getCustom request = return $ case lookup "Foo" (requestHeaders request) of
|
getCustom request = return $ case lookup "Foo" (requestHeaders request) of
|
||||||
|
@ -258,7 +286,7 @@ data Source = Source (IO SBS.ByteString)
|
||||||
|
|
||||||
instance HasServer api context => HasServer (StreamRequest :> api) context where
|
instance HasServer api context => HasServer (StreamRequest :> api) context where
|
||||||
type ServerT (StreamRequest :> api) m = Source -> ServerT api m
|
type ServerT (StreamRequest :> api) m = Source -> ServerT api m
|
||||||
route = runCI $ makeReqBodyCombinator getSource
|
route = runCI $ makeReqBodyCombinator (const getSource)
|
||||||
|
|
||||||
getSource :: IO SBS.ByteString -> Source
|
getSource :: IO SBS.ByteString -> Source
|
||||||
getSource = Source
|
getSource = Source
|
||||||
|
|
Loading…
Reference in a new issue