Incorporate new GADT-based auth into Delayed

This commit is contained in:
aaron levin 2015-12-24 17:20:29 +01:00
parent c169d0bd59
commit 4d23cada4c
2 changed files with 107 additions and 58 deletions

View file

@ -1,9 +1,11 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
module Servant.Server.Internal.Authentication
( AuthData (..)
@ -11,6 +13,7 @@ module Servant.Server.Internal.Authentication
, basicAuthLax
, basicAuthStrict
, jwtAuthStrict
, SimpleAuthProtected
) where
import Control.Monad (guard)
@ -40,19 +43,20 @@ import qualified Web.JWT as JWT (decode)
-- | Class to represent the ability to extract authentication-related
-- data from a 'Request' object.
class AuthData a where
authData :: Request -> Maybe a
class AuthData a e | a -> e where
authData :: Request -> Either e a
authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy)
-> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData
-> (authData -> IO (Either errorIndex usr))
-- | combinator to create authentication protected servers.
authProtect :: OnMissing IO ServantErr missingPolicy missingError
-> OnUnauthenticated IO ServantErr unauthPolicy unauthError authData
-> (authData -> IO (Either unauthError usr))
-> subserver
-> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver
-> AuthProtected IO ServantErr missingPolicy missingError unauthPolicy unauthError authData usr subserver
authProtect = AuthProtected
-- | 'BasicAuth' instance for authData
instance AuthData (BasicAuth realm) where
authData request = do
instance AuthData (BasicAuth realm) () where
authData request = maybe (Left ()) Right $ do
authBs <- lookup "Authorization" (requestHeaders request)
let (x,y) = B.break isSpace authBs
guard (B.map toLower x == "basic")
@ -72,8 +76,8 @@ basicAuthFailure p = let realmBytes = (fromString . symbolVal) p
-- | OnMisisng handler for Basic Authentication
basicMissingHandler :: forall realm. KnownSymbol realm
=> Proxy realm
-> OnMissing IO ServantErr 'Strict
basicMissingHandler p = StrictMissing (return $ basicAuthFailure p)
-> OnMissing IO ServantErr 'Strict ()
basicMissingHandler p = StrictMissing (const $ return (basicAuthFailure p))
-- | OnUnauthenticated handler for Basic Authentication
basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
@ -85,7 +89,7 @@ basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ b
basicAuthStrict :: forall realm usr subserver. KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr))
-> subserver
-> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver
-> AuthProtected IO ServantErr 'Strict () 'Strict () (BasicAuth realm) usr subserver
basicAuthStrict check sub =
let mHandler = basicMissingHandler (Proxy :: Proxy realm)
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
@ -96,14 +100,14 @@ basicAuthStrict check sub =
basicAuthLax :: KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr))
-> subserver
-> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver
-> AuthProtected IO ServantErr 'Lax () 'Lax () (BasicAuth realm) usr subserver
basicAuthLax check sub =
let check' = \a -> maybe (Left ()) Right <$> check a
in AuthProtected LaxMissing LaxUnauthenticated check' sub
-- | Authentication data we extract from requests for JWT-based authentication.
instance AuthData JWTAuth where
authData req = do
instance AuthData JWTAuth () where
authData req = maybe (Left ()) Right $ do
hdr <- lookup "Authorization" . requestHeaders $ req
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
@ -116,9 +120,13 @@ jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>
-- | OnMissing handler for Strict, JWT-based authentication
jwtAuthStrict :: Secret
-> subserver
-> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver
-> AuthProtected IO ServantErr 'Strict () 'Strict () JWTAuth (JWT VerifiedJWT) subserver
jwtAuthStrict secret sub =
let missingHandler = StrictMissing (return $ jwtWithError "invalid_request")
let missingHandler = StrictMissing (const $ return (jwtWithError "invalid_request"))
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
in AuthProtected missingHandler unauthHandler check sub
-- | A type alias to make simple authentication endpoints
type SimpleAuthProtected mPolicy uPolicy authData usr subserver =
AuthProtected IO ServantErr mPolicy () uPolicy () authData usr subserver

View file

