Use new Delayed type to capture authentication
This commit is contained in:
parent
667dbbc8cd
commit
1aca415ec7
2 changed files with 24 additions and 26 deletions
|
@ -31,18 +31,14 @@ import Data.Word8 (isSpace, toLower, _colon)
|
|||
import GHC.TypeLits (KnownSymbol, symbolVal)
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Data.Text (splitOn)
|
||||
import Network.HTTP.Types.Status (status401)
|
||||
import Network.Wai (Request, Response, requestHeaders,
|
||||
responseBuilder)
|
||||
import Network.Wai (Request, requestHeaders)
|
||||
import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders))
|
||||
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
||||
AuthProtected,
|
||||
BasicAuth (BasicAuth),
|
||||
JWTAuth (..))
|
||||
JWTAuth)
|
||||
import Servant.Server.Internal.RoutingApplication (RouteResult(FailFatal))
|
||||
|
||||
import Web.JWT (JWT, VerifiedJWT, Secret)
|
||||
import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret)
|
||||
JWTAuth(..))
|
||||
import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret)
|
||||
import qualified Web.JWT as JWT (decode)
|
||||
|
||||
-- | Class to represent the ability to extract authentication-related
|
||||
-- data from a 'Request' object.
|
||||
|
@ -52,10 +48,10 @@ class AuthData a where
|
|||
-- | handlers to deal with authentication failures.
|
||||
data AuthHandlers authData = AuthHandlers
|
||||
{ -- we couldn't find the right type of auth data (or any, for that matter)
|
||||
onMissingAuthData :: IO ServantError
|
||||
onMissingAuthData :: IO ServantErr
|
||||
,
|
||||
-- we found the right type of auth data in the request but the check failed
|
||||
onUnauthenticated :: authData -> IO ServantError
|
||||
onUnauthenticated :: authData -> IO ServantErr
|
||||
}
|
||||
|
||||
-- | concrete type to provide when in 'Strict' mode.
|
||||
|
@ -100,7 +96,8 @@ basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth
|
|||
basicAuthHandlers =
|
||||
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
|
||||
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
|
||||
authFailure = responseBuilder status401 [("WWW-Authenticate", headerBytes)] mempty in
|
||||
authFailure = err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
|
||||
in
|
||||
AuthHandlers (return authFailure) ((const . return) authFailure)
|
||||
|
||||
-- | Basic authentication combinator with strict failure.
|
||||
|
@ -130,8 +127,7 @@ instance AuthData JWTAuth where
|
|||
jwtAuthHandlers :: AuthHandlers JWTAuth
|
||||
jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure)
|
||||
where
|
||||
withError e =
|
||||
responseBuilder status401 [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] mempty
|
||||
withError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
|
||||
missingData = withError "invalid_request"
|
||||
authFailure = withError "invalid_token"
|
||||
|
||||
|
@ -141,5 +137,5 @@ jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailur
|
|||
-- One can use strictProtect and laxProtect to make more complex authentication
|
||||
-- and authorization schemes.
|
||||
jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict
|
||||
jwtAuthStrict secret subserver = strictProtect (return . JWT.decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver
|
||||
jwtAuthStrict secret subserver = strictProtect (return . decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver
|
||||
|
||||
|
|
|
@ -179,11 +179,13 @@ addMethodCheck :: Delayed a
|
|||
addMethodCheck (Delayed captures method auth body server) new =
|
||||
Delayed captures (combineRouteResults const method new) auth body server
|
||||
|
||||
-- | Add a method to perform authorization in strict mode.
|
||||
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
|
||||
-> IO (RouteResult (Maybe usr))
|
||||
-> IO (RouteResult (Maybe auth))
|
||||
-> Delayed a
|
||||
addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
|
||||
let newAuth = runDelayed delayed `bindRouteResults` \ authProtectionStrict -> new `bindRouteResults` \ mAuthData -> case mAuthData of
|
||||
-- -> Delayed a
|
||||
addAuthStrictCheck delayed@(Delayed captures method _ body _) new =
|
||||
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionStrict -> new `bindRouteResults` \mAuthData -> case mAuthData of
|
||||
|
||||
Nothing -> do
|
||||
-- we're in strict mode: don't let the request go
|
||||
|
@ -192,26 +194,25 @@ addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
|
|||
return $ FailFatal resp
|
||||
|
||||
-- successfully pulled auth data out of the request
|
||||
Just authData -> do
|
||||
mUsr <- (checkAuthStrict authProtectionStrict) authData
|
||||
Just aData -> do
|
||||
mUsr <- (checkAuthStrict authProtectionStrict) aData
|
||||
case mUsr of
|
||||
-- this user is not authenticated
|
||||
Nothing -> do
|
||||
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData
|
||||
resp <- onUnauthenticated (authHandlers authProtectionStrict) aData
|
||||
return $ FailFatal resp
|
||||
|
||||
-- this user is authenticated
|
||||
Just usr ->
|
||||
(return . Route . subServerStrict authProtectionStrict) usr
|
||||
in Delayed captures method newAuth body server
|
||||
|
||||
in Delayed captures method newAuth body (\_ y _ -> Route y)
|
||||
|
||||
-- | Add a body check to the end of the body block.
|
||||
addBodyCheck :: Delayed (a -> b)
|
||||
-> IO (RouteResult a)
|
||||
-> Delayed b
|
||||
addBodyCheck (Delayed captures method auth body server) new =
|
||||
Delayed captures method auth (combineRouteResults (,) body new) (\ x (y, v) z -> ($ v) <$> server x y z)
|
||||
Delayed captures method auth (combineRouteResults (,) body new) (\ x y (z, v) -> ($ v) <$> server x y z)
|
||||
|
||||
-- | Add an accept header check to the end of the body block.
|
||||
-- The accept header check should occur after the body check,
|
||||
|
@ -255,11 +256,12 @@ combineRouteResults f m1 m2 =
|
|||
-- blocks on to the actual handler.
|
||||
runDelayed :: Delayed a
|
||||
-> IO (RouteResult a)
|
||||
runDelayed (Delayed captures method body server) =
|
||||
runDelayed (Delayed captures method auth body server) =
|
||||
captures `bindRouteResults` \ c ->
|
||||
method `bindRouteResults` \ _ ->
|
||||
auth `bindRouteResults` \ a ->
|
||||
body `bindRouteResults` \ b ->
|
||||
return (server c b)
|
||||
return (server c a b)
|
||||
|
||||
-- | Runs a delayed server and the resulting action.
|
||||
-- Takes a continuation that lets us send a response.
|
||||
|
|
Loading…
Reference in a new issue