diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index 76cb2b76..c27400c9 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -13,6 +13,7 @@ module Servant.Server.Internal ( module Servant.Server.Internal + , module Servant.Server.Internal.Auth , module Servant.Server.Internal.Config , module Servant.Server.Internal.Router , module Servant.Server.Internal.RoutingApplication @@ -22,44 +23,59 @@ module Servant.Server.Internal #if !MIN_VERSION_base(4,8,0) import Control.Applicative ((<$>)) #endif -import Control.Monad.Trans.Except (ExceptT) -import qualified Data.ByteString as B -import qualified Data.ByteString.Char8 as BC8 -import qualified Data.ByteString.Lazy as BL -import qualified Data.Map as M -import Data.Maybe (fromMaybe, mapMaybe) -import Data.String (fromString) -import Data.String.Conversions (cs, (<>)) -import Data.Text (Text) +import Control.Monad.Trans.Except (ExceptT, runExceptT) +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BC8 +import qualified Data.ByteString.Lazy as BL +import qualified Data.Map as M +import Data.Maybe (fromMaybe, + mapMaybe) +import Data.String (fromString) +import Data.String.Conversions (cs, (<>)) +import Data.Text (Text) import Data.Typeable -import GHC.Exts (Constraint) -import GHC.TypeLits (KnownNat, KnownSymbol, natVal, - symbolVal) -import Network.HTTP.Types hiding (Header, ResponseHeaders) -import Network.Socket (SockAddr) -import Network.Wai (Application, Request, Response, - httpVersion, isSecure, - lazyRequestBody, pathInfo, - rawQueryString, remoteHost, - requestHeaders, requestMethod, - responseLBS, vault) -import Web.HttpApiData (FromHttpApiData) -import Web.HttpApiData.Internal (parseHeaderMaybe, - parseQueryParamMaybe, - parseUrlPieceMaybe) +import GHC.Exts (Constraint) +import GHC.TypeLits (KnownNat, + KnownSymbol, + natVal, symbolVal) +import Network.HTTP.Types hiding (Header, + ResponseHeaders) +import Network.Socket (SockAddr) +import Network.Wai (Application, + Request, Response, + httpVersion, + isSecure, + lazyRequestBody, + pathInfo, + rawQueryString, + remoteHost, + requestHeaders, + requestMethod, + responseLBS, vault) +import Web.HttpApiData (FromHttpApiData) +import Web.HttpApiData.Internal (parseHeaderMaybe, parseQueryParamMaybe, + parseUrlPieceMaybe) -import Servant.API ((:<|>) (..), (:>), Capture, BasicAuth, - Verb, ReflectMethod(reflectMethod), - IsSecure(..), Header, - QueryFlag, QueryParam, QueryParams, - Raw, RemoteHost, ReqBody, Vault) -import Servant.API.ContentTypes (AcceptHeader (..), - AllCTRender (..), - AllCTUnrender (..), - AllMime, - canHandleAcceptH) -import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders, - getResponse) +import Servant.API ((:<|>) (..), (:>), + AuthProtect, + BasicAuth, Capture, + Header, + IsSecure (..), + QueryFlag, + QueryParam, + QueryParams, Raw, ReflectMethod (reflectMethod), + RemoteHost, + ReqBody, Vault, + Verb) +import Servant.API.ContentTypes (AcceptHeader (..), + AllCTRender (..), + AllCTUnrender (..), + AllMime, + canHandleAcceptH) +import Servant.API.ResponseHeaders (GetHeaders, + Headers, + getHeaders, + getResponse) import Servant.Server.Internal.Auth import Servant.Server.Internal.Config @@ -466,6 +482,7 @@ instance HasServer api => HasServer (HttpVersion :> api) where route Proxy cfg subserver = WithRequest $ \req -> route (Proxy :: Proxy api) cfg (passToServer subserver $ httpVersion req) +-- | Basic Authentication instance (KnownSymbol realm, HasServer api) => HasServer (BasicAuth realm usr :> api) where type ServerT (BasicAuth realm usr :> api) m = usr -> ServerT api m @@ -479,6 +496,17 @@ instance (KnownSymbol realm, HasServer api) baCfg = getConfigEntry (Proxy :: Proxy realm) cfg authCheck req = runBasicAuth req realm baCfg +-- | General Authentication +instance HasServer api => HasServer (AuthProtect tag usr :> api) where + type ServerT (AuthProtect tag usr :> api) m = usr -> ServerT api m + type HasCfg (AuthProtect tag usr :> api) c + = (HasConfigEntry c tag (AuthHandler Request usr), HasCfg api c) + + route Proxy cfg subserver = WithRequest $ \ request -> + route (Proxy :: Proxy api) cfg (subserver `addAuthCheck` authCheck request) + where + authHandler = unAuthHandler (getConfigEntry (Proxy :: Proxy tag) cfg) + authCheck = fmap (either FailFatal Route) . runExceptT . authHandler pathIsEmpty :: Request -> Bool pathIsEmpty = go . pathInfo diff --git a/servant-server/src/Servant/Server/Internal/Auth.hs b/servant-server/src/Servant/Server/Internal/Auth.hs index 3eac4d80..141e61e1 100644 --- a/servant-server/src/Servant/Server/Internal/Auth.hs +++ b/servant-server/src/Servant/Server/Internal/Auth.hs @@ -4,6 +4,7 @@ module Servant.Server.Internal.Auth where import Control.Monad (guard) +import Control.Monad.Trans.Except (ExceptT) import qualified Data.ByteString as BS import Data.ByteString.Base64 (decodeLenient) import Data.Monoid ((<>)) @@ -18,6 +19,13 @@ import Servant.Server.Internal.ServantErr -- * General Auth +-- | Handlers for AuthProtected resources +newtype AuthHandler r usr = AuthHandler + { unAuthHandler :: r -> ExceptT ServantErr IO usr } + +mkAuthHandler :: (r -> ExceptT ServantErr IO usr) -> AuthHandler r usr +mkAuthHandler = AuthHandler + -- | The result of authentication/authorization data AuthResult usr = Unauthorized diff --git a/servant/src/Servant/API.hs b/servant/src/Servant/API.hs index 4f8d6bce..784395db 100644 --- a/servant/src/Servant/API.hs +++ b/servant/src/Servant/API.hs @@ -52,7 +52,7 @@ module Servant.API ( ) where import Servant.API.Alternative ((:<|>) (..)) -import Servant.API.Auth (BasicAuth) +import Servant.API.Auth (BasicAuth, AuthProtect) import Servant.API.Capture (Capture) import Servant.API.ContentTypes (Accept (..), FormUrlEncoded, FromFormUrlEncoded (..), JSON, diff --git a/servant/src/Servant/API/Auth.hs b/servant/src/Servant/API/Auth.hs index db2ab653..72ac1332 100644 --- a/servant/src/Servant/API/Auth.hs +++ b/servant/src/Servant/API/Auth.hs @@ -18,3 +18,7 @@ import GHC.TypeLits (Symbol) -- relatively efficient. data BasicAuth (realm :: Symbol) usr deriving (Typeable) + +-- | A generalized Authentication combinator. +data AuthProtect tag usr + deriving (Typeable)