Add the generalized auth combinator

This commit is contained in:
aaron levin 2016-01-08 11:35:26 +01:00
parent 38c3cb7045
commit 0764d9b84c
4 changed files with 77 additions and 37 deletions

View file

@ -13,6 +13,7 @@
module Servant.Server.Internal module Servant.Server.Internal
( module Servant.Server.Internal ( module Servant.Server.Internal
, module Servant.Server.Internal.Auth
, module Servant.Server.Internal.Config , module Servant.Server.Internal.Config
, module Servant.Server.Internal.Router , module Servant.Server.Internal.Router
, module Servant.Server.Internal.RoutingApplication , module Servant.Server.Internal.RoutingApplication
@ -22,44 +23,59 @@ module Servant.Server.Internal
#if !MIN_VERSION_base(4,8,0) #if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>)) import Control.Applicative ((<$>))
#endif #endif
import Control.Monad.Trans.Except (ExceptT) import Control.Monad.Trans.Except (ExceptT, runExceptT)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC8 import qualified Data.ByteString.Char8 as BC8
import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy as BL
import qualified Data.Map as M import qualified Data.Map as M
import Data.Maybe (fromMaybe, mapMaybe) import Data.Maybe (fromMaybe,
import Data.String (fromString) mapMaybe)
import Data.String.Conversions (cs, (<>)) import Data.String (fromString)
import Data.Text (Text) import Data.String.Conversions (cs, (<>))
import Data.Text (Text)
import Data.Typeable import Data.Typeable
import GHC.Exts (Constraint) import GHC.Exts (Constraint)
import GHC.TypeLits (KnownNat, KnownSymbol, natVal, import GHC.TypeLits (KnownNat,
symbolVal) KnownSymbol,
import Network.HTTP.Types hiding (Header, ResponseHeaders) natVal, symbolVal)
import Network.Socket (SockAddr) import Network.HTTP.Types hiding (Header,
import Network.Wai (Application, Request, Response, ResponseHeaders)
httpVersion, isSecure, import Network.Socket (SockAddr)
lazyRequestBody, pathInfo, import Network.Wai (Application,
rawQueryString, remoteHost, Request, Response,
requestHeaders, requestMethod, httpVersion,
responseLBS, vault) isSecure,
import Web.HttpApiData (FromHttpApiData) lazyRequestBody,
import Web.HttpApiData.Internal (parseHeaderMaybe, pathInfo,
parseQueryParamMaybe, rawQueryString,
parseUrlPieceMaybe) remoteHost,
requestHeaders,
requestMethod,
responseLBS, vault)
import Web.HttpApiData (FromHttpApiData)
import Web.HttpApiData.Internal (parseHeaderMaybe, parseQueryParamMaybe,
parseUrlPieceMaybe)
import Servant.API ((:<|>) (..), (:>), Capture, BasicAuth, import Servant.API ((:<|>) (..), (:>),
Verb, ReflectMethod(reflectMethod), AuthProtect,
IsSecure(..), Header, BasicAuth, Capture,
QueryFlag, QueryParam, QueryParams, Header,
Raw, RemoteHost, ReqBody, Vault) IsSecure (..),
import Servant.API.ContentTypes (AcceptHeader (..), QueryFlag,
AllCTRender (..), QueryParam,
AllCTUnrender (..), QueryParams, Raw, ReflectMethod (reflectMethod),
AllMime, RemoteHost,
canHandleAcceptH) ReqBody, Vault,
import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders, Verb)
getResponse) import Servant.API.ContentTypes (AcceptHeader (..),
AllCTRender (..),
AllCTUnrender (..),
AllMime,
canHandleAcceptH)
import Servant.API.ResponseHeaders (GetHeaders,
Headers,
getHeaders,
getResponse)
import Servant.Server.Internal.Auth import Servant.Server.Internal.Auth
import Servant.Server.Internal.Config import Servant.Server.Internal.Config
@ -466,6 +482,7 @@ instance HasServer api => HasServer (HttpVersion :> api) where
route Proxy cfg subserver = WithRequest $ \req -> route Proxy cfg subserver = WithRequest $ \req ->
route (Proxy :: Proxy api) cfg (passToServer subserver $ httpVersion req) route (Proxy :: Proxy api) cfg (passToServer subserver $ httpVersion req)
-- | Basic Authentication
instance (KnownSymbol realm, HasServer api) instance (KnownSymbol realm, HasServer api)
=> HasServer (BasicAuth realm usr :> api) where => HasServer (BasicAuth realm usr :> api) where
type ServerT (BasicAuth realm usr :> api) m = usr -> ServerT api m type ServerT (BasicAuth realm usr :> api) m = usr -> ServerT api m
@ -479,6 +496,17 @@ instance (KnownSymbol realm, HasServer api)
baCfg = getConfigEntry (Proxy :: Proxy realm) cfg baCfg = getConfigEntry (Proxy :: Proxy realm) cfg
authCheck req = runBasicAuth req realm baCfg authCheck req = runBasicAuth req realm baCfg
-- | General Authentication
instance HasServer api => HasServer (AuthProtect tag usr :> api) where
type ServerT (AuthProtect tag usr :> api) m = usr -> ServerT api m
type HasCfg (AuthProtect tag usr :> api) c
= (HasConfigEntry c tag (AuthHandler Request usr), HasCfg api c)
route Proxy cfg subserver = WithRequest $ \ request ->
route (Proxy :: Proxy api) cfg (subserver `addAuthCheck` authCheck request)
where
authHandler = unAuthHandler (getConfigEntry (Proxy :: Proxy tag) cfg)
authCheck = fmap (either FailFatal Route) . runExceptT . authHandler
pathIsEmpty :: Request -> Bool pathIsEmpty :: Request -> Bool
pathIsEmpty = go . pathInfo pathIsEmpty = go . pathInfo

View file

@ -4,6 +4,7 @@
module Servant.Server.Internal.Auth where module Servant.Server.Internal.Auth where
import Control.Monad (guard) import Control.Monad (guard)
import Control.Monad.Trans.Except (ExceptT)
import qualified Data.ByteString as BS import qualified Data.ByteString as BS
import Data.ByteString.Base64 (decodeLenient) import Data.ByteString.Base64 (decodeLenient)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
@ -18,6 +19,13 @@ import Servant.Server.Internal.ServantErr
-- * General Auth -- * General Auth
-- | Handlers for AuthProtected resources
newtype AuthHandler r usr = AuthHandler
{ unAuthHandler :: r -> ExceptT ServantErr IO usr }
mkAuthHandler :: (r -> ExceptT ServantErr IO usr) -> AuthHandler r usr
mkAuthHandler = AuthHandler
-- | The result of authentication/authorization -- | The result of authentication/authorization
data AuthResult usr data AuthResult usr
= Unauthorized = Unauthorized

View file

@ -52,7 +52,7 @@ module Servant.API (
) where ) where
import Servant.API.Alternative ((:<|>) (..)) import Servant.API.Alternative ((:<|>) (..))
import Servant.API.Auth (BasicAuth) import Servant.API.Auth (BasicAuth, AuthProtect)
import Servant.API.Capture (Capture) import Servant.API.Capture (Capture)
import Servant.API.ContentTypes (Accept (..), FormUrlEncoded, import Servant.API.ContentTypes (Accept (..), FormUrlEncoded,
FromFormUrlEncoded (..), JSON, FromFormUrlEncoded (..), JSON,

View file

@ -18,3 +18,7 @@ import GHC.TypeLits (Symbol)
-- relatively efficient. -- relatively efficient.
data BasicAuth (realm :: Symbol) usr data BasicAuth (realm :: Symbol) usr
deriving (Typeable) deriving (Typeable)
-- | A generalized Authentication combinator.
data AuthProtect tag usr
deriving (Typeable)