From 87a3c55c0b4167ee3a30b5a79a2b4eeeea14555b Mon Sep 17 00:00:00 2001 From: Felix Paulusma Date: Mon, 14 Mar 2022 01:25:16 +0100 Subject: [PATCH] servant-auth-server: added new type of Auth as a PoC --- .../src/Servant/Auth/Server/Internal.hs | 160 +++++++++++++++++- 1 file changed, 154 insertions(+), 6 deletions(-) diff --git a/servant-auth/servant-auth-server/src/Servant/Auth/Server/Internal.hs b/servant-auth/servant-auth-server/src/Servant/Auth/Server/Internal.hs index 2e825c0a..7663dea0 100644 --- a/servant-auth/servant-auth-server/src/Servant/Auth/Server/Internal.hs +++ b/servant-auth/servant-auth-server/src/Servant/Auth/Server/Internal.hs @@ -1,24 +1,38 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Servant.Auth.Server.Internal where +import Control.Monad.Except (runExceptT, join) import Control.Monad.Trans (liftIO) -import Servant ((:>), Handler, HasServer (..), - Proxy (..), - HasContextEntry(getContextEntry)) -import Servant.Auth +import Data.Kind (Type) +import Data.Typeable (Typeable, typeRep) +import Network.Wai (Request, queryString) +import Servant +import Servant.API.Modifiers (FoldLenient, FoldRequired, RequestArgument, unfoldRequestArgument) +import Servant.Auth (Auth) import Servant.Auth.JWT (ToJWT) +import Data.Text (Text) import Servant.Auth.Server.Internal.AddSetCookie import Servant.Auth.Server.Internal.Class import Servant.Auth.Server.Internal.Cookie import Servant.Auth.Server.Internal.ConfigTypes -import Servant.Auth.Server.Internal.JWT +-- import Servant.Auth.Server.Internal.JWT import Servant.Auth.Server.Internal.Types -import Servant.Server.Internal (DelayedIO, addAuthCheck, withRequest) +import Servant.Server.Experimental.Auth (AuthHandler (..)) +import Servant.Server.Internal (DelayedIO, addAuthCheck, delayedFail, delayedFailFatal, mkContextWithErrorFormatter, withRequest, MkContextWithErrorFormatter) +import qualified Data.Text as T +import qualified Data.Text.Lazy.Encoding as TLE +import qualified Data.Text.Lazy as TL +import Data.ByteString.Lazy (ByteString) +import GHC.TypeLits (KnownSymbol, symbolVal) +import Network.HTTP.Types (queryToQueryText) +import Data.String (fromString) instance ( n ~ 'S ('S 'Z) , HasServer (AddSetCookiesApi n api) ctxs, AreAuths auths ctxs v @@ -68,3 +82,137 @@ instance ( n ~ 'S ('S 'Z) -> (AuthResult v, SetCookieList n) -> ServerT (AddSetCookiesApi n api) Handler go fn (authResult, cookies) = addSetCookies cookies $ fn authResult + + +{- +NewAuth is a "quick" PoC to have a more modular way of providing +authentications and the checking thereof. + +In the current implementation, all of the 'auths' are checked one +by one, and the first that is present is tried and will either +be returned when successful, or when not, throw an error given +the appropriate 'ErrorFormatter' or return a 'Left err' if the +'Lenient' modifier is set. +Only if NONE of the auths are present will it either throw an +'err401' or, if the 'Optional' modifier is set, return a 'Nothing'. +This error might also be customizeable if we make it be required +from the context. +-} + +data NewAuth (mods :: [Type]) (auths :: [Type]) (a :: Type) + deriving (Typeable) + +type NewAuthResult a = Maybe (Either Text a) + +instance + ( HasServer api ctxs + , HasContextEntry (MkContextWithErrorFormatter ctxs) ErrorFormatters + , SBoolI (FoldRequired mods) + , SBoolI (FoldLenient mods) + , AllAuth auths a + , HasContextEntry ctxs (AuthHandler Request (NewAuthResult a)) + ) => HasServer (NewAuth mods auths a :> api) ctxs where + type ServerT (NewAuth mods auths a :> api) m = + RequestArgument mods a -> ServerT api m + hoistServerWithContext _ pc nt s = + hoistServerWithContext (Proxy @api) pc nt . s + route _ context subserver = + route (Proxy @api) context $ subserver `addAuthCheck` authCheck + where + errorFormatters :: ErrorFormatters + errorFormatters = getContextEntry $ mkContextWithErrorFormatter context + + authCheck :: DelayedIO (RequestArgument mods a) + authCheck = withRequest $ \req -> do + mRes <- processAllAuthHandlers req authHandlers + case mRes of + Nothing -> case sbool :: SBool (FoldRequired mods) of + STrue -> absent + SFalse -> pure Nothing + Just (badForm, eRes) -> + unfoldRequestArgument (Proxy @mods) absent badForm $ Just eRes + + processAllAuthHandlers :: + Request -> + [NewAuthHandler Request (NewAuthResult a)] -> + DelayedIO (Maybe (Text -> DelayedIO (RequestArgument mods a), Either Text a)) + processAllAuthHandlers _req [] = pure Nothing + processAllAuthHandlers req (auth : auths) = do + eRes <- liftIO . runExceptT . runHandler' $ getHandler auth req + either delayedFail go eRes + where + go Nothing = processAllAuthHandlers req auths + go (Just res) = pure $ Just (badForm, res) + badForm err = delayedFailFatal $ errFmtr rep req . T.unpack $ msg <> err + errFmtr = getErrorFormatter auth errorFormatters + rep = typeRep (Proxy :: Proxy NewAuth) + msg = "Authentication via " <> getAuthName auth <> " failed: " + + + authHandlers :: [NewAuthHandler Request (NewAuthResult a)] + authHandlers = allAuthHandlers (Proxy @auths) + + toLBS :: Text -> ByteString + toLBS = TLE.encodeUtf8 . TL.fromStrict + + absent :: DelayedIO (RequestArgument mods a) + absent = + delayedFailFatal $ err401{ errBody = toLBS msg } + where + allAuthMethodNames = getAuthName <$> authHandlers + msg = case allAuthMethodNames of + [] -> "No authentication required, something went wrong." + [auth] -> "Authentication required: " <> auth + auths -> "One of the following authentications required: " <> T.intercalate ", " auths + +data NewAuthHandler r res = NewAuthHandler + { getHandler :: r -> Handler res + -- Used in the following errors: + -- "One of the following authentications required: {authName}, {authName}, etc." + -- And + -- "Authentication via {authName}: (Left Text)" + , getAuthName :: Text + , getErrorFormatter :: ErrorFormatters -> ErrorFormatter + } deriving Typeable + +class AllAuth (auths :: [Type]) a where + allAuthHandlers :: proxy auths -> [NewAuthHandler Request (NewAuthResult a)] + +instance AllAuth '[] a where + allAuthHandlers _ = [] + +instance (AllAuth auths a, HasAuthHandler auth a) => AllAuth (auth ': auths) a where + allAuthHandlers _ = getAuthHandler (Proxy @auth) : allAuthHandlers (Proxy @auths) + +class HasAuthHandler auth a where + getAuthHandler :: proxy auth -> NewAuthHandler Request (NewAuthResult a) + +{- +The following is an example of a partial implementation to be used by users +where they will only have to supply a 'FromJWT' instance for their type. +There's a lot of possibilities, but this is just a quick and easy example. +-} + +-- | Designates a query parameter named @sym@ which should contain a valid JWT +-- +-- E.g. @JWTQueryParam "token"@ will try to parse the value from the query +-- parameter named "token" in the path (i.e. "example.com/path?token=...") +data JWTQueryParam sym + +-- | Pretty library agnostic way of providing a JWT API, but less plug'n'play. +-- +-- The Servant.Auth.JWT 'FromJWT' forces the 'Value' to be in the "dat" field, +-- but makes this implementation easier and user-friendly. So it depends on +-- what we'd prefer. +class FromJWT a where + parseJWT :: Text -> Either Text a + +instance (FromJWT a, KnownSymbol sym) => HasAuthHandler (JWTQueryParam sym) a where + getAuthHandler _ = NewAuthHandler + { getHandler = \req -> + let paramname = fromString $ symbolVal (Proxy :: Proxy sym) + mev = join . lookup paramname . queryToQueryText $ queryString req + in pure $ parseJWT <$> mev + , getAuthName = "JWT Query Parameter" + , getErrorFormatter = urlParseErrorFormatter + }