Update server authentication with new GADT

This commit is contained in:
aaron levin 2015-12-24 12:09:12 +01:00
parent 0fbd84bfd7
commit 743c51b3c5

View file

@ -6,14 +6,10 @@
{-# LANGUAGE FlexibleInstances #-}
module Servant.Server.Internal.Authentication
( AuthProtected (..)
, AuthData (..)
, AuthHandlers (AuthHandlers, onMissingAuthData, onUnauthenticated)
( AuthData (..)
, authProtect
, basicAuthLax
, basicAuthStrict
, laxProtect
, strictProtect
, jwtAuthHandlers
, jwtAuthStrict
) where
@ -34,9 +30,11 @@ import Data.Text (splitOn)
import Network.Wai (Request, requestHeaders)
import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders))
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtected,
AuthProtected(..),
BasicAuth (BasicAuth),
JWTAuth(..))
JWTAuth(..),
OnMissing (..),
OnUnauthenticated (..))
import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret)
import qualified Web.JWT as JWT (decode)
@ -45,40 +43,12 @@ import qualified Web.JWT as JWT (decode)
class AuthData a where
authData :: Request -> Maybe a
-- | handlers to deal with authentication failures.
data AuthHandlers authData = AuthHandlers
{ -- we couldn't find the right type of auth data (or any, for that matter)
onMissingAuthData :: IO ServantErr
,
-- we found the right type of auth data in the request but the check failed
onUnauthenticated :: authData -> IO ServantErr
}
-- | concrete type to provide when in 'Strict' mode.
data instance AuthProtected authData usr subserver 'Strict =
AuthProtectedStrict { checkAuthStrict :: authData -> IO (Maybe usr)
, authHandlers :: AuthHandlers authData
, subServerStrict :: subserver
}
-- | concrete type to provide when in 'Lax' mode.
data instance AuthProtected authData usr subserver 'Lax =
AuthProtectedLax { checkAuthLax :: authData -> IO (Maybe usr)
, subServerLax :: subserver
}
-- | handy function to build an auth-protected bit of API with a 'Lax' policy
laxProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
-> subserver -- ^ the handlers for the auth-aware bits of the API
-> AuthProtected authData usr subserver 'Lax
laxProtect = AuthProtectedLax
-- | handy function to build an auth-protected bit of API with a 'Strict' policy
strictProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
-> AuthHandlers authData -- ^ functions to call on auth failure
-> subserver -- ^ handlers for the auth-protected bits of the API
-> AuthProtected authData usr subserver 'Strict
strictProtect = AuthProtectedStrict
authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy)
-> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData
-> (authData -> IO (Either errorIndex usr))
-> subserver
-> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver
authProtect = AuthProtected
-- | 'BasicAuth' instance for authData
instance AuthData (BasicAuth realm) where
@ -91,31 +61,47 @@ instance AuthData (BasicAuth realm) where
(_, password) <- B.uncons passWithColonAtHead
return $ BasicAuth username password
-- | handlers for Basic Authentication.
basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth realm)
basicAuthHandlers =
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
authFailure = err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
in
AuthHandlers (return authFailure) ((const . return) authFailure)
-- | failure response for Basic Authentication
basicAuthFailure :: forall realm. KnownSymbol realm
=> Proxy realm
-> ServantErr
basicAuthFailure p = let realmBytes = (fromString . symbolVal) p
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
in err401 { errHeaders = [("WWW-Authenticate", headerBytes)] }
-- | OnMisisng handler for Basic Authentication
basicMissingHandler :: forall realm. KnownSymbol realm
=> Proxy realm
-> OnMissing IO ServantErr 'Strict
basicMissingHandler p = StrictMissing (return $ basicAuthFailure p)
-- | OnUnauthenticated handler for Basic Authentication
basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
=> Proxy realm
-> OnUnauthenticated IO ServantErr 'Strict () (BasicAuth realm)
basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ basicAuthFailure p))
-- | Basic authentication combinator with strict failure.
basicAuthStrict :: KnownSymbol realm
basicAuthStrict :: forall realm usr subserver. KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr))
-> subserver
-> AuthProtected (BasicAuth realm) usr subserver 'Strict
basicAuthStrict check subserver = strictProtect check basicAuthHandlers subserver
-> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver
basicAuthStrict check sub =
let mHandler = basicMissingHandler (Proxy :: Proxy realm)
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
check' = \auth -> maybe (Left ()) Right <$> check auth
in AuthProtected mHandler unauthHandler check' sub
-- | Basic authentication combinator with lax failure.
basicAuthLax :: KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr))
-> subserver
-> AuthProtected (BasicAuth realm) usr subserver 'Lax
basicAuthLax = laxProtect
-> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver
basicAuthLax check sub =
let check' = \a -> maybe (Left ()) Right <$> check a
in AuthProtected LaxMissing LaxUnauthenticated check' sub
-- | Authentication data we extract from requests for JWT-based authentication.
instance AuthData JWTAuth where
authData req = do
hdr <- lookup "Authorization" . requestHeaders $ req
@ -123,19 +109,16 @@ instance AuthData JWTAuth where
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
return . JWTAuth $ token
-- | helper method to construct jwt handlers
jwtWithError :: B.ByteString -> ServantErr
jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
jwtAuthHandlers :: AuthHandlers JWTAuth
jwtAuthHandlers = AuthHandlers (return missingData) ((const . return) authFailure)
where
withError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>e<>"\"")] }
missingData = withError "invalid_request"
authFailure = withError "invalid_token"
-- | A default implementation of an AuthProtected for JWT.
-- Use this to quickly add jwt authentication to your project.
-- One can use strictProtect and laxProtect to make more complex authentication
-- and authorization schemes.
jwtAuthStrict :: Secret -> subserver -> AuthProtected JWTAuth (JWT VerifiedJWT) subserver 'Strict
jwtAuthStrict secret subserver = strictProtect (return . decodeAndVerifySignature secret . unJWTAuth) jwtAuthHandlers subserver
-- | OnMissing handler for Strict, JWT-based authentication
jwtAuthStrict :: Secret
-> subserver
-> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver
jwtAuthStrict secret sub =
let missingHandler = StrictMissing (return $ jwtWithError "invalid_request")
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
in AuthProtected missingHandler unauthHandler check sub