@ -2,10 +2,11 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Servant.Server.Internal.RoutingApplication where
#if !MIN_VERSION_base(4,8,0)
@ -20,8 +21,8 @@ import Network.Wai (Application, Request,
Response, ResponseReceived,
requestBody,
strictRequestBody)
import Servant.API.Authentication (AuthPolicy(Strict,Lax))
import Servant.Server.Internal.Authentication
import Servant.API.Authentication (AuthProtected (..), AuthPolicy(Strict,Lax),
OnMissing (..), OnUnauthenticated (..), SAuthPolicy (..))
import Servant.Server.Internal.ServantErr
type RoutingApplication =
@ -179,43 +180,83 @@ addMethodCheck :: Delayed a
addMethodCheck (Delayed captures method auth body server) new =
Delayed captures (combineRouteResults const method new) auth body server
-- | Add a method to perform authorization in strict mode.
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
-> IO (RouteResult (Maybe auth))
-- | helper type family to capture server handled values for various policies
type family AuthDelayedReturn (mP :: AuthPolicy) mE (uP :: AuthPolicy) uE usr :: * where
AuthDelayedReturn 'Strict mE 'Strict uE usr = usr
AuthDelayedReturn 'Strict mE 'Lax uE usr = Either uE usr
AuthDelayedReturn 'Lax mE 'Strict uE usr = Either mE usr
AuthDelayedReturn 'Lax mE 'Lax uE usr = Either (Either mE uE) usr
-- | Internal method to generate auth checkers for various policies. Scary type signature
-- but it does help with understanding the logic of how each policy works. See
-- examples below.
genAuthCheck :: (OnMissing IO ServantErr mPolicy mError -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> mError -> IO (RouteResult a))
-> (OnUnauthenticated IO ServantErr uPolicy uError auth -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> uError -> auth -> IO (RouteResult a))
-> (usr -> (AuthDelayedReturn mPolicy mError uPolicy uError usr))
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthStrictCheck delayed@(Delayed captures method _ body _) new =
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionStrict -> new `bindRouteResults` \mAuthData -> case mAuthData of
Nothing -> do
-- we're in strict mode: don't let the request go
-- call the provided "on missing auth" handler
resp <- onMissingAuthData (authHandlers authProtectionStrict)
return $ FailFatal resp
-- successfully pulled auth data out of the request
Just aData -> do
mUsr <- (checkAuthStrict authProtectionStrict) aData
case mUsr of
-- this user is not authenticated
Nothing -> do
resp <- onUnauthenticated (authHandlers authProtectionStrict) aData
return $ FailFatal resp
-- this user is authenticated
Just usr ->
(return . Route . subServerStrict authProtectionStrict) usr
genAuthCheck missingHandler unauthHandler returnHandler d@(Delayed captures method _ body _) new =
let newAuth =
runDelayed d `bindRouteResults` \ authProtection ->
new `bindRouteResults` \ eAuthData ->
case eAuthData of
-- we failed to extract authentication data from the request
Left mError -> missingHandler (onMissing authProtection) (subserver authProtection) mError
-- auth data was succesfully extracted from the request
Right aData -> do
eUsr <- checkAuth authProtection aData
case eUsr of
-- we failed to authenticate the user
Left uError -> unauthHandler (onUnauthenticated authProtection) (subserver authProtection) uError aData
-- user was authenticated
Right usr ->
(return . Route . subserver authProtection) (returnHandler usr)
in Delayed captures method newAuth body (\_ y _ -> Route y)
-- | Add a method to perform authorization in strict mode.
addAuthLaxCheck :: Delayed (AuthProtected auth usr (Maybe usr -> a) 'Lax)
-> IO (RouteResult (Maybe auth))
-- | Delayed auth checker for Strict Missing and Strict Unauthentication
addAuthCheckSS :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Strict uError auth usr (usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthLaxCheck delayed@(Delayed captures method _ body _) new =
let newAuth = runDelayed delayed `bindRouteResults` \authProtectionLax -> new `bindRouteResults` \mAuthData ->
fmap (Route . subServerLax authProtectionLax)
(maybe (pure Nothing) (checkAuthLax authProtectionLax) mAuthData)
addAuthCheckSS = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
(\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
id
in Delayed captures method newAuth body (\_ y _ -> Route y)
-- | Delayed auth checker for Strict Missing and Lax Unauthentication
addAuthCheckSL :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Lax uError auth usr (Either uError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckSL = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left e))
Right
-- | Delayed auth checker for Lax Missing and Strict Unauthentication
addAuthCheckLS :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Strict uError auth usr (Either mError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckLS = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left e))
(\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
Right
-- | Delayed auth checker for Lax Missing and Lax Unauthentication
addAuthCheckLL :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Lax uError auth usr (Either (Either mError uError) usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckLL = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left (Left e)))
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left (Right e)))
Right
-- | Add an auth check by supplying OnMissing policies and OnUnauthenticated policies.
addAuthCheck :: SAuthPolicy mPolicy
-> SAuthPolicy uPolicy
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheck SStrict SStrict = addAuthCheckSS
addAuthCheck SStrict SLax = addAuthCheckSL
addAuthCheck SLax SStrict = addAuthCheckLS
addAuthCheck SLax SLax = addAuthCheckLL
-- | Add a body check to the end of the body block.
addBodyCheck :: Delayed (a -> b)