diff --git a/servant-mock/src/Servant/Mock.hs b/servant-mock/src/Servant/Mock.hs index 1bd93a04..5489d33e 100644 --- a/servant-mock/src/Servant/Mock.hs +++ b/servant-mock/src/Servant/Mock.hs @@ -5,6 +5,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Servant.Mock @@ -61,10 +62,13 @@ import GHC.TypeLits import Network.HTTP.Types.Status import Network.Wai import Servant +import Servant.API.Authentication import Servant.API.ContentTypes +import Servant.Server.Internal.Authentication import Test.QuickCheck.Arbitrary (Arbitrary (..), vector) import Test.QuickCheck.Gen (Gen, generate) + -- | 'HasMock' defines an interpretation of API types -- than turns them into random-response-generating -- request handlers, hence providing an instance for @@ -150,6 +154,21 @@ instance (KnownSymbol s, HasMock rest) => HasMock (MatrixFlag s :> rest) where instance (KnownSymbol h, FromText a, HasMock rest) => HasMock (Header h a :> rest) where mock _ = \_ -> mock (Proxy :: Proxy rest) +instance (HasMock rest, AuthData authdata, Arbitrary usr) => HasMock (AuthProtect authdata (usr :: *) 'Lax :> rest) where + mock _ = laxProtect (\_ -> do { a <- generate arbitrary; return (Just a)}) (\_ -> mock (Proxy :: Proxy rest)) + +instance (HasMock rest, Arbitrary usr, KnownSymbol realm) + => HasMock (AuthProtect (BasicAuth realm) (usr :: *) 'Strict :> rest) where + mock _ = basicAuthStrict (\_ -> do { a <- generate arbitrary; return (Just a)}) + (\_ -> mock (Proxy :: Proxy rest)) + +instance (HasMock rest, Arbitrary usr) + => HasMock (AuthProtect JWTAuth (usr :: *) 'Strict :> rest) where + mock _ = strictProtect (\_ -> do { a <- generate arbitrary; return (Just a)}) + (AuthHandlers (return authFailure) ((const . return) authFailure)) + (\_ -> mock (Proxy :: Proxy rest)) + where authFailure = responseBuilder status401 [] mempty + instance (Arbitrary a, AllCTRender ctypes a) => HasMock (Delete ctypes a) where mock _ = mockArbitrary diff --git a/servant-server/src/Servant/Server/Internal/Authentication.hs b/servant-server/src/Servant/Server/Internal/Authentication.hs index e14d3625..2c0d16e6 100644 --- a/servant-server/src/Servant/Server/Internal/Authentication.hs +++ b/servant-server/src/Servant/Server/Internal/Authentication.hs @@ -30,16 +30,17 @@ import Data.String (fromString) import Data.Word8 (isSpace, toLower, _colon) import GHC.TypeLits (KnownSymbol, symbolVal) import Data.Text.Encoding (decodeUtf8) -import Data.Text (splitOn) +import Data.Text (splitOn, Text) import Network.HTTP.Types.Status (status401) import Network.Wai (Request, Response, requestHeaders, responseBuilder) import Servant.API.Authentication (AuthPolicy (Strict, Lax), AuthProtected, - BasicAuth (BasicAuth)) + BasicAuth (BasicAuth), + JWTAuth) import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT, Secret, JSON) -import qualified Web.JWT as JWT (decode, verify) +import qualified Web.JWT as JWT (decode, verify, secret) -- | Class to represent the ability to extract authentication-related -- data from a 'Request' object. @@ -116,7 +117,7 @@ basicAuthLax = laxProtect -instance AuthData JSON where +instance AuthData JWTAuth where authData req = do -- We might want to write a proper parser for this? but split works fine... hdr <- lookup "Authorization" . requestHeaders $ req @@ -135,6 +136,7 @@ jwtAuthHandlers = -- Use this to quickly add jwt authentication to your project. -- One can use strictProtect and laxProtect to make more complex authentication -- and authorization schemes. For an example of that, see our tutorial: @placeholder@ -jwtAuth :: Secret -> subserver -> AuthProtected JSON (JWT VerifiedJWT) subserver 'Strict -jwtAuth secret subserver = strictProtect (return . (JWT.verify secret <=< JWT.decode)) jwtAuthHandlers subserver +-- TODO more advanced one +jwtAuth :: Text -> subserver -> AuthProtected JSON (JWT VerifiedJWT) subserver 'Strict +jwtAuth secret subserver = strictProtect (return . (JWT.verify (JWT.secret secret) <=< JWT.decode)) jwtAuthHandlers subserver diff --git a/servant/src/Servant/API/Authentication.hs b/servant/src/Servant/API/Authentication.hs index 4b6c9981..207662b3 100644 --- a/servant/src/Servant/API/Authentication.hs +++ b/servant/src/Servant/API/Authentication.hs @@ -4,11 +4,19 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_HADDOCK not-home #-} -module Servant.API.Authentication where +module Servant.API.Authentication +( AuthPolicy (..) +, AuthProtect (..) +, AuthProtected (..) +, BasicAuth (..) +, JWTAuth +) where + import Data.ByteString (ByteString) import Data.Typeable (Typeable) import GHC.TypeLits (Symbol) +import Data.Text (Text) -- | we can be either Strict or Lax. -- Strict: all handlers under 'AuthProtect' take a 'usr' argument. @@ -29,3 +37,5 @@ data family AuthProtected authdata usr subserver :: AuthPolicy -> * data BasicAuth (realm :: Symbol) = BasicAuth { baUser :: ByteString , baPass :: ByteString } deriving (Eq, Show, Typeable) + +type JWTAuth = Text