Add gen. authentication support to servant-server
This commit is contained in:
parent
038abb433d
commit
0461c4642d
5 changed files with 95 additions and 6 deletions
|
@ -38,6 +38,7 @@ library
|
||||||
Servant.Server
|
Servant.Server
|
||||||
Servant.Server.Internal
|
Servant.Server.Internal
|
||||||
Servant.Server.Internal.Context
|
Servant.Server.Internal.Context
|
||||||
|
Servant.Server.Internal.Auth
|
||||||
Servant.Server.Internal.BasicAuth
|
Servant.Server.Internal.BasicAuth
|
||||||
Servant.Server.Internal.Enter
|
Servant.Server.Internal.Enter
|
||||||
Servant.Server.Internal.Router
|
Servant.Server.Internal.Router
|
||||||
|
|
|
@ -45,11 +45,15 @@ module Servant.Server
|
||||||
, NamedContext(..)
|
, NamedContext(..)
|
||||||
, descendIntoNamedContext
|
, descendIntoNamedContext
|
||||||
|
|
||||||
|
|
||||||
-- * Basic Authentication
|
-- * Basic Authentication
|
||||||
, BasicAuthCheck(BasicAuthCheck, unBasicAuthCheck)
|
, BasicAuthCheck(BasicAuthCheck, unBasicAuthCheck)
|
||||||
, BasicAuthResult(..)
|
, BasicAuthResult(..)
|
||||||
|
|
||||||
|
-- * General Authentication
|
||||||
|
, AuthHandler(unAuthHandler)
|
||||||
|
, AuthServerData
|
||||||
|
, mkAuthHandler
|
||||||
|
|
||||||
-- * Default error type
|
-- * Default error type
|
||||||
, ServantErr(..)
|
, ServantErr(..)
|
||||||
-- ** 3XX
|
-- ** 3XX
|
||||||
|
|
|
@ -10,11 +10,13 @@
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
|
|
||||||
#include "overlapping-compat.h"
|
#include "overlapping-compat.h"
|
||||||
|
|
||||||
module Servant.Server.Internal
|
module Servant.Server.Internal
|
||||||
( module Servant.Server.Internal
|
( module Servant.Server.Internal
|
||||||
|
, module Servant.Server.Internal.Auth
|
||||||
, module Servant.Server.Internal.Context
|
, module Servant.Server.Internal.Context
|
||||||
, module Servant.Server.Internal.BasicAuth
|
, module Servant.Server.Internal.BasicAuth
|
||||||
, module Servant.Server.Internal.Router
|
, module Servant.Server.Internal.Router
|
||||||
|
@ -25,7 +27,7 @@ module Servant.Server.Internal
|
||||||
#if !MIN_VERSION_base(4,8,0)
|
#if !MIN_VERSION_base(4,8,0)
|
||||||
import Control.Applicative ((<$>))
|
import Control.Applicative ((<$>))
|
||||||
#endif
|
#endif
|
||||||
import Control.Monad.Trans.Except (ExceptT)
|
import Control.Monad.Trans.Except (ExceptT, runExceptT)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Char8 as BC8
|
import qualified Data.ByteString.Char8 as BC8
|
||||||
import qualified Data.ByteString.Lazy as BL
|
import qualified Data.ByteString.Lazy as BL
|
||||||
|
@ -50,7 +52,7 @@ import Web.HttpApiData.Internal (parseHeaderMaybe,
|
||||||
parseQueryParamMaybe,
|
parseQueryParamMaybe,
|
||||||
parseUrlPieceMaybe)
|
parseUrlPieceMaybe)
|
||||||
|
|
||||||
import Servant.API ((:<|>) (..), (:>), BasicAuth, Capture,
|
import Servant.API ((:<|>) (..), (:>), AuthProtect, BasicAuth, Capture,
|
||||||
Verb, ReflectMethod(reflectMethod),
|
Verb, ReflectMethod(reflectMethod),
|
||||||
IsSecure(..), Header,
|
IsSecure(..), Header,
|
||||||
QueryFlag, QueryParam, QueryParams,
|
QueryFlag, QueryParam, QueryParams,
|
||||||
|
@ -64,6 +66,7 @@ import Servant.API.ContentTypes (AcceptHeader (..),
|
||||||
import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders,
|
import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders,
|
||||||
getResponse)
|
getResponse)
|
||||||
|
|
||||||
|
import Servant.Server.Internal.Auth
|
||||||
import Servant.Server.Internal.Context
|
import Servant.Server.Internal.Context
|
||||||
import Servant.Server.Internal.BasicAuth
|
import Servant.Server.Internal.BasicAuth
|
||||||
import Servant.Server.Internal.Router
|
import Servant.Server.Internal.Router
|
||||||
|
@ -482,6 +485,22 @@ pathIsEmpty = go . pathInfo
|
||||||
ct_wildcard :: B.ByteString
|
ct_wildcard :: B.ByteString
|
||||||
ct_wildcard = "*" <> "/" <> "*" -- Because CPP
|
ct_wildcard = "*" <> "/" <> "*" -- Because CPP
|
||||||
|
|
||||||
|
-- * General Authentication
|
||||||
|
|
||||||
|
instance ( HasServer api context
|
||||||
|
, HasContextEntry context (AuthHandler Request (AuthServerData (AuthProtect tag)))
|
||||||
|
)
|
||||||
|
=> HasServer (AuthProtect tag :> api) context where
|
||||||
|
|
||||||
|
type ServerT (AuthProtect tag :> api) m =
|
||||||
|
AuthServerData (AuthProtect tag) -> ServerT api m
|
||||||
|
|
||||||
|
route Proxy context subserver = WithRequest $ \ request ->
|
||||||
|
route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request)
|
||||||
|
where
|
||||||
|
authHandler = unAuthHandler (getContextEntry context)
|
||||||
|
authCheck = fmap (either FailFatal Route) . runExceptT . authHandler
|
||||||
|
|
||||||
-- * contexts
|
-- * contexts
|
||||||
|
|
||||||
instance (HasContextEntry context (NamedContext name subContext), HasServer subApi subContext)
|
instance (HasContextEntry context (NamedContext name subContext), HasServer subApi subContext)
|
||||||
|
|
27
servant-server/src/Servant/Server/Internal/Auth.hs
Normal file
27
servant-server/src/Servant/Server/Internal/Auth.hs
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
{-# LANGUAGE DeriveDataTypeable #-}
|
||||||
|
{-# LANGUAGE DeriveFunctor #-}
|
||||||
|
{-# LANGUAGE DeriveGeneric #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
|
||||||
|
module Servant.Server.Internal.Auth where
|
||||||
|
|
||||||
|
import Control.Monad.Trans.Except (ExceptT)
|
||||||
|
import Data.Typeable (Typeable)
|
||||||
|
import GHC.Generics (Generic)
|
||||||
|
|
||||||
|
import Servant.Server.Internal.ServantErr (ServantErr)
|
||||||
|
|
||||||
|
-- * General Auth
|
||||||
|
|
||||||
|
-- | Specify the type of data returned after we've authenticated a request.
|
||||||
|
-- quite often this is some `User` datatype.
|
||||||
|
type family AuthServerData a :: *
|
||||||
|
|
||||||
|
-- | Handlers for AuthProtected resources
|
||||||
|
newtype AuthHandler r usr = AuthHandler
|
||||||
|
{ unAuthHandler :: r -> ExceptT ServantErr IO usr }
|
||||||
|
deriving (Generic, Typeable)
|
||||||
|
|
||||||
|
mkAuthHandler :: (r -> ExceptT ServantErr IO usr) -> AuthHandler r usr
|
||||||
|
mkAuthHandler = AuthHandler
|
|
@ -31,14 +31,15 @@ import Network.HTTP.Types (Status (..), hAccept, hContentType,
|
||||||
methodHead, methodPatch,
|
methodHead, methodPatch,
|
||||||
methodPost, methodPut, ok200,
|
methodPost, methodPut, ok200,
|
||||||
parseQuery)
|
parseQuery)
|
||||||
import Network.Wai (Application, Request, pathInfo,
|
import Network.Wai (Application, Request, requestHeaders, pathInfo,
|
||||||
queryString, rawQueryString,
|
queryString, rawQueryString,
|
||||||
responseBuilder, responseLBS)
|
responseBuilder, responseLBS)
|
||||||
import Network.Wai.Internal (Response (ResponseBuilder))
|
import Network.Wai.Internal (Response (ResponseBuilder))
|
||||||
import Network.Wai.Test (defaultRequest, request,
|
import Network.Wai.Test (defaultRequest, request,
|
||||||
runSession, simpleBody,
|
runSession, simpleBody,
|
||||||
simpleHeaders, simpleStatus)
|
simpleHeaders, simpleStatus)
|
||||||
import Servant.API ((:<|>) (..), (:>), BasicAuth, BasicAuthData(BasicAuthData),
|
import Servant.API ((:<|>) (..), (:>), AuthProtect,
|
||||||
|
BasicAuth, BasicAuthData(BasicAuthData),
|
||||||
Capture, Delete, Get, Header (..),
|
Capture, Delete, Get, Header (..),
|
||||||
Headers, HttpVersion,
|
Headers, HttpVersion,
|
||||||
IsSecure (..), JSON,
|
IsSecure (..), JSON,
|
||||||
|
@ -59,6 +60,9 @@ import Test.Hspec.Wai (get, liftIO, matchHeaders,
|
||||||
|
|
||||||
import Servant.Server.Internal.BasicAuth (BasicAuthCheck(BasicAuthCheck),
|
import Servant.Server.Internal.BasicAuth (BasicAuthCheck(BasicAuthCheck),
|
||||||
BasicAuthResult(Authorized,Unauthorized))
|
BasicAuthResult(Authorized,Unauthorized))
|
||||||
|
import Servant.Server.Internal.Auth
|
||||||
|
(AuthHandler, AuthServerData,
|
||||||
|
mkAuthHandler)
|
||||||
import Servant.Server.Internal.RoutingApplication
|
import Servant.Server.Internal.RoutingApplication
|
||||||
(toApplication, RouteResult(..))
|
(toApplication, RouteResult(..))
|
||||||
import Servant.Server.Internal.Router
|
import Servant.Server.Internal.Router
|
||||||
|
@ -90,6 +94,7 @@ spec = do
|
||||||
routerSpec
|
routerSpec
|
||||||
miscCombinatorSpec
|
miscCombinatorSpec
|
||||||
basicAuthSpec
|
basicAuthSpec
|
||||||
|
genAuthSpec
|
||||||
|
|
||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
-- * verbSpec {{{
|
-- * verbSpec {{{
|
||||||
|
@ -534,7 +539,7 @@ miscCombinatorSpec = with (return $ serve miscApi miscServ) $
|
||||||
|
|
||||||
-- }}}
|
-- }}}
|
||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
-- * Authentication {{{
|
-- * Basic Authentication {{{
|
||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
|
|
||||||
type BasicAuthAPI = BasicAuth "foo" () :> "basic" :> Get '[JSON] Animal
|
type BasicAuthAPI = BasicAuth "foo" () :> "basic" :> Get '[JSON] Animal
|
||||||
|
@ -564,6 +569,39 @@ basicAuthSpec = do
|
||||||
it "returns 200 with the right password" $ do
|
it "returns 200 with the right password" $ do
|
||||||
THW.request methodGet "/basic" [("Authorization","Basic c2VydmFudDpzZXJ2ZXI=")] "" `shouldRespondWith` 200
|
THW.request methodGet "/basic" [("Authorization","Basic c2VydmFudDpzZXJ2ZXI=")] "" `shouldRespondWith` 200
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
-- * General Authentication {{{
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type GenAuthAPI = AuthProtect "auth" :> "auth" :> Get '[JSON] Animal
|
||||||
|
authApi :: Proxy GenAuthAPI
|
||||||
|
authApi = Proxy
|
||||||
|
authServer :: Server GenAuthAPI
|
||||||
|
authServer = const (return tweety)
|
||||||
|
|
||||||
|
type instance AuthServerData (AuthProtect "auth") = ()
|
||||||
|
|
||||||
|
genAuthContext :: Context '[ AuthHandler Request () ]
|
||||||
|
genAuthContext =
|
||||||
|
let authHandler = (\req ->
|
||||||
|
if elem ("Auth", "secret") (requestHeaders req)
|
||||||
|
then return ()
|
||||||
|
else throwE err401
|
||||||
|
)
|
||||||
|
in mkAuthHandler authHandler :. EmptyContext
|
||||||
|
|
||||||
|
genAuthSpec :: Spec
|
||||||
|
genAuthSpec = do
|
||||||
|
describe "Servant.API.Auth" $ do
|
||||||
|
with (return (serveWithContext authApi genAuthContext authServer)) $ do
|
||||||
|
|
||||||
|
context "Custom Auth Protection" $ do
|
||||||
|
it "returns 401 when missing headers" $ do
|
||||||
|
get "/auth" `shouldRespondWith` 401
|
||||||
|
it "returns 200 with the right header" $ do
|
||||||
|
THW.request methodGet "/auth" [("Auth","secret")] "" `shouldRespondWith` 200
|
||||||
|
|
||||||
-- }}}
|
-- }}}
|
||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
-- * Test data types {{{
|
-- * Test data types {{{
|
||||||
|
|
Loading…
Reference in a new issue