diff --git a/servant-server/src/Servant/Server/Internal/Authentication.hs b/servant-server/src/Servant/Server/Internal/Authentication.hs index c120dbe0..9f15ba7a 100644 --- a/servant-server/src/Servant/Server/Internal/Authentication.hs +++ b/servant-server/src/Servant/Server/Internal/Authentication.hs @@ -1,9 +1,11 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} module Servant.Server.Internal.Authentication ( AuthData (..) @@ -11,6 +13,7 @@ module Servant.Server.Internal.Authentication , basicAuthLax , basicAuthStrict , jwtAuthStrict +, SimpleAuthProtected ) where import Control.Monad (guard) @@ -40,19 +43,20 @@ import qualified Web.JWT as JWT (decode) -- | Class to represent the ability to extract authentication-related -- data from a 'Request' object. -class AuthData a where - authData :: Request -> Maybe a +class AuthData a e | a -> e where + authData :: Request -> Either e a -authProtect :: OnMissing IO ServantErr (missingPolicy :: AuthPolicy) - -> OnUnauthenticated IO ServantErr (unauthPolicy :: AuthPolicy) errorIndex authData - -> (authData -> IO (Either errorIndex usr)) +-- | combinator to create authentication protected servers. +authProtect :: OnMissing IO ServantErr missingPolicy missingError + -> OnUnauthenticated IO ServantErr unauthPolicy unauthError authData + -> (authData -> IO (Either unauthError usr)) -> subserver - -> AuthProtected IO ServantErr missingPolicy unauthPolicy errorIndex authData usr subserver + -> AuthProtected IO ServantErr missingPolicy missingError unauthPolicy unauthError authData usr subserver authProtect = AuthProtected -- | 'BasicAuth' instance for authData -instance AuthData (BasicAuth realm) where - authData request = do +instance AuthData (BasicAuth realm) () where + authData request = maybe (Left ()) Right $ do authBs <- lookup "Authorization" (requestHeaders request) let (x,y) = B.break isSpace authBs guard (B.map toLower x == "basic") @@ -72,8 +76,8 @@ basicAuthFailure p = let realmBytes = (fromString . symbolVal) p -- | OnMisisng handler for Basic Authentication basicMissingHandler :: forall realm. KnownSymbol realm => Proxy realm - -> OnMissing IO ServantErr 'Strict -basicMissingHandler p = StrictMissing (return $ basicAuthFailure p) + -> OnMissing IO ServantErr 'Strict () +basicMissingHandler p = StrictMissing (const $ return (basicAuthFailure p)) -- | OnUnauthenticated handler for Basic Authentication basicUnauthenticatedHandler :: forall realm. KnownSymbol realm @@ -85,7 +89,7 @@ basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ b basicAuthStrict :: forall realm usr subserver. KnownSymbol realm => (BasicAuth realm -> IO (Maybe usr)) -> subserver - -> AuthProtected IO ServantErr 'Strict 'Strict () (BasicAuth realm) usr subserver + -> AuthProtected IO ServantErr 'Strict () 'Strict () (BasicAuth realm) usr subserver basicAuthStrict check sub = let mHandler = basicMissingHandler (Proxy :: Proxy realm) unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm) @@ -96,14 +100,14 @@ basicAuthStrict check sub = basicAuthLax :: KnownSymbol realm => (BasicAuth realm -> IO (Maybe usr)) -> subserver - -> AuthProtected IO ServantErr 'Lax 'Lax () (BasicAuth realm) usr subserver + -> AuthProtected IO ServantErr 'Lax () 'Lax () (BasicAuth realm) usr subserver basicAuthLax check sub = let check' = \a -> maybe (Left ()) Right <$> check a in AuthProtected LaxMissing LaxUnauthenticated check' sub -- | Authentication data we extract from requests for JWT-based authentication. -instance AuthData JWTAuth where - authData req = do +instance AuthData JWTAuth () where + authData req = maybe (Left ()) Right $ do hdr <- lookup "Authorization" . requestHeaders $ req ["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr _ <- JWT.decode token -- try decode it. otherwise it's not a proper token @@ -116,9 +120,13 @@ jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<> -- | OnMissing handler for Strict, JWT-based authentication jwtAuthStrict :: Secret -> subserver - -> AuthProtected IO ServantErr 'Strict 'Strict () JWTAuth (JWT VerifiedJWT) subserver + -> AuthProtected IO ServantErr 'Strict () 'Strict () JWTAuth (JWT VerifiedJWT) subserver jwtAuthStrict secret sub = - let missingHandler = StrictMissing (return $ jwtWithError "invalid_request") + let missingHandler = StrictMissing (const $ return (jwtWithError "invalid_request")) unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token")) check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth in AuthProtected missingHandler unauthHandler check sub + +-- | A type alias to make simple authentication endpoints +type SimpleAuthProtected mPolicy uPolicy authData usr subserver = + AuthProtected IO ServantErr mPolicy () uPolicy () authData usr subserver diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 707c0f05..27e45506 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -2,10 +2,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE TypeOperators #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} module Servant.Server.Internal.RoutingApplication where #if !MIN_VERSION_base(4,8,0) @@ -20,8 +21,8 @@ import Network.Wai (Application, Request, Response, ResponseReceived, requestBody, strictRequestBody) -import Servant.API.Authentication (AuthPolicy(Strict,Lax)) -import Servant.Server.Internal.Authentication +import Servant.API.Authentication (AuthProtected (..), AuthPolicy(Strict,Lax), + OnMissing (..), OnUnauthenticated (..), SAuthPolicy (..)) import Servant.Server.Internal.ServantErr type RoutingApplication = @@ -179,43 +180,83 @@ addMethodCheck :: Delayed a addMethodCheck (Delayed captures method auth body server) new = Delayed captures (combineRouteResults const method new) auth body server --- | Add a method to perform authorization in strict mode. -addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict) - -> IO (RouteResult (Maybe auth)) - -> Delayed a -addAuthStrictCheck delayed@(Delayed captures method _ body _) new = - let newAuth = runDelayed delayed `bindRouteResults` \authProtectionStrict -> new `bindRouteResults` \mAuthData -> case mAuthData of +-- | helper type family to capture server handled values for various policies +type family AuthDelayedReturn (mP :: AuthPolicy) mE (uP :: AuthPolicy) uE usr :: * where + AuthDelayedReturn 'Strict mE 'Strict uE usr = usr + AuthDelayedReturn 'Strict mE 'Lax uE usr = Either uE usr + AuthDelayedReturn 'Lax mE 'Strict uE usr = Either mE usr + AuthDelayedReturn 'Lax mE 'Lax uE usr = Either (Either mE uE) usr - Nothing -> do - -- we're in strict mode: don't let the request go - -- call the provided "on missing auth" handler - resp <- onMissingAuthData (authHandlers authProtectionStrict) - return $ FailFatal resp - - -- successfully pulled auth data out of the request - Just aData -> do - mUsr <- (checkAuthStrict authProtectionStrict) aData - case mUsr of - -- this user is not authenticated - Nothing -> do - resp <- onUnauthenticated (authHandlers authProtectionStrict) aData - return $ FailFatal resp - - -- this user is authenticated - Just usr -> - (return . Route . subServerStrict authProtectionStrict) usr +-- | Internal method to generate auth checkers for various policies. Scary type signature +-- but it does help with understanding the logic of how each policy works. See +-- examples below. +genAuthCheck :: (OnMissing IO ServantErr mPolicy mError -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> mError -> IO (RouteResult a)) + -> (OnUnauthenticated IO ServantErr uPolicy uError auth -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> uError -> auth -> IO (RouteResult a)) + -> (usr -> (AuthDelayedReturn mPolicy mError uPolicy uError usr)) + -> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a)) + -> IO (RouteResult (Either mError auth)) + -> Delayed a +genAuthCheck missingHandler unauthHandler returnHandler d@(Delayed captures method _ body _) new = + let newAuth = + runDelayed d `bindRouteResults` \ authProtection -> + new `bindRouteResults` \ eAuthData -> + case eAuthData of + -- we failed to extract authentication data from the request + Left mError -> missingHandler (onMissing authProtection) (subserver authProtection) mError + -- auth data was succesfully extracted from the request + Right aData -> do + eUsr <- checkAuth authProtection aData + case eUsr of + -- we failed to authenticate the user + Left uError -> unauthHandler (onUnauthenticated authProtection) (subserver authProtection) uError aData + -- user was authenticated + Right usr -> + (return . Route . subserver authProtection) (returnHandler usr) in Delayed captures method newAuth body (\_ y _ -> Route y) --- | Add a method to perform authorization in strict mode. -addAuthLaxCheck :: Delayed (AuthProtected auth usr (Maybe usr -> a) 'Lax) - -> IO (RouteResult (Maybe auth)) - -> Delayed a -addAuthLaxCheck delayed@(Delayed captures method _ body _) new = - let newAuth = runDelayed delayed `bindRouteResults` \authProtectionLax -> new `bindRouteResults` \mAuthData -> - fmap (Route . subServerLax authProtectionLax) - (maybe (pure Nothing) (checkAuthLax authProtectionLax) mAuthData) +-- | Delayed auth checker for Strict Missing and Strict Unauthentication +addAuthCheckSS :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Strict uError auth usr (usr -> a)) + -> IO (RouteResult (Either mError auth)) + -> Delayed a +addAuthCheckSS = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e) + (\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a) + id - in Delayed captures method newAuth body (\_ y _ -> Route y) +-- | Delayed auth checker for Strict Missing and Lax Unauthentication +addAuthCheckSL :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Lax uError auth usr (Either uError usr -> a)) + -> IO (RouteResult (Either mError auth)) + -> Delayed a +addAuthCheckSL = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e) + (\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left e)) + Right + + +-- | Delayed auth checker for Lax Missing and Strict Unauthentication +addAuthCheckLS :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Strict uError auth usr (Either mError usr -> a)) + -> IO (RouteResult (Either mError auth)) + -> Delayed a +addAuthCheckLS = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left e)) + (\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a) + Right + +-- | Delayed auth checker for Lax Missing and Lax Unauthentication +addAuthCheckLL :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Lax uError auth usr (Either (Either mError uError) usr -> a)) + -> IO (RouteResult (Either mError auth)) + -> Delayed a +addAuthCheckLL = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left (Left e))) + (\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left (Right e))) + Right + +-- | Add an auth check by supplying OnMissing policies and OnUnauthenticated policies. +addAuthCheck :: SAuthPolicy mPolicy + -> SAuthPolicy uPolicy + -> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a)) + -> IO (RouteResult (Either mError auth)) + -> Delayed a +addAuthCheck SStrict SStrict = addAuthCheckSS +addAuthCheck SStrict SLax = addAuthCheckSL +addAuthCheck SLax SStrict = addAuthCheckLS +addAuthCheck SLax SLax = addAuthCheckLL -- | Add a body check to the end of the body block. addBodyCheck :: Delayed (a -> b)