Update server authentication with new GADT

This commit is contained in:
aaron levin 2015-12-24 12:09:12 +01:00
parent 0fbd84bfd7
commit 743c51b3c5

View file

@ -6,14 +6,10 @@
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
module Servant.Server.Internal.Authentication module Servant.Server.Internal.Authentication
( AuthProtected (..) ( AuthData (..)
, AuthData (..) , authProtect
, AuthHandlers (AuthHandlers, onMissingAuthData, onUnauthenticated)
, basicAuthLax , basicAuthLax
, basicAuthStrict , basicAuthStrict
, laxProtect
, strictProtect
, jwtAuthHandlers
, jwtAuthStrict , jwtAuthStrict
) where ) where
@ -34,9 +30,11 @@ import Data.Text (splitOn)
import Network.Wai (Request, requestHeaders) import Network.Wai (Request, requestHeaders)
import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders)) import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders))
import Servant.API.Authentication (AuthPolicy (Strict, Lax), import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtected, AuthProtected(..),
BasicAuth (BasicAuth), BasicAuth (BasicAuth),
JWTAuth(..)) JWTAuth(..),
OnMissing (..),
OnUnauthenticated (..))
import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret) import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret)
import qualified Web.JWT as JWT (decode) import qualified Web.JWT as JWT (decode)
@ -45,40 +43,12 @@ import qualified Web.JWT as JWT (decode)
class AuthData a where class AuthData a where
authData :: Request -> Maybe a authData :: Request -> Maybe a
-- | handlers to deal with authentication failures. authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy)
data AuthHandlers authData = AuthHandlers -> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData
{ -- we couldn't find the right type of auth data (or any, for that matter) -> (authData -> IO (Either errorIndex usr))
onMissingAuthData :: IO ServantErr -> subserver
, -> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver
-- we found the right type of auth data in the request but the check failed authProtect = AuthProtected
onUnauthenticated :: authData -> IO ServantErr
}
-- | concrete type to provide when in 'Strict' mode.
data instance AuthProtected authData usr subserver 'Strict =
AuthProtectedStrict { checkAuthStrict :: authData -> IO (Maybe usr)
, authHandlers :: AuthHandlers authData
, subServerStrict :: subserver
}
-- | concrete type to provide when in 'Lax' mode.
data instance AuthProtected authData usr subserver 'Lax =
AuthProtectedLax { checkAuthLax :: authData -> IO (Maybe usr)
, subServerLax :: subserver
}
-- | handy function to build an auth-protected bit of API with a 'Lax' policy
laxProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
-> subserver -- ^ the handlers for the auth-aware bits of the API
-> AuthProtected authData usr subserver 'Lax
laxProtect = AuthProtectedLax
-- | handy function to build an auth-protected bit of API with a 'Strict' policy
strictProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
-> AuthHandlers authData -- ^ functions to call on auth failure
-> subserver -- ^ handlers for the auth-protected bits of the API
-> AuthProtected authData usr subserver 'Strict
strictProtect = AuthProtectedStrict
-- | 'BasicAuth' instance for authData -- | 'BasicAuth' instance for authData
instance AuthData (BasicAuth realm) where instance AuthData (BasicAuth realm) where
@ -91,31 +61,47 @@ instance AuthData (BasicAuth realm) where
(_, password) <- B.uncons passWithColonAtHead (_, password) <- B.uncons passWithColonAtHead
return $ BasicAuth username password return $ BasicAuth username password
-- | handlers for Basic Authentication. -- | failure response for Basic Authentication
basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth realm) basicAuthFailure :: forall realm. KnownSymbol realm
basicAuthHandlers = => Proxy realm
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm) -> ServantErr
basicAuthFailure p = let realmBytes = (fromString . symbolVal) p
headerBytes = "Basic realm=\"" <> realmBytes <> "\"" headerBytes = "Basic realm=\"" <> realmBytes <> "\""
authFailure = err401 { errHeaders = [("WWW-Authenticate", headerBytes)] } in err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
in
AuthHandlers (return authFailure) ((const . return) authFailure) -- | OnMisisng handler for Basic Authentication
basicMissingHandler :: forall realm. KnownSymbol realm
=> Proxy realm
-> OnMissing IO ServantErr 'Strict
basicMissingHandler p = StrictMissing (return $ basicAuthFailure p)
-- | OnUnauthenticated handler for Basic Authentication
basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
=> Proxy realm
-> OnUnauthenticated IO ServantErr 'Strict () (BasicAuth realm)
basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ basicAuthFailure p))
-- | Basic authentication combinator with strict failure. -- | Basic authentication combinator with strict failure.
basicAuthStrict :: KnownSymbol realm basicAuthStrict :: forall realm usr subserver. KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr)) => (BasicAuth realm -> IO (Maybe usr))
-> subserver -> subserver
-> AuthProtected (BasicAuth realm) usr subserver 'Strict -> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver
basicAuthStrict check subserver = strictProtect check basicAuthHandlers subserver basicAuthStrict check sub =
let mHandler = basicMissingHandler (Proxy :: Proxy realm)
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
check' = \auth -> maybe (Left ()) Right <$> check auth
in AuthProtected mHandler unauthHandler check' sub
-- | Basic authentication combinator with lax failure. -- | Basic authentication combinator with lax failure.
basicAuthLax :: KnownSymbol realm basicAuthLax :: KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr)) => (BasicAuth realm -> IO (Maybe usr))
-> subserver -> subserver
-> AuthProtected (BasicAuth realm) usr subserver 'Lax -> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver
basicAuthLax = laxProtect basicAuthLax check sub =
let check' = \a -> maybe (Left ()) Right <$> check a
in AuthProtected LaxMissing LaxUnauthenticated check' sub
-- | Authentication data we extract from requests for JWT-based authentication.
instance AuthData JWTAuth where instance AuthData JWTAuth where
authData req = do authData req = do
hdr <- lookup "Authorization" . requestHeaders $ req hdr <- lookup "Authorization" . requestHeaders $ req
@ -123,19 +109,16 @@ instance AuthData JWTAuth where
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token _ <- JWT.decode token -- try decode it. otherwise it's not a proper token
return . JWTAuth $ token return . JWTAuth $ token
-- | helper method to construct jwt handlers
jwtWithError :: B.ByteString -> ServantErr
jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
jwtAuthHandlers :: AuthHandlers JWTAuth -- | OnMissing handler for Strict, JWT-based authentication
jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure) jwtAuthStrict :: Secret
where -> subserver
withError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] } -> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver
missingData = withError "invalid_request" jwtAuthStrict secret sub =
authFailure = withError "invalid_token" let missingHandler = StrictMissing (return $ jwtWithError "invalid_request")
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
-- | A default implementation of an AuthProtected for JWT. in AuthProtected missingHandler unauthHandler check sub
-- Use this to quickly add jwt authentication to your project.
-- One can use strictProtect and laxProtect to make more complex authentication
-- and authorization schemes.
jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict
jwtAuthStrict secret subserver = strictProtect (return . decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver