Make auth tests

This commit is contained in:
Arian van Putten 2015-10-31 22:18:59 +01:00 committed by aaron levin
parent dc699a93e0
commit 667dbbc8cd
4 changed files with 107 additions and 33 deletions

View file

@ -124,6 +124,7 @@ test-suite spec
, wai , wai
, wai-extra , wai-extra
, warp , warp
, jwt
test-suite doctests test-suite doctests
build-depends: base build-depends: base

View file

@ -17,7 +17,7 @@ module Servant.Server.Internal.Authentication
, jwtAuthStrict , jwtAuthStrict
) where ) where
import Control.Monad (guard, (<=<)) import Control.Monad (guard)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import Data.ByteString.Base64 (decodeLenient) import Data.ByteString.Base64 (decodeLenient)
#if !MIN_VERSION_base(4,8,0) #if !MIN_VERSION_base(4,8,0)
@ -30,7 +30,7 @@ import Data.String (fromString)
import Data.Word8 (isSpace, toLower, _colon) import Data.Word8 (isSpace, toLower, _colon)
import GHC.TypeLits (KnownSymbol, symbolVal) import GHC.TypeLits (KnownSymbol, symbolVal)
import Data.Text.Encoding (decodeUtf8) import Data.Text.Encoding (decodeUtf8)
import Data.Text (splitOn, Text) import Data.Text (splitOn)
import Network.HTTP.Types.Status (status401) import Network.HTTP.Types.Status (status401)
import Network.Wai (Request, Response, requestHeaders, import Network.Wai (Request, Response, requestHeaders,
responseBuilder) responseBuilder)
@ -38,8 +38,10 @@ import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtected, AuthProtected,
BasicAuth (BasicAuth), BasicAuth (BasicAuth),
JWTAuth (..)) JWTAuth (..))
JWTAuth)
import Servant.Server.Internal.RoutingApplication (RouteResult(FailFatal))
import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT, Secret, JSON) import Web.JWT (JWT, VerifiedJWT, Secret)
import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret) import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret)
-- | Class to represent the ability to extract authentication-related -- | Class to represent the ability to extract authentication-related
@ -50,10 +52,10 @@ class AuthData a where
-- | handlers to deal with authentication failures. -- | handlers to deal with authentication failures.
data AuthHandlers authData = AuthHandlers data AuthHandlers authData = AuthHandlers
{ -- we couldn't find the right type of auth data (or any, for that matter) { -- we couldn't find the right type of auth data (or any, for that matter)
onMissingAuthData :: IO Response onMissingAuthData :: IO ServantError
, ,
-- we found the right type of auth data in the request but the check failed -- we found the right type of auth data in the request but the check failed
onUnauthenticated :: authData -> IO Response onUnauthenticated :: authData -> IO ServantError
} }
-- | concrete type to provide when in 'Strict' mode. -- | concrete type to provide when in 'Strict' mode.
@ -119,10 +121,9 @@ basicAuthLax = laxProtect
instance AuthData JWTAuth where instance AuthData JWTAuth where
authData req = do authData req = do
-- We might want to write a proper parser for this? but split works fine...
hdr <- lookup "Authorization" . requestHeaders $ req hdr <- lookup "Authorization" . requestHeaders $ req
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr ["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
JWT.decode token -- try decode it. otherwise it's not a proper token _ <- JWT.decode token -- try decode it. otherwise it's not a proper token
return . JWTAuth $ token return . JWTAuth $ token

View file

@ -1,4 +1,5 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
@ -19,6 +20,8 @@ import Network.Wai (Application, Request,
Response, ResponseReceived, Response, ResponseReceived,
requestBody, requestBody,
strictRequestBody) strictRequestBody)
import Servant.API.Authentication (AuthPolicy(Strict,Lax))
import Servant.Server.Internal.Authentication
import Servant.Server.Internal.ServantErr import Servant.Server.Internal.ServantErr
type RoutingApplication = type RoutingApplication =
@ -155,32 +158,60 @@ data Delayed :: * -> * where
Delayed :: IO (RouteResult a) Delayed :: IO (RouteResult a)
-> IO (RouteResult ()) -> IO (RouteResult ())
-> IO (RouteResult b) -> IO (RouteResult b)
-> (a -> b -> RouteResult c) -> IO (RouteResult c)
-> Delayed c -> (a -> b -> c -> RouteResult d)
-> Delayed d
instance Functor Delayed where instance Functor Delayed where
fmap f (Delayed a b c g) = Delayed a b c ((fmap.fmap.fmap) f g) fmap f (Delayed a b c d g) = Delayed a b c d ((fmap.fmap.fmap.fmap) f g)
-- | Add a capture to the end of the capture block. -- | Add a capture to the end of the capture block.
addCapture :: Delayed (a -> b) addCapture :: Delayed (a -> b)
-> IO (RouteResult a) -> IO (RouteResult a)
-> Delayed b -> Delayed b
addCapture (Delayed captures method body server) new = addCapture (Delayed captures method auth body server) new =
Delayed (combineRouteResults (,) captures new) method body (\ (x, v) y -> ($ v) <$> server x y) Delayed (combineRouteResults (,) captures new) method auth body (\ (x, v) y z -> ($ v) <$> server x y z)
-- | Add a method check to the end of the method block. -- | Add a method check to the end of the method block.
addMethodCheck :: Delayed a addMethodCheck :: Delayed a
-> IO (RouteResult ()) -> IO (RouteResult ())
-> Delayed a -> Delayed a
addMethodCheck (Delayed captures method body server) new = addMethodCheck (Delayed captures method auth body server) new =
Delayed captures (combineRouteResults const method new) body server Delayed captures (combineRouteResults const method new) auth body server
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
-> IO (RouteResult (Maybe usr))
-> Delayed a
addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
let newAuth = runDelayed delayed `bindRouteResults` \ authProtectionStrict -> new `bindRouteResults` \ mAuthData -> case mAuthData of
Nothing -> do
-- we're in strict mode: don't let the request go
-- call the provided "on missing auth" handler
resp <- onMissingAuthData (authHandlers authProtectionStrict)
return $ FailFatal resp
-- successfully pulled auth data out of the request
Just authData -> do
mUsr <- (checkAuthStrict authProtectionStrict) authData
case mUsr of
-- this user is not authenticated
Nothing -> do
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData
return $ FailFatal resp
-- this user is authenticated
Just usr ->
(return . Route . subServerStrict authProtectionStrict) usr
in Delayed captures method newAuth body server
-- | Add a body check to the end of the body block. -- | Add a body check to the end of the body block.
addBodyCheck :: Delayed (a -> b) addBodyCheck :: Delayed (a -> b)
-> IO (RouteResult a) -> IO (RouteResult a)
-> Delayed b -> Delayed b
addBodyCheck (Delayed captures method body server) new = addBodyCheck (Delayed captures method auth body server) new =
Delayed captures method (combineRouteResults (,) body new) (\ x (y, v) -> ($ v) <$> server x y) Delayed captures method auth (combineRouteResults (,) body new) (\ x (y, v) z -> ($ v) <$> server x y z)
-- | Add an accept header check to the end of the body block. -- | Add an accept header check to the end of the body block.
-- The accept header check should occur after the body check, -- The accept header check should occur after the body check,
@ -189,8 +220,8 @@ addBodyCheck (Delayed captures method body server) new =
addAcceptCheck :: Delayed a addAcceptCheck :: Delayed a
-> IO (RouteResult ()) -> IO (RouteResult ())
-> Delayed a -> Delayed a
addAcceptCheck (Delayed captures method body server) new = addAcceptCheck (Delayed captures method auth body server) new =
Delayed captures method (combineRouteResults const body new) server Delayed captures method auth (combineRouteResults const body new) server
-- | Many combinators extract information that is passed to -- | Many combinators extract information that is passed to
-- the handler without the possibility of failure. In such a -- the handler without the possibility of failure. In such a

