From 1aca415ec7c7689bf666ad52e5adc315b1e670af Mon Sep 17 00:00:00 2001 From: aaron levin Date: Thu, 17 Dec 2015 00:39:53 +0100 Subject: [PATCH] Use new Delayed type to capture authentication --- .../Servant/Server/Internal/Authentication.hs | 26 ++++++++----------- .../Server/Internal/RoutingApplication.hs | 24 +++++++++-------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/servant-server/src/Servant/Server/Internal/Authentication.hs b/servant-server/src/Servant/Server/Internal/Authentication.hs index 7f4765e6..fc9d020f 100644 --- a/servant-server/src/Servant/Server/Internal/Authentication.hs +++ b/servant-server/src/Servant/Server/Internal/Authentication.hs @@ -31,18 +31,14 @@ import Data.Word8 (isSpace, toLower, _colon) import GHC.TypeLits (KnownSymbol, symbolVal) import Data.Text.Encoding (decodeUtf8) import Data.Text (splitOn) -import Network.HTTP.Types.Status (status401) -import Network.Wai (Request, Response, requestHeaders, - responseBuilder) +import Network.Wai (Request, requestHeaders) +import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders)) import Servant.API.Authentication (AuthPolicy (Strict, Lax), AuthProtected, BasicAuth (BasicAuth), - JWTAuth (..)) - JWTAuth) -import Servant.Server.Internal.RoutingApplication (RouteResult(FailFatal)) - -import Web.JWT (JWT, VerifiedJWT, Secret) -import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret) + JWTAuth(..)) +import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret) +import qualified Web.JWT as JWT (decode) -- | Class to represent the ability to extract authentication-related -- data from a 'Request' object. @@ -52,10 +48,10 @@ class AuthData a where -- | handlers to deal with authentication failures. data AuthHandlers authData = AuthHandlers { -- we couldn't find the right type of auth data (or any, for that matter) - onMissingAuthData :: IO ServantError + onMissingAuthData :: IO ServantErr , -- we found the right type of auth data in the request but the check failed - onUnauthenticated :: authData -> IO ServantError + onUnauthenticated :: authData -> IO ServantErr } -- | concrete type to provide when in 'Strict' mode. @@ -100,7 +96,8 @@ basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth basicAuthHandlers = let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm) headerBytes = "Basic realm=\"" <> realmBytes <> "\"" - authFailure = responseBuilder status401 [("WWW-Authenticate", headerBytes)] mempty in + authFailure = err401 { errHeaders = [("WWW-Authenticate", headerBytes)] } + in AuthHandlers (return authFailure) ((const . return) authFailure) -- | Basic authentication combinator with strict failure. @@ -130,8 +127,7 @@ instance AuthData JWTAuth where jwtAuthHandlers :: AuthHandlers JWTAuth jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure) where - withError e = - responseBuilder status401 [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] mempty + withError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] } missingData = withError "invalid_request" authFailure = withError "invalid_token" @@ -141,5 +137,5 @@ jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailur -- One can use strictProtect and laxProtect to make more complex authentication -- and authorization schemes. jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict -jwtAuthStrict secret subserver = strictProtect (return . JWT.decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver +jwtAuthStrict secret subserver = strictProtect (return . decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 70cd62b3..1455e036 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -179,11 +179,13 @@ addMethodCheck :: Delayed a addMethodCheck (Delayed captures method auth body server) new = Delayed captures (combineRouteResults const method new) auth body server +-- | Add a method to perform authorization in strict mode. addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict) - -> IO (RouteResult (Maybe usr)) + -> IO (RouteResult (Maybe auth)) -> Delayed a -addAuthStrictCheck delayed@(Delayed captures method _ body server) new = - let newAuth = runDelayed delayed `bindRouteResults` \ authProtectionStrict -> new `bindRouteResults` \ mAuthData -> case mAuthData of + -- -> Delayed a +addAuthStrictCheck delayed@(Delayed captures method _ body _) new = + let newAuth = runDelayed delayed `bindRouteResults` \authProtectionStrict -> new `bindRouteResults` \mAuthData -> case mAuthData of Nothing -> do -- we're in strict mode: don't let the request go @@ -192,26 +194,25 @@ addAuthStrictCheck delayed@(Delayed captures method _ body server) new = return $ FailFatal resp -- successfully pulled auth data out of the request - Just authData -> do - mUsr <- (checkAuthStrict authProtectionStrict) authData + Just aData -> do + mUsr <- (checkAuthStrict authProtectionStrict) aData case mUsr of -- this user is not authenticated Nothing -> do - resp <- onUnauthenticated (authHandlers authProtectionStrict) authData + resp <- onUnauthenticated (authHandlers authProtectionStrict) aData return $ FailFatal resp -- this user is authenticated Just usr -> (return . Route . subServerStrict authProtectionStrict) usr - in Delayed captures method newAuth body server - + in Delayed captures method newAuth body (\_ y _ -> Route y) -- | Add a body check to the end of the body block. addBodyCheck :: Delayed (a -> b) -> IO (RouteResult a) -> Delayed b addBodyCheck (Delayed captures method auth body server) new = - Delayed captures method auth (combineRouteResults (,) body new) (\ x (y, v) z -> ($ v) <$> server x y z) + Delayed captures method auth (combineRouteResults (,) body new) (\ x y (z, v) -> ($ v) <$> server x y z) -- | Add an accept header check to the end of the body block. -- The accept header check should occur after the body check, @@ -255,11 +256,12 @@ combineRouteResults f m1 m2 = -- blocks on to the actual handler. runDelayed :: Delayed a -> IO (RouteResult a) -runDelayed (Delayed captures method body server) = +runDelayed (Delayed captures method auth body server) = captures `bindRouteResults` \ c -> method `bindRouteResults` \ _ -> + auth `bindRouteResults` \ a -> body `bindRouteResults` \ b -> - return (server c b) + return (server c a b) -- | Runs a delayed server and the resulting action. -- Takes a continuation that lets us send a response.