Update server authentication with new GADT
This commit is contained in:
parent
0fbd84bfd7
commit
743c51b3c5
1 changed files with 55 additions and 72 deletions
|
@ -6,14 +6,10 @@
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
module Servant.Server.Internal.Authentication
|
module Servant.Server.Internal.Authentication
|
||||||
( AuthProtected (..)
|
( AuthData (..)
|
||||||
, AuthData (..)
|
, authProtect
|
||||||
, AuthHandlers (AuthHandlers, onMissingAuthData, onUnauthenticated)
|
|
||||||
, basicAuthLax
|
, basicAuthLax
|
||||||
, basicAuthStrict
|
, basicAuthStrict
|
||||||
, laxProtect
|
|
||||||
, strictProtect
|
|
||||||
, jwtAuthHandlers
|
|
||||||
, jwtAuthStrict
|
, jwtAuthStrict
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
@ -34,9 +30,11 @@ import Data.Text (splitOn)
|
||||||
import Network.Wai (Request, requestHeaders)
|
import Network.Wai (Request, requestHeaders)
|
||||||
import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders))
|
import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders))
|
||||||
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
||||||
AuthProtected,
|
AuthProtected(..),
|
||||||
BasicAuth (BasicAuth),
|
BasicAuth (BasicAuth),
|
||||||
JWTAuth(..))
|
JWTAuth(..),
|
||||||
|
OnMissing (..),
|
||||||
|
OnUnauthenticated (..))
|
||||||
import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret)
|
import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret)
|
||||||
import qualified Web.JWT as JWT (decode)
|
import qualified Web.JWT as JWT (decode)
|
||||||
|
|
||||||
|
@ -45,40 +43,12 @@ import qualified Web.JWT as JWT (decode)
|
||||||
class AuthData a where
|
class AuthData a where
|
||||||
authData :: Request -> Maybe a
|
authData :: Request -> Maybe a
|
||||||
|
|
||||||
-- | handlers to deal with authentication failures.
|
authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy)
|
||||||
data AuthHandlers authData = AuthHandlers
|
-> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData
|
||||||
{ -- we couldn't find the right type of auth data (or any, for that matter)
|
-> (authData -> IO (Either errorIndex usr))
|
||||||
onMissingAuthData :: IO ServantErr
|
-> subserver
|
||||||
,
|
-> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver
|
||||||
-- we found the right type of auth data in the request but the check failed
|
authProtect = AuthProtected
|
||||||
onUnauthenticated :: authData -> IO ServantErr
|
|
||||||
}
|
|
||||||
|
|
||||||
-- | concrete type to provide when in 'Strict' mode.
|
|
||||||
data instance AuthProtected authData usr subserver 'Strict =
|
|
||||||
AuthProtectedStrict { checkAuthStrict :: authData -> IO (Maybe usr)
|
|
||||||
, authHandlers :: AuthHandlers authData
|
|
||||||
, subServerStrict :: subserver
|
|
||||||
}
|
|
||||||
|
|
||||||
-- | concrete type to provide when in 'Lax' mode.
|
|
||||||
data instance AuthProtected authData usr subserver 'Lax =
|
|
||||||
AuthProtectedLax { checkAuthLax :: authData -> IO (Maybe usr)
|
|
||||||
, subServerLax :: subserver
|
|
||||||
}
|
|
||||||
|
|
||||||
-- | handy function to build an auth-protected bit of API with a 'Lax' policy
|
|
||||||
laxProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
|
|
||||||
-> subserver -- ^ the handlers for the auth-aware bits of the API
|
|
||||||
-> AuthProtected authData usr subserver 'Lax
|
|
||||||
laxProtect = AuthProtectedLax
|
|
||||||
|
|
||||||
-- | handy function to build an auth-protected bit of API with a 'Strict' policy
|
|
||||||
strictProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
|
|
||||||
-> AuthHandlers authData -- ^ functions to call on auth failure
|
|
||||||
-> subserver -- ^ handlers for the auth-protected bits of the API
|
|
||||||
-> AuthProtected authData usr subserver 'Strict
|
|
||||||
strictProtect = AuthProtectedStrict
|
|
||||||
|
|
||||||
-- | 'BasicAuth' instance for authData
|
-- | 'BasicAuth' instance for authData
|
||||||
instance AuthData (BasicAuth realm) where
|
instance AuthData (BasicAuth realm) where
|
||||||
|
@ -91,31 +61,47 @@ instance AuthData (BasicAuth realm) where
|
||||||
(_, password) <- B.uncons passWithColonAtHead
|
(_, password) <- B.uncons passWithColonAtHead
|
||||||
return $ BasicAuth username password
|
return $ BasicAuth username password
|
||||||
|
|
||||||
-- | handlers for Basic Authentication.
|
-- | failure response for Basic Authentication
|
||||||
basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth realm)
|
basicAuthFailure :: forall realm. KnownSymbol realm
|
||||||
basicAuthHandlers =
|
=> Proxy realm
|
||||||
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
|
-> ServantErr
|
||||||
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
|
basicAuthFailure p = let realmBytes = (fromString . symbolVal) p
|
||||||
authFailure = err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
|
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
|
||||||
in
|
in err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
|
||||||
AuthHandlers (return authFailure) ((const . return) authFailure)
|
|
||||||
|
-- | OnMisisng handler for Basic Authentication
|
||||||
|
basicMissingHandler :: forall realm. KnownSymbol realm
|
||||||
|
=> Proxy realm
|
||||||
|
-> OnMissing IO ServantErr 'Strict
|
||||||
|
basicMissingHandler p = StrictMissing (return $ basicAuthFailure p)
|
||||||
|
|
||||||
|
-- | OnUnauthenticated handler for Basic Authentication
|
||||||
|
basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
|
||||||
|
=> Proxy realm
|
||||||
|
-> OnUnauthenticated IO ServantErr 'Strict () (BasicAuth realm)
|
||||||
|
basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ basicAuthFailure p))
|
||||||
|
|
||||||
-- | Basic authentication combinator with strict failure.
|
-- | Basic authentication combinator with strict failure.
|
||||||
basicAuthStrict :: KnownSymbol realm
|
basicAuthStrict :: forall realm usr subserver. KnownSymbol realm
|
||||||
=> (BasicAuth realm -> IO (Maybe usr))
|
=> (BasicAuth realm -> IO (Maybe usr))
|
||||||
-> subserver
|
-> subserver
|
||||||
-> AuthProtected (BasicAuth realm) usr subserver 'Strict
|
-> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver
|
||||||
basicAuthStrict check subserver = strictProtect check basicAuthHandlers subserver
|
basicAuthStrict check sub =
|
||||||
|
let mHandler = basicMissingHandler (Proxy :: Proxy realm)
|
||||||
|
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
|
||||||
|
check' = \auth -> maybe (Left ()) Right <$> check auth
|
||||||
|
in AuthProtected mHandler unauthHandler check' sub
|
||||||
|
|
||||||
-- | Basic authentication combinator with lax failure.
|
-- | Basic authentication combinator with lax failure.
|
||||||
basicAuthLax :: KnownSymbol realm
|
basicAuthLax :: KnownSymbol realm
|
||||||
=> (BasicAuth realm -> IO (Maybe usr))
|
=> (BasicAuth realm -> IO (Maybe usr))
|
||||||
-> subserver
|
-> subserver
|
||||||
-> AuthProtected (BasicAuth realm) usr subserver 'Lax
|
-> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver
|
||||||
basicAuthLax = laxProtect
|
basicAuthLax check sub =
|
||||||
|
let check' = \a -> maybe (Left ()) Right <$> check a
|
||||||
|
in AuthProtected LaxMissing LaxUnauthenticated check' sub
|
||||||
|
|
||||||
|
-- | Authentication data we extract from requests for JWT-based authentication.
|
||||||
instance AuthData JWTAuth where
|
instance AuthData JWTAuth where
|
||||||
authData req = do
|
authData req = do
|
||||||
hdr <- lookup "Authorization" . requestHeaders $ req
|
hdr <- lookup "Authorization" . requestHeaders $ req
|
||||||
|
@ -123,19 +109,16 @@ instance AuthData JWTAuth where
|
||||||
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
|
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
|
||||||
return . JWTAuth $ token
|
return . JWTAuth $ token
|
||||||
|
|
||||||
|
-- | helper method to construct jwt handlers
|
||||||
|
jwtWithError :: B.ByteString -> ServantErr
|
||||||
|
jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
|
||||||
|
|
||||||
jwtAuthHandlers :: AuthHandlers JWTAuth
|
-- | OnMissing handler for Strict, JWT-based authentication
|
||||||
jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure)
|
jwtAuthStrict :: Secret
|
||||||
where
|
-> subserver
|
||||||
withError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
|
-> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver
|
||||||
missingData = withError "invalid_request"
|
jwtAuthStrict secret sub =
|
||||||
authFailure = withError "invalid_token"
|
let missingHandler = StrictMissing (return $ jwtWithError "invalid_request")
|
||||||
|
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
|
||||||
|
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
|
||||||
-- | A default implementation of an AuthProtected for JWT.
|
in AuthProtected missingHandler unauthHandler check sub
|
||||||
-- Use this to quickly add jwt authentication to your project.
|
|
||||||
-- One can use strictProtect and laxProtect to make more complex authentication
|
|
||||||
-- and authorization schemes.
|
|
||||||
jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict
|
|
||||||
jwtAuthStrict secret subserver = strictProtect (return . decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue