Incorporate new GADT-based auth into Delayed
This commit is contained in:
parent
c169d0bd59
commit
4d23cada4c
2 changed files with 107 additions and 58 deletions
|
@ -1,9 +1,11 @@
|
||||||
{-# LANGUAGE CPP #-}
|
{-# LANGUAGE CPP #-}
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE FunctionalDependencies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
|
|
||||||
module Servant.Server.Internal.Authentication
|
module Servant.Server.Internal.Authentication
|
||||||
( AuthData (..)
|
( AuthData (..)
|
||||||
|
@ -11,6 +13,7 @@ module Servant.Server.Internal.Authentication
|
||||||
, basicAuthLax
|
, basicAuthLax
|
||||||
, basicAuthStrict
|
, basicAuthStrict
|
||||||
, jwtAuthStrict
|
, jwtAuthStrict
|
||||||
|
, SimpleAuthProtected
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad (guard)
|
import Control.Monad (guard)
|
||||||
|
@ -40,19 +43,20 @@ import qualified Web.JWT as JWT (decode)
|
||||||
|
|
||||||
-- | Class to represent the ability to extract authentication-related
|
-- | Class to represent the ability to extract authentication-related
|
||||||
-- data from a 'Request' object.
|
-- data from a 'Request' object.
|
||||||
class AuthData a where
|
class AuthData a e | a -> e where
|
||||||
authData :: Request -> Maybe a
|
authData :: Request -> Either e a
|
||||||
|
|
||||||
authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy)
|
-- | combinator to create authentication protected servers.
|
||||||
-> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData
|
authProtect :: OnMissing IO ServantErr missingPolicy missingError
|
||||||
-> (authData -> IO (Either errorIndex usr))
|
-> OnUnauthenticated IO ServantErr unauthPolicy unauthError authData
|
||||||
|
-> (authData -> IO (Either unauthError usr))
|
||||||
-> subserver
|
-> subserver
|
||||||
-> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver
|
-> AuthProtected IO ServantErr missingPolicy missingError unauthPolicy unauthError authData usr subserver
|
||||||
authProtect = AuthProtected
|
authProtect = AuthProtected
|
||||||
|
|
||||||
-- | 'BasicAuth' instance for authData
|
-- | 'BasicAuth' instance for authData
|
||||||
instance AuthData (BasicAuth realm) where
|
instance AuthData (BasicAuth realm) () where
|
||||||
authData request = do
|
authData request = maybe (Left ()) Right $ do
|
||||||
authBs <- lookup "Authorization" (requestHeaders request)
|
authBs <- lookup "Authorization" (requestHeaders request)
|
||||||
let (x,y) = B.break isSpace authBs
|
let (x,y) = B.break isSpace authBs
|
||||||
guard (B.map toLower x == "basic")
|
guard (B.map toLower x == "basic")
|
||||||
|
@ -72,8 +76,8 @@ basicAuthFailure p = let realmBytes = (fromString . symbolVal) p
|
||||||
-- | OnMisisng handler for Basic Authentication
|
-- | OnMisisng handler for Basic Authentication
|
||||||
basicMissingHandler :: forall realm. KnownSymbol realm
|
basicMissingHandler :: forall realm. KnownSymbol realm
|
||||||
=> Proxy realm
|
=> Proxy realm
|
||||||
-> OnMissing IO ServantErr 'Strict
|
-> OnMissing IO ServantErr 'Strict ()
|
||||||
basicMissingHandler p = StrictMissing (return $ basicAuthFailure p)
|
basicMissingHandler p = StrictMissing (const $ return (basicAuthFailure p))
|
||||||
|
|
||||||
-- | OnUnauthenticated handler for Basic Authentication
|
-- | OnUnauthenticated handler for Basic Authentication
|
||||||
basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
|
basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
|
||||||
|
@ -85,7 +89,7 @@ basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ b
|
||||||
basicAuthStrict :: forall realm usr subserver. KnownSymbol realm
|
basicAuthStrict :: forall realm usr subserver. KnownSymbol realm
|
||||||
=> (BasicAuth realm -> IO (Maybe usr))
|
=> (BasicAuth realm -> IO (Maybe usr))
|
||||||
-> subserver
|
-> subserver
|
||||||
-> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver
|
-> AuthProtected IO ServantErr 'Strict () 'Strict () (BasicAuth realm) usr subserver
|
||||||
basicAuthStrict check sub =
|
basicAuthStrict check sub =
|
||||||
let mHandler = basicMissingHandler (Proxy :: Proxy realm)
|
let mHandler = basicMissingHandler (Proxy :: Proxy realm)
|
||||||
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
|
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
|
||||||
|
@ -96,14 +100,14 @@ basicAuthStrict check sub =
|
||||||
basicAuthLax :: KnownSymbol realm
|
basicAuthLax :: KnownSymbol realm
|
||||||
=> (BasicAuth realm -> IO (Maybe usr))
|
=> (BasicAuth realm -> IO (Maybe usr))
|
||||||
-> subserver
|
-> subserver
|
||||||
-> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver
|
-> AuthProtected IO ServantErr 'Lax () 'Lax () (BasicAuth realm) usr subserver
|
||||||
basicAuthLax check sub =
|
basicAuthLax check sub =
|
||||||
let check' = \a -> maybe (Left ()) Right <$> check a
|
let check' = \a -> maybe (Left ()) Right <$> check a
|
||||||
in AuthProtected LaxMissing LaxUnauthenticated check' sub
|
in AuthProtected LaxMissing LaxUnauthenticated check' sub
|
||||||
|
|
||||||
-- | Authentication data we extract from requests for JWT-based authentication.
|
-- | Authentication data we extract from requests for JWT-based authentication.
|
||||||
instance AuthData JWTAuth where
|
instance AuthData JWTAuth () where
|
||||||
authData req = do
|
authData req = maybe (Left ()) Right $ do
|
||||||
hdr <- lookup "Authorization" . requestHeaders $ req
|
hdr <- lookup "Authorization" . requestHeaders $ req
|
||||||
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
|
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
|
||||||
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
|
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
|
||||||
|
@ -116,9 +120,13 @@ jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>
|
||||||
-- | OnMissing handler for Strict, JWT-based authentication
|
-- | OnMissing handler for Strict, JWT-based authentication
|
||||||
jwtAuthStrict :: Secret
|
jwtAuthStrict :: Secret
|
||||||
-> subserver
|
-> subserver
|
||||||
-> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver
|
-> AuthProtected IO ServantErr 'Strict () 'Strict () JWTAuth (JWT VerifiedJWT) subserver
|
||||||
jwtAuthStrict secret sub =
|
jwtAuthStrict secret sub =
|
||||||
let missingHandler = StrictMissing (return $ jwtWithError "invalid_request")
|
let missingHandler = StrictMissing (const $ return (jwtWithError "invalid_request"))
|
||||||
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
|
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
|
||||||
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
|
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
|
||||||
in AuthProtected missingHandler unauthHandler check sub
|
in AuthProtected missingHandler unauthHandler check sub
|
||||||
|
|
||||||
|
-- | A type alias to make simple authentication endpoints
|
||||||
|
type SimpleAuthProtected mPolicy uPolicy authData usr subserver =
|
||||||
|
AuthProtected IO ServantErr mPolicy () uPolicy () authData usr subserver
|
||||||
|
|
|
@ -2,10 +2,11 @@
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE DeriveFunctor #-}
|
{-# LANGUAGE DeriveFunctor #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE KindSignatures #-}
|
{-# LANGUAGE KindSignatures #-}
|
||||||
{-# LANGUAGE StandaloneDeriving #-}
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
module Servant.Server.Internal.RoutingApplication where
|
module Servant.Server.Internal.RoutingApplication where
|
||||||
|
|
||||||
#if !MIN_VERSION_base(4,8,0)
|
#if !MIN_VERSION_base(4,8,0)
|
||||||
|
@ -20,8 +21,8 @@ import Network.Wai (Application, Request,
|
||||||
Response, ResponseReceived,
|
Response, ResponseReceived,
|
||||||
requestBody,
|
requestBody,
|
||||||
strictRequestBody)
|
strictRequestBody)
|
||||||
import Servant.API.Authentication (AuthPolicy(Strict,Lax))
|
import Servant.API.Authentication (AuthProtected (..), AuthPolicy(Strict,Lax),
|
||||||
import Servant.Server.Internal.Authentication
|
OnMissing (..), OnUnauthenticated (..), SAuthPolicy (..))
|
||||||
import Servant.Server.Internal.ServantErr
|
import Servant.Server.Internal.ServantErr
|
||||||
|
|
||||||
type RoutingApplication =
|
type RoutingApplication =
|
||||||
|
@ -179,43 +180,83 @@ addMethodCheck :: Delayed a
|
||||||
addMethodCheck (Delayed captures method auth body server) new =
|
addMethodCheck (Delayed captures method auth body server) new =
|
||||||
Delayed captures (combineRouteResults const method new) auth body server
|
Delayed captures (combineRouteResults const method new) auth body server
|
||||||
|
|
||||||
-- | Add a method to perform authorization in strict mode.
|
-- | helper type family to capture server handled values for various policies
|
||||||
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
|
type family AuthDelayedReturn (mP :: AuthPolicy) mE (uP :: AuthPolicy) uE usr :: * where
|
||||||
-> IO (RouteResult (Maybe auth))
|
AuthDelayedReturn 'Strict mE 'Strict uE usr = usr
|
||||||
|
AuthDelayedReturn 'Strict mE 'Lax uE usr = Either uE usr
|
||||||
|
AuthDelayedReturn 'Lax mE 'Strict uE usr = Either mE usr
|
||||||
|
AuthDelayedReturn 'Lax mE 'Lax uE usr = Either (Either mE uE) usr
|
||||||
|
|
||||||
|
-- | Internal method to generate auth checkers for various policies. Scary type signature
|
||||||
|
-- but it does help with understanding the logic of how each policy works. See
|
||||||
|
-- examples below.
|
||||||
|
genAuthCheck :: (OnMissing IO ServantErr mPolicy mError -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> mError -> IO (RouteResult a))
|
||||||
|
-> (OnUnauthenticated IO ServantErr uPolicy uError auth -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> uError -> auth -> IO (RouteResult a))
|
||||||
|
-> (usr -> (AuthDelayedReturn mPolicy mError uPolicy uError usr))
|
||||||
|
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
|
||||||
|
-> IO (RouteResult (Either mError auth))
|
||||||
-> Delayed a
|
-> Delayed a
|
||||||
addAuthStrictCheck delayed@(Delayed captures method _ body _) new =
|
genAuthCheck missingHandler unauthHandler returnHandler d@(Delayed captures method _ body _) new =
|
||||||
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionStrict -> new `bindRouteResults` \mAuthData -> case mAuthData of
|
let newAuth =
|
||||||
|
runDelayed d `bindRouteResults` \ authProtection ->
|
||||||
Nothing -> do
|
new `bindRouteResults` \ eAuthData ->
|
||||||
-- we're in strict mode: don't let the request go
|
case eAuthData of
|
||||||
-- call the provided "on missing auth" handler
|
-- we failed to extract authentication data from the request
|
||||||
resp <- onMissingAuthData (authHandlers authProtectionStrict)
|
Left mError -> missingHandler (onMissing authProtection) (subserver authProtection) mError
|
||||||
return $ FailFatal resp
|
-- auth data was succesfully extracted from the request
|
||||||
|
Right aData -> do
|
||||||
-- successfully pulled auth data out of the request
|
eUsr <- checkAuth authProtection aData
|
||||||
Just aData -> do
|
case eUsr of
|
||||||
mUsr <- (checkAuthStrict authProtectionStrict) aData
|
-- we failed to authenticate the user
|
||||||
case mUsr of
|
Left uError -> unauthHandler (onUnauthenticated authProtection) (subserver authProtection) uError aData
|
||||||
-- this user is not authenticated
|
-- user was authenticated
|
||||||
Nothing -> do
|
Right usr ->
|
||||||
resp <- onUnauthenticated (authHandlers authProtectionStrict) aData
|
(return . Route . subserver authProtection) (returnHandler usr)
|
||||||
return $ FailFatal resp
|
|
||||||
|
|
||||||
-- this user is authenticated
|
|
||||||
Just usr ->
|
|
||||||
(return . Route . subServerStrict authProtectionStrict) usr
|
|
||||||
in Delayed captures method newAuth body (\_ y _ -> Route y)
|
in Delayed captures method newAuth body (\_ y _ -> Route y)
|
||||||
|
|
||||||
-- | Add a method to perform authorization in strict mode.
|
-- | Delayed auth checker for Strict Missing and Strict Unauthentication
|
||||||
addAuthLaxCheck :: Delayed (AuthProtected auth usr (Maybe usr -> a) 'Lax)
|
addAuthCheckSS :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Strict uError auth usr (usr -> a))
|
||||||
-> IO (RouteResult (Maybe auth))
|
-> IO (RouteResult (Either mError auth))
|
||||||
-> Delayed a
|
-> Delayed a
|
||||||
addAuthLaxCheck delayed@(Delayed captures method _ body _) new =
|
addAuthCheckSS = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
|
||||||
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionLax -> new `bindRouteResults` \mAuthData ->
|
(\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
|
||||||
fmap (Route . subServerLax authProtectionLax)
|
id
|
||||||
(maybe (pure Nothing) (checkAuthLax authProtectionLax) mAuthData)
|
|
||||||
|
|
||||||
in Delayed captures method newAuth body (\_ y _ -> Route y)
|
-- | Delayed auth checker for Strict Missing and Lax Unauthentication
|
||||||
|
addAuthCheckSL :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Lax uError auth usr (Either uError usr -> a))
|
||||||
|
-> IO (RouteResult (Either mError auth))
|
||||||
|
-> Delayed a
|
||||||
|
addAuthCheckSL = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
|
||||||
|
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left e))
|
||||||
|
Right
|
||||||
|
|
||||||
|
|
||||||
|
-- | Delayed auth checker for Lax Missing and Strict Unauthentication
|
||||||
|
addAuthCheckLS :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Strict uError auth usr (Either mError usr -> a))
|
||||||
|
-> IO (RouteResult (Either mError auth))
|
||||||
|
-> Delayed a
|
||||||
|
addAuthCheckLS = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left e))
|
||||||
|
(\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
|
||||||
|
Right
|
||||||
|
|
||||||
|
-- | Delayed auth checker for Lax Missing and Lax Unauthentication
|
||||||
|
addAuthCheckLL :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Lax uError auth usr (Either (Either mError uError) usr -> a))
|
||||||
|
-> IO (RouteResult (Either mError auth))
|
||||||
|
-> Delayed a
|
||||||
|
addAuthCheckLL = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left (Left e)))
|
||||||
|
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left (Right e)))
|
||||||
|
Right
|
||||||
|
|
||||||
|
-- | Add an auth check by supplying OnMissing policies and OnUnauthenticated policies.
|
||||||
|
addAuthCheck :: SAuthPolicy mPolicy
|
||||||
|
-> SAuthPolicy uPolicy
|
||||||
|
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
|
||||||
|
-> IO (RouteResult (Either mError auth))
|
||||||
|
-> Delayed a
|
||||||
|
addAuthCheck SStrict SStrict = addAuthCheckSS
|
||||||
|
addAuthCheck SStrict SLax = addAuthCheckSL
|
||||||
|
addAuthCheck SLax SStrict = addAuthCheckLS
|
||||||
|
addAuthCheck SLax SLax = addAuthCheckLL
|
||||||
|
|
||||||
-- | Add a body check to the end of the body block.
|
-- | Add a body check to the end of the body block.
|
||||||
addBodyCheck :: Delayed (a -> b)
|
addBodyCheck :: Delayed (a -> b)
|
||||||
|
|
Loading…
Reference in a new issue