Use new Delayed type to capture authentication
This commit is contained in:
parent
667dbbc8cd
commit
1aca415ec7
2 changed files with 24 additions and 26 deletions
|
@ -31,18 +31,14 @@ import Data.Word8 (isSpace, toLower, _colon)
|
||||||
import GHC.TypeLits (KnownSymbol, symbolVal)
|
import GHC.TypeLits (KnownSymbol, symbolVal)
|
||||||
import Data.Text.Encoding (decodeUtf8)
|
import Data.Text.Encoding (decodeUtf8)
|
||||||
import Data.Text (splitOn)
|
import Data.Text (splitOn)
|
||||||
import Network.HTTP.Types.Status (status401)
|
import Network.Wai (Request, requestHeaders)
|
||||||
import Network.Wai (Request, Response, requestHeaders,
|
import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders))
|
||||||
responseBuilder)
|
|
||||||
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
||||||
AuthProtected,
|
AuthProtected,
|
||||||
BasicAuth (BasicAuth),
|
BasicAuth (BasicAuth),
|
||||||
JWTAuth (..))
|
JWTAuth(..))
|
||||||
JWTAuth)
|
import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret)
|
||||||
import Servant.Server.Internal.RoutingApplication (RouteResult(FailFatal))
|
import qualified Web.JWT as JWT (decode)
|
||||||
|
|
||||||
import Web.JWT (JWT, VerifiedJWT, Secret)
|
|
||||||
import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret)
|
|
||||||
|
|
||||||
-- | Class to represent the ability to extract authentication-related
|
-- | Class to represent the ability to extract authentication-related
|
||||||
-- data from a 'Request' object.
|
-- data from a 'Request' object.
|
||||||
|
@ -52,10 +48,10 @@ class AuthData a where
|
||||||
-- | handlers to deal with authentication failures.
|
-- | handlers to deal with authentication failures.
|
||||||
data AuthHandlers authData = AuthHandlers
|
data AuthHandlers authData = AuthHandlers
|
||||||
{ -- we couldn't find the right type of auth data (or any, for that matter)
|
{ -- we couldn't find the right type of auth data (or any, for that matter)
|
||||||
onMissingAuthData :: IO ServantError
|
onMissingAuthData :: IO ServantErr
|
||||||
,
|
,
|
||||||
-- we found the right type of auth data in the request but the check failed
|
-- we found the right type of auth data in the request but the check failed
|
||||||
onUnauthenticated :: authData -> IO ServantError
|
onUnauthenticated :: authData -> IO ServantErr
|
||||||
}
|
}
|
||||||
|
|
||||||
-- | concrete type to provide when in 'Strict' mode.
|
-- | concrete type to provide when in 'Strict' mode.
|
||||||
|
@ -100,7 +96,8 @@ basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth
|
||||||
basicAuthHandlers =
|
basicAuthHandlers =
|
||||||
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
|
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
|
||||||
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
|
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
|
||||||
authFailure = responseBuilder status401 [("WWW-Authenticate", headerBytes)] mempty in
|
authFailure = err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
|
||||||
|
in
|
||||||
AuthHandlers (return authFailure) ((const . return) authFailure)
|
AuthHandlers (return authFailure) ((const . return) authFailure)
|
||||||
|
|
||||||
-- | Basic authentication combinator with strict failure.
|
-- | Basic authentication combinator with strict failure.
|
||||||
|
@ -130,8 +127,7 @@ instance AuthData JWTAuth where
|
||||||
jwtAuthHandlers :: AuthHandlers JWTAuth
|
jwtAuthHandlers :: AuthHandlers JWTAuth
|
||||||
jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure)
|
jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure)
|
||||||
where
|
where
|
||||||
withError e =
|
withError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
|
||||||
responseBuilder status401 [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] mempty
|
|
||||||
missingData = withError "invalid_request"
|
missingData = withError "invalid_request"
|
||||||
authFailure = withError "invalid_token"
|
authFailure = withError "invalid_token"
|
||||||
|
|
||||||
|
@ -141,5 +137,5 @@ jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailur
|
||||||
-- One can use strictProtect and laxProtect to make more complex authentication
|
-- One can use strictProtect and laxProtect to make more complex authentication
|
||||||
-- and authorization schemes.
|
-- and authorization schemes.
|
||||||
jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict
|
jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict
|
||||||
jwtAuthStrict secret subserver = strictProtect (return . JWT.decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver
|
jwtAuthStrict secret subserver = strictProtect (return . decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver
|
||||||
|
|
||||||
|
|
|
@ -179,11 +179,13 @@ addMethodCheck :: Delayed a
|
||||||
addMethodCheck (Delayed captures method auth body server) new =
|
addMethodCheck (Delayed captures method auth body server) new =
|
||||||
Delayed captures (combineRouteResults const method new) auth body server
|
Delayed captures (combineRouteResults const method new) auth body server
|
||||||
|
|
||||||
|
-- | Add a method to perform authorization in strict mode.
|
||||||
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
|
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
|
||||||
-> IO (RouteResult (Maybe usr))
|
-> IO (RouteResult (Maybe auth))
|
||||||
-> Delayed a
|
-> Delayed a
|
||||||
addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
|
-- -> Delayed a
|
||||||
let newAuth = runDelayed delayed `bindRouteResults` \ authProtectionStrict -> new `bindRouteResults` \ mAuthData -> case mAuthData of
|
addAuthStrictCheck delayed@(Delayed captures method _ body _) new =
|
||||||
|
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionStrict -> new `bindRouteResults` \mAuthData -> case mAuthData of
|
||||||
|
|
||||||
Nothing -> do
|
Nothing -> do
|
||||||
-- we're in strict mode: don't let the request go
|
-- we're in strict mode: don't let the request go
|
||||||
|
@ -192,26 +194,25 @@ addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
|
||||||
return $ FailFatal resp
|
return $ FailFatal resp
|
||||||
|
|
||||||
-- successfully pulled auth data out of the request
|
-- successfully pulled auth data out of the request
|
||||||
Just authData -> do
|
Just aData -> do
|
||||||
mUsr <- (checkAuthStrict authProtectionStrict) authData
|
mUsr <- (checkAuthStrict authProtectionStrict) aData
|
||||||
case mUsr of
|
case mUsr of
|
||||||
-- this user is not authenticated
|
-- this user is not authenticated
|
||||||
Nothing -> do
|
Nothing -> do
|
||||||
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData
|
resp <- onUnauthenticated (authHandlers authProtectionStrict) aData
|
||||||
return $ FailFatal resp
|
return $ FailFatal resp
|
||||||
|
|
||||||
-- this user is authenticated
|
-- this user is authenticated
|
||||||
Just usr ->
|
Just usr ->
|
||||||
(return . Route . subServerStrict authProtectionStrict) usr
|
(return . Route . subServerStrict authProtectionStrict) usr
|
||||||
in Delayed captures method newAuth body server
|
in Delayed captures method newAuth body (\_ y _ -> Route y)
|
||||||
|
|
||||||
|
|
||||||
-- | Add a body check to the end of the body block.
|
-- | Add a body check to the end of the body block.
|
||||||
addBodyCheck :: Delayed (a -> b)
|
addBodyCheck :: Delayed (a -> b)
|
||||||
-> IO (RouteResult a)
|
-> IO (RouteResult a)
|
||||||
-> Delayed b
|
-> Delayed b
|
||||||
addBodyCheck (Delayed captures method auth body server) new =
|
addBodyCheck (Delayed captures method auth body server) new =
|
||||||
Delayed captures method auth (combineRouteResults (,) body new) (\ x (y, v) z -> ($ v) <$> server x y z)
|
Delayed captures method auth (combineRouteResults (,) body new) (\ x y (z, v) -> ($ v) <$> server x y z)
|
||||||
|
|
||||||
-- | Add an accept header check to the end of the body block.
|
-- | Add an accept header check to the end of the body block.
|
||||||
-- The accept header check should occur after the body check,
|
-- The accept header check should occur after the body check,
|
||||||
|
@ -255,11 +256,12 @@ combineRouteResults f m1 m2 =
|
||||||
-- blocks on to the actual handler.
|
-- blocks on to the actual handler.
|
||||||
runDelayed :: Delayed a
|
runDelayed :: Delayed a
|
||||||
-> IO (RouteResult a)
|
-> IO (RouteResult a)
|
||||||
runDelayed (Delayed captures method body server) =
|
runDelayed (Delayed captures method auth body server) =
|
||||||
captures `bindRouteResults` \ c ->
|
captures `bindRouteResults` \ c ->
|
||||||
method `bindRouteResults` \ _ ->
|
method `bindRouteResults` \ _ ->
|
||||||
|
auth `bindRouteResults` \ a ->
|
||||||
body `bindRouteResults` \ b ->
|
body `bindRouteResults` \ b ->
|
||||||
return (server c b)
|
return (server c a b)
|
||||||
|
|
||||||
-- | Runs a delayed server and the resulting action.
|
-- | Runs a delayed server and the resulting action.
|
||||||
-- Takes a continuation that lets us send a response.
|
-- Takes a continuation that lets us send a response.
|
||||||
|
|
Loading…
Reference in a new issue