Add instances for auth in servant-mock

This commit is contained in:
Arian van Putten 2015-10-03 01:35:27 +02:00
parent 81f48c6b14
commit 9ccb7203e4
3 changed files with 38 additions and 7 deletions

View file

@ -5,6 +5,7 @@
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-orphans #-} {-# OPTIONS_GHC -fno-warn-orphans #-}
-- | -- |
-- Module : Servant.Mock -- Module : Servant.Mock
@ -61,10 +62,13 @@ import GHC.TypeLits
import Network.HTTP.Types.Status import Network.HTTP.Types.Status
import Network.Wai import Network.Wai
import Servant import Servant
import Servant.API.Authentication
import Servant.API.ContentTypes import Servant.API.ContentTypes
import Servant.Server.Internal.Authentication
import Test.QuickCheck.Arbitrary (Arbitrary (..), vector) import Test.QuickCheck.Arbitrary (Arbitrary (..), vector)
import Test.QuickCheck.Gen (Gen, generate) import Test.QuickCheck.Gen (Gen, generate)
-- | 'HasMock' defines an interpretation of API types -- | 'HasMock' defines an interpretation of API types
-- than turns them into random-response-generating -- than turns them into random-response-generating
-- request handlers, hence providing an instance for -- request handlers, hence providing an instance for
@ -150,6 +154,21 @@ instance (KnownSymbol s, HasMock rest) => HasMock (MatrixFlag s :> rest) where
instance (KnownSymbol h, FromText a, HasMock rest) => HasMock (Header h a :> rest) where instance (KnownSymbol h, FromText a, HasMock rest) => HasMock (Header h a :> rest) where
mock _ = \_ -> mock (Proxy :: Proxy rest) mock _ = \_ -> mock (Proxy :: Proxy rest)
instance (HasMock rest, AuthData authdata, Arbitrary usr) => HasMock (AuthProtect authdata (usr :: *) 'Lax :> rest) where
mock _ = laxProtect (\_ -> do { a <- generate arbitrary; return (Just a)}) (\_ -> mock (Proxy :: Proxy rest))
instance (HasMock rest, Arbitrary usr, KnownSymbol realm)
=> HasMock (AuthProtect (BasicAuth realm) (usr :: *) 'Strict :> rest) where
mock _ = basicAuthStrict (\_ -> do { a <- generate arbitrary; return (Just a)})
(\_ -> mock (Proxy :: Proxy rest))
instance (HasMock rest, Arbitrary usr)
=> HasMock (AuthProtect JWTAuth (usr :: *) 'Strict :> rest) where
mock _ = strictProtect (\_ -> do { a <- generate arbitrary; return (Just a)})
(AuthHandlers (return authFailure) ((const . return) authFailure))
(\_ -> mock (Proxy :: Proxy rest))
where authFailure = responseBuilder status401 [] mempty
instance (Arbitrary a, AllCTRender ctypes a) => HasMock (Delete ctypes a) where instance (Arbitrary a, AllCTRender ctypes a) => HasMock (Delete ctypes a) where
mock _ = mockArbitrary mock _ = mockArbitrary

View file

@ -30,16 +30,17 @@ import Data.String (fromString)
import Data.Word8 (isSpace, toLower, _colon) import Data.Word8 (isSpace, toLower, _colon)
import GHC.TypeLits (KnownSymbol, symbolVal) import GHC.TypeLits (KnownSymbol, symbolVal)
import Data.Text.Encoding (decodeUtf8) import Data.Text.Encoding (decodeUtf8)
import Data.Text (splitOn) import Data.Text (splitOn, Text)
import Network.HTTP.Types.Status (status401) import Network.HTTP.Types.Status (status401)
import Network.Wai (Request, Response, requestHeaders, import Network.Wai (Request, Response, requestHeaders,
responseBuilder) responseBuilder)
import Servant.API.Authentication (AuthPolicy (Strict, Lax), import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtected, AuthProtected,
BasicAuth (BasicAuth)) BasicAuth (BasicAuth),
JWTAuth)
import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT, Secret, JSON) import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT, Secret, JSON)
import qualified Web.JWT as JWT (decode, verify) import qualified Web.JWT as JWT (decode, verify, secret)
-- | Class to represent the ability to extract authentication-related -- | Class to represent the ability to extract authentication-related
-- data from a 'Request' object. -- data from a 'Request' object.
@ -116,7 +117,7 @@ basicAuthLax = laxProtect
instance AuthData JSON where instance AuthData JWTAuth where
authData req = do authData req = do
-- We might want to write a proper parser for this? but split works fine... -- We might want to write a proper parser for this? but split works fine...
hdr <- lookup "Authorization" . requestHeaders $ req hdr <- lookup "Authorization" . requestHeaders $ req
@ -135,6 +136,7 @@ jwtAuthHandlers =
-- Use this to quickly add jwt authentication to your project. -- Use this to quickly add jwt authentication to your project.
-- One can use strictProtect and laxProtect to make more complex authentication -- One can use strictProtect and laxProtect to make more complex authentication
-- and authorization schemes. For an example of that, see our tutorial: @placeholder@ -- and authorization schemes. For an example of that, see our tutorial: @placeholder@
jwtAuth :: Secret -> subserver -> AuthProtected JSON (JWT VerifiedJWT) subserver 'Strict -- TODO more advanced one
jwtAuth secret subserver = strictProtect (return . (JWT.verify secret <=< JWT.decode)) jwtAuthHandlers subserver jwtAuth :: Text -> subserver -> AuthProtected JSON (JWT VerifiedJWT) subserver 'Strict
jwtAuth secret subserver = strictProtect (return . (JWT.verify (JWT.secret secret) <=< JWT.decode)) jwtAuthHandlers subserver

View file

@ -4,11 +4,19 @@
{-# LANGUAGE PolyKinds #-} {-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_HADDOCK not-home #-} {-# OPTIONS_HADDOCK not-home #-}
module Servant.API.Authentication where module Servant.API.Authentication
( AuthPolicy (..)
, AuthProtect (..)
, AuthProtected (..)
, BasicAuth (..)
, JWTAuth
) where
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Data.Typeable (Typeable) import Data.Typeable (Typeable)
import GHC.TypeLits (Symbol) import GHC.TypeLits (Symbol)
import Data.Text (Text)
-- | we can be either Strict or Lax. -- | we can be either Strict or Lax.
-- Strict: all handlers under 'AuthProtect' take a 'usr' argument. -- Strict: all handlers under 'AuthProtect' take a 'usr' argument.
@ -29,3 +37,5 @@ data family AuthProtected authdata usr subserver :: AuthPolicy -> *
data BasicAuth (realm :: Symbol) = BasicAuth { baUser :: ByteString data BasicAuth (realm :: Symbol) = BasicAuth { baUser :: ByteString
, baPass :: ByteString , baPass :: ByteString
} deriving (Eq, Show, Typeable) } deriving (Eq, Show, Typeable)
type JWTAuth = Text