View file

@ -67,6 +67,7 @@ import Servant.API.Authentication
import Servant.Server.Internal.Authentication import Servant.Server.Internal.Authentication
import Servant.Server (Server, serve, ServantErr(..), err404) import Servant.Server (Server, serve, ServantErr(..), err404)
import Servant.Server.Internal (RouteMismatch (..)) import Servant.Server.Internal (RouteMismatch (..))
import Web.JWT hiding (JSON)
-- * test data types -- * test data types
@ -116,7 +117,8 @@ spec = do
routerSpec routerSpec
responseHeadersSpec responseHeadersSpec
miscReqCombinatorsSpec miscReqCombinatorsSpec
authRequiredSpec basicAuthRequiredSpec
jwtAuthRequiredSpec
type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal
@ -714,7 +716,7 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $
type AuthUser = ByteString type AuthUser = ByteString
type BasicAuthFooRealm = AuthProtect (BasicAuth "foo-realm") AuthUser 'Strict type BasicAuthFooRealm = AuthProtect (BasicAuth "foo-realm") AuthUser 'Strict
type BasicAuthBarRealm = AuthProtect (BasicAuth "bar-realm") AuthUser 'Strict type BasicAuthBarRealm = AuthProtect (BasicAuth "bar-realm") AuthUser 'Strict
type AuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person type BasicAuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person
:<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal :<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal
basicAuthFooCheck :: BasicAuth "foo-realm" -> IO (Maybe AuthUser) basicAuthFooCheck :: BasicAuth "foo-realm" -> IO (Maybe AuthUser)
@ -726,11 +728,11 @@ basicAuthBarCheck :: BasicAuth "bar-realm" -> IO (Maybe AuthUser)
basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar" basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar"
then return (Just "bar") then return (Just "bar")
else return Nothing else return Nothing
authRequiredApi :: Proxy AuthRequiredAPI basicBasicAuthRequiredApi :: Proxy BasicAuthRequiredAPI
authRequiredApi = Proxy basicBasicAuthRequiredApi = Proxy
authRequiredServer :: Server AuthRequiredAPI basicAuthRequiredServer :: Server BasicAuthRequiredAPI
authRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice) basicAuthRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice)
:<|> basicAuthStrict basicAuthBarCheck (const . return $ jerry) :<|> basicAuthStrict basicAuthBarCheck (const . return $ jerry)
-- base64-encoded "servant:server" -- base64-encoded "servant:server"
@ -745,25 +747,25 @@ base64BarColonPassword = "YmFyOmJhcg=="
base64UserColonPassword :: ByteString base64UserColonPassword :: ByteString
base64UserColonPassword = "dXNlcjpwYXNzd29yZA==" base64UserColonPassword = "dXNlcjpwYXNzd29yZA=="
authGet :: ByteString -> ByteString -> WaiSession SResponse basicAuthGet :: ByteString -> ByteString -> WaiSession SResponse
authGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] "" basicAuthGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] ""
authRequiredSpec :: Spec basicAuthRequiredSpec :: Spec
authRequiredSpec = do basicAuthRequiredSpec = do
describe "Servant.API.Authentication" $ do describe "Servant.API.Authentication" $ do
with (return $ serve authRequiredApi authRequiredServer) $ do with (return $ serve basicBasicAuthRequiredApi basicAuthRequiredServer) $ do
it "allows access with the correct username and password" $ do it "allows access with the correct username and password" $ do
response <- authGet "/foo" base64ServantColonServer response <- basicAuthGet "/foo" base64ServantColonServer
liftIO $ do liftIO $ do
decode' (simpleBody response) `shouldBe` Just alice decode' (simpleBody response) `shouldBe` Just alice
response <- authGet "/bar" base64BarColonPassword response <- basicAuthGet "/bar" base64BarColonPassword
liftIO $ do liftIO $ do
decode' (simpleBody response) `shouldBe` Just jerry decode' (simpleBody response) `shouldBe` Just jerry
it "rejects requests with the incorrect username and password" $ do it "rejects requests with the incorrect username and password" $ do
authGet "/foo" base64UserColonPassword `shouldRespondWith` 401 basicAuthGet "/foo" base64UserColonPassword `shouldRespondWith` 401
authGet "/bar" base64UserColonPassword `shouldRespondWith` 401 basicAuthGet "/bar" base64UserColonPassword `shouldRespondWith` 401
it "does not respond to non-authenticated requests" $ do it "does not respond to non-authenticated requests" $ do
get "/foo" `shouldRespondWith` 401 get "/foo" `shouldRespondWith` 401
@ -774,3 +776,42 @@ authRequiredSpec = do
bar401 <- get "/bar" bar401 <- get "/bar"
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401) WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401)
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401) WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401)
type JWTAuthProtect = AuthProtect JWTAuth (JWT VerifiedJWT) 'Strict
type JWTAuthRequiredAPI = JWTAuthProtect :> "foo" :> Get '[JSON] Person
jwtAuthRequiredApi :: Proxy JWTAuthRequiredAPI
jwtAuthRequiredApi = Proxy
jwtAuthRequiredServer :: Server JWTAuthRequiredAPI
jwtAuthRequiredServer = jwtAuthStrict (secret "secret") (const . return $ alice)
correctToken = "blah"
incorrectToken = "blah"
jwtAuthGet :: ByteString -> ByteString -> WaiSession SResponse
jwtAuthGet path token = Test.Hspec.Wai.request methodGet path [("Authorization", "Bearer " <> token)] ""
jwtAuthRequiredSpec :: Spec
jwtAuthRequiredSpec = do
describe "JWT Auth" $ do
with (return $ serve jwtAuthRequiredApi jwtAuthRequiredServer) $ do
it "allows access with the correct token" $ do
response <- jwtAuthGet "/foo" correctToken
liftIO $ do
decode' (simpleBody response) `shouldBe` Just alice
it "rejects requests with an incorrect token" $ do
jwtAuthGet "/foo" incorrectToken `shouldRespondWith` 401
it "rejects requests without auth data" $ do
get "/foo" `shouldRespondWith` 401
it "responds correctly to requests without auth data" $ do
a <- jwtAuthGet "/foo" incorrectToken
WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_token\"" a)
it "respond correctly to requests with incorrect auth data" $ do
a <- get "/foo"
WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_request\"" a)