Incorporate new GADT-based auth into Delayed
This commit is contained in:
parent
c169d0bd59
commit
4d23cada4c
2 changed files with 107 additions and 58 deletions
|
@ -1,9 +1,11 @@
|
|||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
module Servant.Server.Internal.Authentication
|
||||
( AuthData (..)
|
||||
|
@ -11,6 +13,7 @@ module Servant.Server.Internal.Authentication
|
|||
, basicAuthLax
|
||||
, basicAuthStrict
|
||||
, jwtAuthStrict
|
||||
, SimpleAuthProtected
|
||||
) where
|
||||
|
||||
import Control.Monad (guard)
|
||||
|
@ -40,19 +43,20 @@ import qualified Web.JWT as JWT (decode)
|
|||
|
||||
-- | Class to represent the ability to extract authentication-related
|
||||
-- data from a 'Request' object.
|
||||
class AuthData a where
|
||||
authData :: Request -> Maybe a
|
||||
class AuthData a e | a -> e where
|
||||
authData :: Request -> Either e a
|
||||
|
||||
authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy)
|
||||
-> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData
|
||||
-> (authData -> IO (Either errorIndex usr))
|
||||
-- | combinator to create authentication protected servers.
|
||||
authProtect :: OnMissing IO ServantErr missingPolicy missingError
|
||||
-> OnUnauthenticated IO ServantErr unauthPolicy unauthError authData
|
||||
-> (authData -> IO (Either unauthError usr))
|
||||
-> subserver
|
||||
-> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver
|
||||
-> AuthProtected IO ServantErr missingPolicy missingError unauthPolicy unauthError authData usr subserver
|
||||
authProtect = AuthProtected
|
||||
|
||||
-- | 'BasicAuth' instance for authData
|
||||
instance AuthData (BasicAuth realm) where
|
||||
authData request = do
|
||||
instance AuthData (BasicAuth realm) () where
|
||||
authData request = maybe (Left ()) Right $ do
|
||||
authBs <- lookup "Authorization" (requestHeaders request)
|
||||
let (x,y) = B.break isSpace authBs
|
||||
guard (B.map toLower x == "basic")
|
||||
|
@ -72,8 +76,8 @@ basicAuthFailure p = let realmBytes = (fromString . symbolVal) p
|
|||
-- | OnMisisng handler for Basic Authentication
|
||||
basicMissingHandler :: forall realm. KnownSymbol realm
|
||||
=> Proxy realm
|
||||
-> OnMissing IO ServantErr 'Strict
|
||||
basicMissingHandler p = StrictMissing (return $ basicAuthFailure p)
|
||||
-> OnMissing IO ServantErr 'Strict ()
|
||||
basicMissingHandler p = StrictMissing (const $ return (basicAuthFailure p))
|
||||
|
||||
-- | OnUnauthenticated handler for Basic Authentication
|
||||
basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
|
||||
|
@ -85,7 +89,7 @@ basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ b
|
|||
basicAuthStrict :: forall realm usr subserver. KnownSymbol realm
|
||||
=> (BasicAuth realm -> IO (Maybe usr))
|
||||
-> subserver
|
||||
-> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver
|
||||
-> AuthProtected IO ServantErr 'Strict () 'Strict () (BasicAuth realm) usr subserver
|
||||
basicAuthStrict check sub =
|
||||
let mHandler = basicMissingHandler (Proxy :: Proxy realm)
|
||||
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
|
||||
|
@ -96,14 +100,14 @@ basicAuthStrict check sub =
|
|||
basicAuthLax :: KnownSymbol realm
|
||||
=> (BasicAuth realm -> IO (Maybe usr))
|
||||
-> subserver
|
||||
-> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver
|
||||
-> AuthProtected IO ServantErr 'Lax () 'Lax () (BasicAuth realm) usr subserver
|
||||
basicAuthLax check sub =
|
||||
let check' = \a -> maybe (Left ()) Right <$> check a
|
||||
in AuthProtected LaxMissing LaxUnauthenticated check' sub
|
||||
|
||||
-- | Authentication data we extract from requests for JWT-based authentication.
|
||||
instance AuthData JWTAuth where
|
||||
authData req = do
|
||||
instance AuthData JWTAuth () where
|
||||
authData req = maybe (Left ()) Right $ do
|
||||
hdr <- lookup "Authorization" . requestHeaders $ req
|
||||
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
|
||||
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
|
||||
|
@ -116,9 +120,13 @@ jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>
|
|||
-- | OnMissing handler for Strict, JWT-based authentication
|
||||
jwtAuthStrict :: Secret
|
||||
-> subserver
|
||||
-> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver
|
||||
-> AuthProtected IO ServantErr 'Strict () 'Strict () JWTAuth (JWT VerifiedJWT) subserver
|
||||
jwtAuthStrict secret sub =
|
||||
let missingHandler = StrictMissing (return $ jwtWithError "invalid_request")
|
||||
let missingHandler = StrictMissing (const $ return (jwtWithError "invalid_request"))
|
||||
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
|
||||
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
|
||||
in AuthProtected missingHandler unauthHandler check sub
|
||||
|
||||
-- | A type alias to make simple authentication endpoints
|
||||
type SimpleAuthProtected mPolicy uPolicy authData usr subserver =
|
||||
AuthProtected IO ServantErr mPolicy () uPolicy () authData usr subserver
|
||||
|
|
|
@ -2,10 +2,11 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveFunctor #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
module Servant.Server.Internal.RoutingApplication where
|
||||
|
||||
#if !MIN_VERSION_base(4,8,0)
|
||||
|
@ -20,8 +21,8 @@ import Network.Wai (Application, Request,
|
|||
Response, ResponseReceived,
|
||||
requestBody,
|
||||
strictRequestBody)
|
||||
import Servant.API.Authentication (AuthPolicy(Strict,Lax))
|
||||
import Servant.Server.Internal.Authentication
|
||||
import Servant.API.Authentication (AuthProtected (..), AuthPolicy(Strict,Lax),
|
||||
OnMissing (..), OnUnauthenticated (..), SAuthPolicy (..))
|
||||
import Servant.Server.Internal.ServantErr
|
||||
|
||||
type RoutingApplication =
|
||||
|
@ -179,43 +180,83 @@ addMethodCheck :: Delayed a
|
|||
addMethodCheck (Delayed captures method auth body server) new =
|
||||
Delayed captures (combineRouteResults const method new) auth body server
|
||||
|
||||
-- | Add a method to perform authorization in strict mode.
|
||||
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
|
||||
-> IO (RouteResult (Maybe auth))
|
||||
-- | helper type family to capture server handled values for various policies
|
||||
type family AuthDelayedReturn (mP :: AuthPolicy) mE (uP :: AuthPolicy) uE usr :: * where
|
||||
AuthDelayedReturn 'Strict mE 'Strict uE usr = usr
|
||||
AuthDelayedReturn 'Strict mE 'Lax uE usr = Either uE usr
|
||||
AuthDelayedReturn 'Lax mE 'Strict uE usr = Either mE usr
|
||||
AuthDelayedReturn 'Lax mE 'Lax uE usr = Either (Either mE uE) usr
|
||||
|
||||
-- | Internal method to generate auth checkers for various policies. Scary type signature
|
||||
-- but it does help with understanding the logic of how each policy works. See
|
||||
-- examples below.
|
||||
genAuthCheck :: (OnMissing IO ServantErr mPolicy mError -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> mError -> IO (RouteResult a))
|
||||
-> (OnUnauthenticated IO ServantErr uPolicy uError auth -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> uError -> auth -> IO (RouteResult a))
|
||||
-> (usr -> (AuthDelayedReturn mPolicy mError uPolicy uError usr))
|
||||
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
|
||||
-> IO (RouteResult (Either mError auth))
|
||||
-> Delayed a
|
||||
addAuthStrictCheck delayed@(Delayed captures method _ body _) new =
|
||||
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionStrict -> new `bindRouteResults` \mAuthData -> case mAuthData of
|
||||
|
||||
Nothing -> do
|
||||
-- we're in strict mode: don't let the request go
|
||||
-- call the provided "on missing auth" handler
|
||||
resp <- onMissingAuthData (authHandlers authProtectionStrict)
|
||||
return $ FailFatal resp
|
||||
|
||||
-- successfully pulled auth data out of the request
|
||||
Just aData -> do
|
||||
mUsr <- (checkAuthStrict authProtectionStrict) aData
|
||||
case mUsr of
|
||||
-- this user is not authenticated
|
||||
Nothing -> do
|
||||
resp <- onUnauthenticated (authHandlers authProtectionStrict) aData
|
||||
return $ FailFatal resp
|
||||
|
||||
-- this user is authenticated
|
||||
Just usr ->
|
||||
(return . Route . subServerStrict authProtectionStrict) usr
|
||||
genAuthCheck missingHandler unauthHandler returnHandler d@(Delayed captures method _ body _) new =
|
||||
let newAuth =
|
||||
runDelayed d `bindRouteResults` \ authProtection ->
|
||||
new `bindRouteResults` \ eAuthData ->
|
||||
case eAuthData of
|
||||
-- we failed to extract authentication data from the request
|
||||
Left mError -> missingHandler (onMissing authProtection) (subserver authProtection) mError
|
||||
-- auth data was succesfully extracted from the request
|
||||
Right aData -> do
|
||||
eUsr <- checkAuth authProtection aData
|
||||
case eUsr of
|
||||
-- we failed to authenticate the user
|
||||
Left uError -> unauthHandler (onUnauthenticated authProtection) (subserver authProtection) uError aData
|
||||
-- user was authenticated
|
||||
Right usr ->
|
||||
(return . Route . subserver authProtection) (returnHandler usr)
|
||||
in Delayed captures method newAuth body (\_ y _ -> Route y)
|
||||
|
||||
-- | Add a method to perform authorization in strict mode.
|
||||
addAuthLaxCheck :: Delayed (AuthProtected auth usr (Maybe usr -> a) 'Lax)
|
||||
-> IO (RouteResult (Maybe auth))
|
||||
-- | Delayed auth checker for Strict Missing and Strict Unauthentication
|
||||
addAuthCheckSS :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Strict uError auth usr (usr -> a))
|
||||
-> IO (RouteResult (Either mError auth))
|
||||
-> Delayed a
|
||||
addAuthLaxCheck delayed@(Delayed captures method _ body _) new =
|
||||
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionLax -> new `bindRouteResults` \mAuthData ->
|
||||
fmap (Route . subServerLax authProtectionLax)
|
||||
(maybe (pure Nothing) (checkAuthLax authProtectionLax) mAuthData)
|
||||
addAuthCheckSS = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
|
||||
(\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
|
||||
id
|
||||
|
||||
in Delayed captures method newAuth body (\_ y _ -> Route y)
|
||||
-- | Delayed auth checker for Strict Missing and Lax Unauthentication
|
||||
addAuthCheckSL :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Lax uError auth usr (Either uError usr -> a))
|
||||
-> IO (RouteResult (Either mError auth))
|
||||
-> Delayed a
|
||||
addAuthCheckSL = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
|
||||
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left e))
|
||||
Right
|
||||
|
||||
|
||||
-- | Delayed auth checker for Lax Missing and Strict Unauthentication
|
||||
addAuthCheckLS :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Strict uError auth usr (Either mError usr -> a))
|
||||
-> IO (RouteResult (Either mError auth))
|
||||
-> Delayed a
|
||||
addAuthCheckLS = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left e))
|
||||
(\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
|
||||
Right
|
||||
|
||||
-- | Delayed auth checker for Lax Missing and Lax Unauthentication
|
||||
addAuthCheckLL :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Lax uError auth usr (Either (Either mError uError) usr -> a))
|
||||
-> IO (RouteResult (Either mError auth))
|
||||
-> Delayed a
|
||||
addAuthCheckLL = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left (Left e)))
|
||||
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left (Right e)))
|
||||
Right
|
||||
|
||||
-- | Add an auth check by supplying OnMissing policies and OnUnauthenticated policies.
|
||||
addAuthCheck :: SAuthPolicy mPolicy
|
||||
-> SAuthPolicy uPolicy
|
||||
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
|
||||
-> IO (RouteResult (Either mError auth))
|
||||
-> Delayed a
|
||||
addAuthCheck SStrict SStrict = addAuthCheckSS
|
||||
addAuthCheck SStrict SLax = addAuthCheckSL
|
||||
addAuthCheck SLax SStrict = addAuthCheckLS
|
||||
addAuthCheck SLax SLax = addAuthCheckLL
|
||||
|
||||
-- | Add a body check to the end of the body block.
|
||||
addBodyCheck :: Delayed (a -> b)
|
||||
|
|
Loading…
Reference in a new issue