diff --git a/servant-server/src/Servant/Server/Internal/Authentication.hs b/servant-server/src/Servant/Server/Internal/Authentication.hs index fc9d020f..c120dbe0 100644 --- a/servant-server/src/Servant/Server/Internal/Authentication.hs +++ b/servant-server/src/Servant/Server/Internal/Authentication.hs @@ -6,14 +6,10 @@ {-# LANGUAGE FlexibleInstances #-} module Servant.Server.Internal.Authentication -( AuthProtected (..) -, AuthData (..) -, AuthHandlers (AuthHandlers, onMissingAuthData, onUnauthenticated) +( AuthData (..) +, authProtect , basicAuthLax , basicAuthStrict -, laxProtect -, strictProtect -, jwtAuthHandlers , jwtAuthStrict ) where @@ -34,9 +30,11 @@ import Data.Text (splitOn) import Network.Wai (Request, requestHeaders) import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders)) import Servant.API.Authentication (AuthPolicy (Strict, Lax), - AuthProtected, + AuthProtected(..), BasicAuth (BasicAuth), - JWTAuth(..)) + JWTAuth(..), + OnMissing (..), + OnUnauthenticated (..)) import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret) import qualified Web.JWT as JWT (decode) @@ -45,40 +43,12 @@ import qualified Web.JWT as JWT (decode) class AuthData a where authData :: Request -> Maybe a --- | handlers to deal with authentication failures. -data AuthHandlers authData = AuthHandlers - { -- we couldn't find the right type of auth data (or any, for that matter) - onMissingAuthData :: IO ServantErr - , - -- we found the right type of auth data in the request but the check failed - onUnauthenticated :: authData -> IO ServantErr - } - --- | concrete type to provide when in 'Strict' mode. -data instance AuthProtected authData usr subserver 'Strict = - AuthProtectedStrict { checkAuthStrict :: authData -> IO (Maybe usr) - , authHandlers :: AuthHandlers authData - , subServerStrict :: subserver - } - --- | concrete type to provide when in 'Lax' mode. -data instance AuthProtected authData usr subserver 'Lax = - AuthProtectedLax { checkAuthLax :: authData -> IO (Maybe usr) - , subServerLax :: subserver - } - --- | handy function to build an auth-protected bit of API with a 'Lax' policy -laxProtect :: (authData -> IO (Maybe usr)) -- ^ check auth - -> subserver -- ^ the handlers for the auth-aware bits of the API - -> AuthProtected authData usr subserver 'Lax -laxProtect = AuthProtectedLax - --- | handy function to build an auth-protected bit of API with a 'Strict' policy -strictProtect :: (authData -> IO (Maybe usr)) -- ^ check auth - -> AuthHandlers authData -- ^ functions to call on auth failure - -> subserver -- ^ handlers for the auth-protected bits of the API - -> AuthProtected authData usr subserver 'Strict -strictProtect = AuthProtectedStrict +authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy) + -> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData + -> (authData -> IO (Either errorIndex usr)) + -> subserver + -> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver +authProtect = AuthProtected -- | 'BasicAuth' instance for authData instance AuthData (BasicAuth realm) where @@ -91,31 +61,47 @@ instance AuthData (BasicAuth realm) where (_, password) <- B.uncons passWithColonAtHead return $ BasicAuth username password --- | handlers for Basic Authentication. -basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth realm) -basicAuthHandlers = - let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm) - headerBytes = "Basic realm=\"" <> realmBytes <> "\"" - authFailure = err401 { errHeaders = [("WWW-Authenticate", headerBytes)] } - in - AuthHandlers (return authFailure) ((const . return) authFailure) +-- | failure response for Basic Authentication +basicAuthFailure :: forall realm. KnownSymbol realm + => Proxy realm + -> ServantErr +basicAuthFailure p = let realmBytes = (fromString . symbolVal) p + headerBytes = "Basic realm=\"" <> realmBytes <> "\"" + in err401 { errHeaders = [("WWW-Authenticate", headerBytes)] } + +-- | OnMisisng handler for Basic Authentication +basicMissingHandler :: forall realm. KnownSymbol realm + => Proxy realm + -> OnMissing IO ServantErr 'Strict +basicMissingHandler p = StrictMissing (return $ basicAuthFailure p) + +-- | OnUnauthenticated handler for Basic Authentication +basicUnauthenticatedHandler :: forall realm. KnownSymbol realm + => Proxy realm + -> OnUnauthenticated IO ServantErr 'Strict () (BasicAuth realm) +basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ basicAuthFailure p)) -- | Basic authentication combinator with strict failure. -basicAuthStrict :: KnownSymbol realm +basicAuthStrict :: forall realm usr subserver. KnownSymbol realm => (BasicAuth realm -> IO (Maybe usr)) -> subserver - -> AuthProtected (BasicAuth realm) usr subserver 'Strict -basicAuthStrict check subserver = strictProtect check basicAuthHandlers subserver + -> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver +basicAuthStrict check sub = + let mHandler = basicMissingHandler (Proxy :: Proxy realm) + unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm) + check' = \auth -> maybe (Left ()) Right <$> check auth + in AuthProtected mHandler unauthHandler check' sub -- | Basic authentication combinator with lax failure. basicAuthLax :: KnownSymbol realm => (BasicAuth realm -> IO (Maybe usr)) -> subserver - -> AuthProtected (BasicAuth realm) usr subserver 'Lax -basicAuthLax = laxProtect - - + -> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver +basicAuthLax check sub = + let check' = \a -> maybe (Left ()) Right <$> check a + in AuthProtected LaxMissing LaxUnauthenticated check' sub +-- | Authentication data we extract from requests for JWT-based authentication. instance AuthData JWTAuth where authData req = do hdr <- lookup "Authorization" . requestHeaders $ req @@ -123,19 +109,16 @@ instance AuthData JWTAuth where _ <- JWT.decode token -- try decode it. otherwise it's not a proper token return . JWTAuth $ token +-- | helper method to construct jwt handlers +jwtWithError :: B.ByteString -> ServantErr +jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] } -jwtAuthHandlers :: AuthHandlers JWTAuth -jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure) - where - withError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] } - missingData = withError "invalid_request" - authFailure = withError "invalid_token" - - --- | A default implementation of an AuthProtected for JWT. --- Use this to quickly add jwt authentication to your project. --- One can use strictProtect and laxProtect to make more complex authentication --- and authorization schemes. -jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict -jwtAuthStrict secret subserver = strictProtect (return . decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver - +-- | OnMissing handler for Strict, JWT-based authentication +jwtAuthStrict :: Secret + -> subserver + -> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver +jwtAuthStrict secret sub = + let missingHandler = StrictMissing (return $ jwtWithError "invalid_request") + unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token")) + check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth + in AuthProtected missingHandler unauthHandler check sub