Add gen. authentication support to servant-server

This commit is contained in:
aaron levin 2016-02-17 21:21:57 +01:00
parent 038abb433d
commit 0461c4642d
5 changed files with 95 additions and 6 deletions

View File

@ -38,6 +38,7 @@ library
Servant.Server
Servant.Server.Internal
Servant.Server.Internal.Context
Servant.Server.Internal.Auth
Servant.Server.Internal.BasicAuth
Servant.Server.Internal.Enter
Servant.Server.Internal.Router

View File

@ -45,11 +45,15 @@ module Servant.Server
, NamedContext(..)
, descendIntoNamedContext
-- * Basic Authentication
, BasicAuthCheck(BasicAuthCheck, unBasicAuthCheck)
, BasicAuthResult(..)
-- * General Authentication
, AuthHandler(unAuthHandler)
, AuthServerData
, mkAuthHandler
-- * Default error type
, ServantErr(..)
-- ** 3XX

View File

@ -10,11 +10,13 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
#include "overlapping-compat.h"
module Servant.Server.Internal
( module Servant.Server.Internal
, module Servant.Server.Internal.Auth
, module Servant.Server.Internal.Context
, module Servant.Server.Internal.BasicAuth
, module Servant.Server.Internal.Router
@ -25,7 +27,7 @@ module Servant.Server.Internal
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
#endif
import Control.Monad.Trans.Except (ExceptT)
import Control.Monad.Trans.Except (ExceptT, runExceptT)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC8
import qualified Data.ByteString.Lazy as BL
@ -50,7 +52,7 @@ import Web.HttpApiData.Internal (parseHeaderMaybe,
parseQueryParamMaybe,
parseUrlPieceMaybe)
import Servant.API ((:<|>) (..), (:>), BasicAuth, Capture,
import Servant.API ((:<|>) (..), (:>), AuthProtect, BasicAuth, Capture,
Verb, ReflectMethod(reflectMethod),
IsSecure(..), Header,
QueryFlag, QueryParam, QueryParams,
@ -64,6 +66,7 @@ import Servant.API.ContentTypes (AcceptHeader (..),
import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders,
getResponse)
import Servant.Server.Internal.Auth
import Servant.Server.Internal.Context
import Servant.Server.Internal.BasicAuth
import Servant.Server.Internal.Router
@ -482,6 +485,22 @@ pathIsEmpty = go . pathInfo
ct_wildcard :: B.ByteString
ct_wildcard = "*" <> "/" <> "*" -- Because CPP
-- * General Authentication
instance ( HasServer api context
, HasContextEntry context (AuthHandler Request (AuthServerData (AuthProtect tag)))
)
=> HasServer (AuthProtect tag :> api) context where
type ServerT (AuthProtect tag :> api) m =
AuthServerData (AuthProtect tag) -> ServerT api m
route Proxy context subserver = WithRequest $ \ request ->
route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request)
where
authHandler = unAuthHandler (getContextEntry context)
authCheck = fmap (either FailFatal Route) . runExceptT . authHandler
-- * contexts
instance (HasContextEntry context (NamedContext name subContext), HasServer subApi subContext)

View File

@ -0,0 +1,27 @@
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Servant.Server.Internal.Auth where
import Control.Monad.Trans.Except (ExceptT)
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
import Servant.Server.Internal.ServantErr (ServantErr)
-- * General Auth
-- | Specify the type of data returned after we've authenticated a request.
-- quite often this is some `User` datatype.
type family AuthServerData a :: *
-- | Handlers for AuthProtected resources
newtype AuthHandler r usr = AuthHandler
{ unAuthHandler :: r -> ExceptT ServantErr IO usr }
deriving (Generic, Typeable)
mkAuthHandler :: (r -> ExceptT ServantErr IO usr) -> AuthHandler r usr
mkAuthHandler = AuthHandler

View File

@ -31,14 +31,15 @@ import Network.HTTP.Types (Status (..), hAccept, hContentType,
methodHead, methodPatch,
methodPost, methodPut, ok200,
parseQuery)
import Network.Wai (Application, Request, pathInfo,
import Network.Wai (Application, Request, requestHeaders, pathInfo,
queryString, rawQueryString,
responseBuilder, responseLBS)
import Network.Wai.Internal (Response (ResponseBuilder))
import Network.Wai.Test (defaultRequest, request,
runSession, simpleBody,
simpleHeaders, simpleStatus)
import Servant.API ((:<|>) (..), (:>), BasicAuth, BasicAuthData(BasicAuthData),
import Servant.API ((:<|>) (..), (:>), AuthProtect,
BasicAuth, BasicAuthData(BasicAuthData),
Capture, Delete, Get, Header (..),
Headers, HttpVersion,
IsSecure (..), JSON,
@ -59,6 +60,9 @@ import Test.Hspec.Wai (get, liftIO, matchHeaders,
import Servant.Server.Internal.BasicAuth (BasicAuthCheck(BasicAuthCheck),
BasicAuthResult(Authorized,Unauthorized))
import Servant.Server.Internal.Auth
(AuthHandler, AuthServerData,
mkAuthHandler)
import Servant.Server.Internal.RoutingApplication
(toApplication, RouteResult(..))
import Servant.Server.Internal.Router
@ -90,6 +94,7 @@ spec = do
routerSpec
miscCombinatorSpec
basicAuthSpec
genAuthSpec
------------------------------------------------------------------------------
-- * verbSpec {{{
@ -534,7 +539,7 @@ miscCombinatorSpec = with (return $ serve miscApi miscServ) $
-- }}}
------------------------------------------------------------------------------
-- * Authentication {{{
-- * Basic Authentication {{{
------------------------------------------------------------------------------
type BasicAuthAPI = BasicAuth "foo" () :> "basic" :> Get '[JSON] Animal
@ -564,6 +569,39 @@ basicAuthSpec = do
it "returns 200 with the right password" $ do
THW.request methodGet "/basic" [("Authorization","Basic c2VydmFudDpzZXJ2ZXI=")] "" `shouldRespondWith` 200
-- }}}
------------------------------------------------------------------------------
-- * General Authentication {{{
------------------------------------------------------------------------------
type GenAuthAPI = AuthProtect "auth" :> "auth" :> Get '[JSON] Animal
authApi :: Proxy GenAuthAPI
authApi = Proxy
authServer :: Server GenAuthAPI
authServer = const (return tweety)
type instance AuthServerData (AuthProtect "auth") = ()
genAuthContext :: Context '[ AuthHandler Request () ]
genAuthContext =
let authHandler = (\req ->
if elem ("Auth", "secret") (requestHeaders req)
then return ()
else throwE err401
)
in mkAuthHandler authHandler :. EmptyContext
genAuthSpec :: Spec
genAuthSpec = do
describe "Servant.API.Auth" $ do
with (return (serveWithContext authApi genAuthContext authServer)) $ do
context "Custom Auth Protection" $ do
it "returns 401 when missing headers" $ do
get "/auth" `shouldRespondWith` 401
it "returns 200 with the right header" $ do
THW.request methodGet "/auth" [("Auth","secret")] "" `shouldRespondWith` 200
-- }}}
------------------------------------------------------------------------------
-- * Test data types {{{