Update server authentication with new GADT
This commit is contained in:
parent
0fbd84bfd7
commit
743c51b3c5
1 changed files with 55 additions and 72 deletions
|
@ -6,14 +6,10 @@
|
|||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
module Servant.Server.Internal.Authentication
|
||||
( AuthProtected (..)
|
||||
, AuthData (..)
|
||||
, AuthHandlers (AuthHandlers, onMissingAuthData, onUnauthenticated)
|
||||
( AuthData (..)
|
||||
, authProtect
|
||||
, basicAuthLax
|
||||
, basicAuthStrict
|
||||
, laxProtect
|
||||
, strictProtect
|
||||
, jwtAuthHandlers
|
||||
, jwtAuthStrict
|
||||
) where
|
||||
|
||||
|
@ -34,9 +30,11 @@ import Data.Text (splitOn)
|
|||
import Network.Wai (Request, requestHeaders)
|
||||
import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders))
|
||||
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
||||
AuthProtected,
|
||||
AuthProtected(..),
|
||||
BasicAuth (BasicAuth),
|
||||
JWTAuth(..))
|
||||
JWTAuth(..),
|
||||
OnMissing (..),
|
||||
OnUnauthenticated (..))
|
||||
import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret)
|
||||
import qualified Web.JWT as JWT (decode)
|
||||
|
||||
|
@ -45,40 +43,12 @@ import qualified Web.JWT as JWT (decode)
|
|||
class AuthData a where
|
||||
authData :: Request -> Maybe a
|
||||
|
||||
-- | handlers to deal with authentication failures.
|
||||
data AuthHandlers authData = AuthHandlers
|
||||
{ -- we couldn't find the right type of auth data (or any, for that matter)
|
||||
onMissingAuthData :: IO ServantErr
|
||||
,
|
||||
-- we found the right type of auth data in the request but the check failed
|
||||
onUnauthenticated :: authData -> IO ServantErr
|
||||
}
|
||||
|
||||
-- | concrete type to provide when in 'Strict' mode.
|
||||
data instance AuthProtected authData usr subserver 'Strict =
|
||||
AuthProtectedStrict { checkAuthStrict :: authData -> IO (Maybe usr)
|
||||
, authHandlers :: AuthHandlers authData
|
||||
, subServerStrict :: subserver
|
||||
}
|
||||
|
||||
-- | concrete type to provide when in 'Lax' mode.
|
||||
data instance AuthProtected authData usr subserver 'Lax =
|
||||
AuthProtectedLax { checkAuthLax :: authData -> IO (Maybe usr)
|
||||
, subServerLax :: subserver
|
||||
}
|
||||
|
||||
-- | handy function to build an auth-protected bit of API with a 'Lax' policy
|
||||
laxProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
|
||||
-> subserver -- ^ the handlers for the auth-aware bits of the API
|
||||
-> AuthProtected authData usr subserver 'Lax
|
||||
laxProtect = AuthProtectedLax
|
||||
|
||||
-- | handy function to build an auth-protected bit of API with a 'Strict' policy
|
||||
strictProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
|
||||
-> AuthHandlers authData -- ^ functions to call on auth failure
|
||||
-> subserver -- ^ handlers for the auth-protected bits of the API
|
||||
-> AuthProtected authData usr subserver 'Strict
|
||||
strictProtect = AuthProtectedStrict
|
||||
authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy)
|
||||
-> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData
|
||||
-> (authData -> IO (Either errorIndex usr))
|
||||
-> subserver
|
||||
-> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver
|
||||
authProtect = AuthProtected
|
||||
|
||||
-- | 'BasicAuth' instance for authData
|
||||
instance AuthData (BasicAuth realm) where
|
||||
|
@ -91,31 +61,47 @@ instance AuthData (BasicAuth realm) where
|
|||
(_, password) <- B.uncons passWithColonAtHead
|
||||
return $ BasicAuth username password
|
||||
|
||||
-- | handlers for Basic Authentication.
|
||||
basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth realm)
|
||||
basicAuthHandlers =
|
||||
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
|
||||
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
|
||||
authFailure = err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
|
||||
in
|
||||
AuthHandlers (return authFailure) ((const . return) authFailure)
|
||||
-- | failure response for Basic Authentication
|
||||
basicAuthFailure :: forall realm. KnownSymbol realm
|
||||
=> Proxy realm
|
||||
-> ServantErr
|
||||
basicAuthFailure p = let realmBytes = (fromString . symbolVal) p
|
||||
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
|
||||
in err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
|
||||
|
||||
-- | OnMisisng handler for Basic Authentication
|
||||
basicMissingHandler :: forall realm. KnownSymbol realm
|
||||
=> Proxy realm
|
||||
-> OnMissing IO ServantErr 'Strict
|
||||
basicMissingHandler p = StrictMissing (return $ basicAuthFailure p)
|
||||
|
||||
-- | OnUnauthenticated handler for Basic Authentication
|
||||
basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
|
||||
=> Proxy realm
|
||||
-> OnUnauthenticated IO ServantErr 'Strict () (BasicAuth realm)
|
||||
basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ basicAuthFailure p))
|
||||
|
||||
-- | Basic authentication combinator with strict failure.
|
||||
basicAuthStrict :: KnownSymbol realm
|
||||
basicAuthStrict :: forall realm usr subserver. KnownSymbol realm
|
||||
=> (BasicAuth realm -> IO (Maybe usr))
|
||||
-> subserver
|
||||
-> AuthProtected (BasicAuth realm) usr subserver 'Strict
|
||||
basicAuthStrict check subserver = strictProtect check basicAuthHandlers subserver
|
||||
-> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver
|
||||
basicAuthStrict check sub =
|
||||
let mHandler = basicMissingHandler (Proxy :: Proxy realm)
|
||||
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
|
||||
check' = \auth -> maybe (Left ()) Right <$> check auth
|
||||
in AuthProtected mHandler unauthHandler check' sub
|
||||
|
||||
-- | Basic authentication combinator with lax failure.
|
||||
basicAuthLax :: KnownSymbol realm
|
||||
=> (BasicAuth realm -> IO (Maybe usr))
|
||||
-> subserver
|
||||
-> AuthProtected (BasicAuth realm) usr subserver 'Lax
|
||||
basicAuthLax = laxProtect
|
||||
|
||||
|
||||
-> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver
|
||||
basicAuthLax check sub =
|
||||
let check' = \a -> maybe (Left ()) Right <$> check a
|
||||
in AuthProtected LaxMissing LaxUnauthenticated check' sub
|
||||
|
||||
-- | Authentication data we extract from requests for JWT-based authentication.
|
||||
instance AuthData JWTAuth where
|
||||
authData req = do
|
||||
hdr <- lookup "Authorization" . requestHeaders $ req
|
||||
|
@ -123,19 +109,16 @@ instance AuthData JWTAuth where
|
|||
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
|
||||
return . JWTAuth $ token
|
||||
|
||||
-- | helper method to construct jwt handlers
|
||||
jwtWithError :: B.ByteString -> ServantErr
|
||||
jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
|
||||
|
||||
jwtAuthHandlers :: AuthHandlers JWTAuth
|
||||
jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure)
|
||||
where
|
||||
withError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
|
||||
missingData = withError "invalid_request"
|
||||
authFailure = withError "invalid_token"
|
||||
|
||||
|
||||
-- | A default implementation of an AuthProtected for JWT.
|
||||
-- Use this to quickly add jwt authentication to your project.
|
||||
-- One can use strictProtect and laxProtect to make more complex authentication
|
||||
-- and authorization schemes.
|
||||
jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict
|
||||
jwtAuthStrict secret subserver = strictProtect (return . decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver
|
||||
|
||||
-- | OnMissing handler for Strict, JWT-based authentication
|
||||
jwtAuthStrict :: Secret
|
||||
-> subserver
|
||||
-> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver
|
||||
jwtAuthStrict secret sub =
|
||||
let missingHandler = StrictMissing (return $ jwtWithError "invalid_request")
|
||||
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
|
||||
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
|
||||
in AuthProtected missingHandler unauthHandler check sub
|
||||
|
|
Loading…
Reference in a new issue