Use new Delayed type to capture authentication

This commit is contained in:
aaron levin 2015-12-17 00:39:53 +01:00
parent 667dbbc8cd
commit 1aca415ec7
2 changed files with 24 additions and 26 deletions

View file

@ -31,18 +31,14 @@ import Data.Word8 (isSpace, toLower, _colon)
import GHC.TypeLits (KnownSymbol, symbolVal)
import Data.Text.Encoding (decodeUtf8)
import Data.Text (splitOn)
import Network.HTTP.Types.Status (status401)
import Network.Wai (Request, Response, requestHeaders,
responseBuilder)
import Network.Wai (Request, requestHeaders)
import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders))
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtected,
BasicAuth (BasicAuth),
JWTAuth(..))
JWTAuth)
import Servant.Server.Internal.RoutingApplication (RouteResult(FailFatal))
import Web.JWT (JWT, VerifiedJWT, Secret)
import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret)
import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret)
import qualified Web.JWT as JWT (decode)
-- | Class to represent the ability to extract authentication-related
-- data from a 'Request' object.
@ -52,10 +48,10 @@ class AuthData a where
-- | handlers to deal with authentication failures.
data AuthHandlers authData = AuthHandlers
{ -- we couldn't find the right type of auth data (or any, for that matter)
onMissingAuthData :: IO ServantError
onMissingAuthData :: IO ServantErr
,
-- we found the right type of auth data in the request but the check failed
onUnauthenticated :: authData -> IO ServantError
onUnauthenticated :: authData -> IO ServantErr
}
-- | concrete type to provide when in 'Strict' mode.
@ -100,7 +96,8 @@ basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth
basicAuthHandlers =
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
authFailure = responseBuilder status401 [("WWW-Authenticate", headerBytes)] mempty in
authFailure = err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
in
AuthHandlers (return authFailure) ((const . return) authFailure)
-- | Basic authentication combinator with strict failure.
@ -130,8 +127,7 @@ instance AuthData JWTAuth where
jwtAuthHandlers :: AuthHandlers JWTAuth
jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure)
where
withError e =
responseBuilder status401 [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] mempty
withError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
missingData = withError "invalid_request"
authFailure = withError "invalid_token"
@ -141,5 +137,5 @@ jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailur
-- One can use strictProtect and laxProtect to make more complex authentication
-- and authorization schemes.
jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict
jwtAuthStrict secret subserver = strictProtect (return . JWT.decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver
jwtAuthStrict secret subserver = strictProtect (return . decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver

View file

@ -179,10 +179,12 @@ addMethodCheck :: Delayed a
addMethodCheck (Delayed captures method auth body server) new =
Delayed captures (combineRouteResults const method new) auth body server
-- | Add a method to perform authorization in strict mode.
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
-> IO (RouteResult (Maybe usr))
-> IO (RouteResult (Maybe auth))
-> Delayed a
addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
-- -> Delayed a
addAuthStrictCheck delayed@(Delayed captures method _ body _) new =
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionStrict -> new `bindRouteResults` \mAuthData -> case mAuthData of
Nothing -> do
@ -192,26 +194,25 @@ addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
return $ FailFatal resp
-- successfully pulled auth data out of the request
Just authData -> do
mUsr <- (checkAuthStrict authProtectionStrict) authData
Just aData -> do
mUsr <- (checkAuthStrict authProtectionStrict) aData
case mUsr of
-- this user is not authenticated
Nothing -> do
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData
resp <- onUnauthenticated (authHandlers authProtectionStrict) aData
return $ FailFatal resp
-- this user is authenticated
Just usr ->
(return . Route . subServerStrict authProtectionStrict) usr
in Delayed captures method newAuth body server
in Delayed captures method newAuth body (\_ y _ -> Route y)
-- | Add a body check to the end of the body block.
addBodyCheck :: Delayed (a -> b)
-> IO (RouteResult a)
-> Delayed b
addBodyCheck (Delayed captures method auth body server) new =
Delayed captures method auth (combineRouteResults (,) body new) (\ x (y, v) z -> ($ v) <$> server x y z)
Delayed captures method auth (combineRouteResults (,) body new) (\ x y (z, v) -> ($ v) <$> server x y z)
-- | Add an accept header check to the end of the body block.
-- The accept header check should occur after the body check,
@ -255,11 +256,12 @@ combineRouteResults f m1 m2 =
-- blocks on to the actual handler.
runDelayed :: Delayed a
-> IO (RouteResult a)
runDelayed (Delayed captures method body server) =
runDelayed (Delayed captures method auth body server) =
captures `bindRouteResults` \ c ->
method `bindRouteResults` \ _ ->
auth `bindRouteResults` \ a ->
body `bindRouteResults` \ b ->
return (server c b)
return (server c a b)
-- | Runs a delayed server and the resulting action.
-- Takes a continuation that lets us send a response.