Incorporate new GADT-based auth into Delayed

This commit is contained in:
aaron levin 2015-12-24 17:20:29 +01:00
parent c169d0bd59
commit 4d23cada4c
2 changed files with 107 additions and 58 deletions

View file

@ -1,9 +1,11 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
module Servant.Server.Internal.Authentication module Servant.Server.Internal.Authentication
( AuthData (..) ( AuthData (..)
@ -11,6 +13,7 @@ module Servant.Server.Internal.Authentication
, basicAuthLax , basicAuthLax
, basicAuthStrict , basicAuthStrict
, jwtAuthStrict , jwtAuthStrict
, SimpleAuthProtected
) where ) where
import Control.Monad (guard) import Control.Monad (guard)
@ -40,19 +43,20 @@ import qualified Web.JWT as JWT (decode)
-- | Class to represent the ability to extract authentication-related -- | Class to represent the ability to extract authentication-related
-- data from a 'Request' object. -- data from a 'Request' object.
class AuthData a where class AuthData a e | a -> e where
authData :: Request -> Maybe a authData :: Request -> Either e a
authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy) -- | combinator to create authentication protected servers.
-> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData authProtect :: OnMissing IO ServantErr missingPolicy missingError
-> (authData -> IO (Either errorIndex usr)) -> OnUnauthenticated IO ServantErr unauthPolicy unauthError authData
-> (authData -> IO (Either unauthError usr))
-> subserver -> subserver
-> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver -> AuthProtected IO ServantErr missingPolicy missingError unauthPolicy unauthError authData usr subserver
authProtect = AuthProtected authProtect = AuthProtected
-- | 'BasicAuth' instance for authData -- | 'BasicAuth' instance for authData
instance AuthData (BasicAuth realm) where instance AuthData (BasicAuth realm) () where
authData request = do authData request = maybe (Left ()) Right $ do
authBs <- lookup "Authorization" (requestHeaders request) authBs <- lookup "Authorization" (requestHeaders request)
let (x,y) = B.break isSpace authBs let (x,y) = B.break isSpace authBs
guard (B.map toLower x == "basic") guard (B.map toLower x == "basic")
@ -72,8 +76,8 @@ basicAuthFailure p = let realmBytes = (fromString . symbolVal) p
-- | OnMisisng handler for Basic Authentication -- | OnMisisng handler for Basic Authentication
basicMissingHandler :: forall realm. KnownSymbol realm basicMissingHandler :: forall realm. KnownSymbol realm
=> Proxy realm => Proxy realm
-> OnMissing IO ServantErr 'Strict -> OnMissing IO ServantErr 'Strict ()
basicMissingHandler p = StrictMissing (return $ basicAuthFailure p) basicMissingHandler p = StrictMissing (const $ return (basicAuthFailure p))
-- | OnUnauthenticated handler for Basic Authentication -- | OnUnauthenticated handler for Basic Authentication
basicUnauthenticatedHandler :: forall realm. KnownSymbol realm basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
@ -85,7 +89,7 @@ basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ b
basicAuthStrict :: forall realm usr subserver. KnownSymbol realm basicAuthStrict :: forall realm usr subserver. KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr)) => (BasicAuth realm -> IO (Maybe usr))
-> subserver -> subserver
-> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver -> AuthProtected IO ServantErr 'Strict () 'Strict () (BasicAuth realm) usr subserver
basicAuthStrict check sub = basicAuthStrict check sub =
let mHandler = basicMissingHandler (Proxy :: Proxy realm) let mHandler = basicMissingHandler (Proxy :: Proxy realm)
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm) unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
@ -96,14 +100,14 @@ basicAuthStrict check sub =
basicAuthLax :: KnownSymbol realm basicAuthLax :: KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr)) => (BasicAuth realm -> IO (Maybe usr))
-> subserver -> subserver
-> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver -> AuthProtected IO ServantErr 'Lax () 'Lax () (BasicAuth realm) usr subserver
basicAuthLax check sub = basicAuthLax check sub =
let check' = \a -> maybe (Left ()) Right <$> check a let check' = \a -> maybe (Left ()) Right <$> check a
in AuthProtected LaxMissing LaxUnauthenticated check' sub in AuthProtected LaxMissing LaxUnauthenticated check' sub
-- | Authentication data we extract from requests for JWT-based authentication. -- | Authentication data we extract from requests for JWT-based authentication.
instance AuthData JWTAuth where instance AuthData JWTAuth () where
authData req = do authData req = maybe (Left ()) Right $ do
hdr <- lookup "Authorization" . requestHeaders $ req hdr <- lookup "Authorization" . requestHeaders $ req
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr ["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token _ <- JWT.decode token -- try decode it. otherwise it's not a proper token
@ -116,9 +120,13 @@ jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>
-- | OnMissing handler for Strict, JWT-based authentication -- | OnMissing handler for Strict, JWT-based authentication
jwtAuthStrict :: Secret jwtAuthStrict :: Secret
-> subserver -> subserver
-> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver -> AuthProtected IO ServantErr 'Strict () 'Strict () JWTAuth (JWT VerifiedJWT) subserver
jwtAuthStrict secret sub = jwtAuthStrict secret sub =
let missingHandler = StrictMissing (return $ jwtWithError "invalid_request") let missingHandler = StrictMissing (const $ return (jwtWithError "invalid_request"))
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token")) unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
in AuthProtected missingHandler unauthHandler check sub in AuthProtected missingHandler unauthHandler check sub
-- | A type alias to make simple authentication endpoints
type SimpleAuthProtected mPolicy uPolicy authData usr subserver =
AuthProtected IO ServantErr mPolicy () uPolicy () authData usr subserver

View file

@ -2,10 +2,11 @@
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-} {-# LANGUAGE KindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Servant.Server.Internal.RoutingApplication where module Servant.Server.Internal.RoutingApplication where
#if !MIN_VERSION_base(4,8,0) #if !MIN_VERSION_base(4,8,0)
@ -20,8 +21,8 @@ import Network.Wai (Application, Request,
Response, ResponseReceived, Response, ResponseReceived,
requestBody, requestBody,
strictRequestBody) strictRequestBody)
import Servant.API.Authentication (AuthPolicy(Strict,Lax)) import Servant.API.Authentication (AuthProtected (..), AuthPolicy(Strict,Lax),
import Servant.Server.Internal.Authentication OnMissing (..), OnUnauthenticated (..), SAuthPolicy (..))
import Servant.Server.Internal.ServantErr import Servant.Server.Internal.ServantErr
type RoutingApplication = type RoutingApplication =
@ -179,43 +180,83 @@ addMethodCheck :: Delayed a
addMethodCheck (Delayed captures method auth body server) new = addMethodCheck (Delayed captures method auth body server) new =
Delayed captures (combineRouteResults const method new) auth body server Delayed captures (combineRouteResults const method new) auth body server
-- | Add a method to perform authorization in strict mode. -- | helper type family to capture server handled values for various policies
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict) type family AuthDelayedReturn (mP :: AuthPolicy) mE (uP :: AuthPolicy) uE usr :: * where
-> IO (RouteResult (Maybe auth)) AuthDelayedReturn 'Strict mE 'Strict uE usr = usr
AuthDelayedReturn 'Strict mE 'Lax uE usr = Either uE usr
AuthDelayedReturn 'Lax mE 'Strict uE usr = Either mE usr
AuthDelayedReturn 'Lax mE 'Lax uE usr = Either (Either mE uE) usr
-- | Internal method to generate auth checkers for various policies. Scary type signature
-- but it does help with understanding the logic of how each policy works. See
-- examples below.
genAuthCheck :: (OnMissing IO ServantErr mPolicy mError -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> mError -> IO (RouteResult a))
-> (OnUnauthenticated IO ServantErr uPolicy uError auth -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> uError -> auth -> IO (RouteResult a))
-> (usr -> (AuthDelayedReturn mPolicy mError uPolicy uError usr))
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a -> Delayed a
addAuthStrictCheck delayed@(Delayed captures method _ body _) new = genAuthCheck missingHandler unauthHandler returnHandler d@(Delayed captures method _ body _) new =
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionStrict -> new `bindRouteResults` \mAuthData -> case mAuthData of let newAuth =
runDelayed d `bindRouteResults` \ authProtection ->
Nothing -> do new `bindRouteResults` \ eAuthData ->
-- we're in strict mode: don't let the request go case eAuthData of
-- call the provided "on missing auth" handler -- we failed to extract authentication data from the request
resp <- onMissingAuthData (authHandlers authProtectionStrict) Left mError -> missingHandler (onMissing authProtection) (subserver authProtection) mError
return $ FailFatal resp -- auth data was succesfully extracted from the request
Right aData -> do
-- successfully pulled auth data out of the request eUsr <- checkAuth authProtection aData
Just aData -> do case eUsr of
mUsr <- (checkAuthStrict authProtectionStrict) aData -- we failed to authenticate the user
case mUsr of Left uError -> unauthHandler (onUnauthenticated authProtection) (subserver authProtection) uError aData
-- this user is not authenticated -- user was authenticated
Nothing -> do Right usr ->
resp <- onUnauthenticated (authHandlers authProtectionStrict) aData (return . Route . subserver authProtection) (returnHandler usr)
return $ FailFatal resp
-- this user is authenticated
Just usr ->
(return . Route . subServerStrict authProtectionStrict) usr
in Delayed captures method newAuth body (\_ y _ -> Route y) in Delayed captures method newAuth body (\_ y _ -> Route y)
-- | Add a method to perform authorization in strict mode. -- | Delayed auth checker for Strict Missing and Strict Unauthentication
addAuthLaxCheck :: Delayed (AuthProtected auth usr (Maybe usr -> a) 'Lax) addAuthCheckSS :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Strict uError auth usr (usr -> a))
-> IO (RouteResult (Maybe auth)) -> IO (RouteResult (Either mError auth))
-> Delayed a -> Delayed a
addAuthLaxCheck delayed@(Delayed captures method _ body _) new = addAuthCheckSS = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionLax -> new `bindRouteResults` \mAuthData -> (\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
fmap (Route . subServerLax authProtectionLax) id
(maybe (pure Nothing) (checkAuthLax authProtectionLax) mAuthData)
in Delayed captures method newAuth body (\_ y _ -> Route y) -- | Delayed auth checker for Strict Missing and Lax Unauthentication
addAuthCheckSL :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Lax uError auth usr (Either uError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckSL = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left e))
Right
-- | Delayed auth checker for Lax Missing and Strict Unauthentication
addAuthCheckLS :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Strict uError auth usr (Either mError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckLS = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left e))
(\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
Right
-- | Delayed auth checker for Lax Missing and Lax Unauthentication
addAuthCheckLL :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Lax uError auth usr (Either (Either mError uError) usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckLL = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left (Left e)))
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left (Right e)))
Right
-- | Add an auth check by supplying OnMissing policies and OnUnauthenticated policies.
addAuthCheck :: SAuthPolicy mPolicy
-> SAuthPolicy uPolicy
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheck SStrict SStrict = addAuthCheckSS
addAuthCheck SStrict SLax = addAuthCheckSL
addAuthCheck SLax SStrict = addAuthCheckLS
addAuthCheck SLax SLax = addAuthCheckLL
-- | Add a body check to the end of the body block. -- | Add a body check to the end of the body block.
addBodyCheck :: Delayed (a -> b) addBodyCheck :: Delayed (a -> b)