Make auth tests

This commit is contained in:
Arian van Putten 2015-10-31 22:18:59 +01:00 committed by aaron levin
parent dc699a93e0
commit 667dbbc8cd
4 changed files with 107 additions and 33 deletions

View file

@ -124,6 +124,7 @@ test-suite spec
, wai
, wai-extra
, warp
, jwt
test-suite doctests
build-depends: base

View file

@ -17,7 +17,7 @@ module Servant.Server.Internal.Authentication
, jwtAuthStrict
) where
import Control.Monad (guard, (<=<))
import Control.Monad (guard)
import qualified Data.ByteString as B
import Data.ByteString.Base64 (decodeLenient)
#if !MIN_VERSION_base(4,8,0)
@ -30,7 +30,7 @@ import Data.String (fromString)
import Data.Word8 (isSpace, toLower, _colon)
import GHC.TypeLits (KnownSymbol, symbolVal)
import Data.Text.Encoding (decodeUtf8)
import Data.Text (splitOn, Text)
import Data.Text (splitOn)
import Network.HTTP.Types.Status (status401)
import Network.Wai (Request, Response, requestHeaders,
responseBuilder)
@ -38,8 +38,10 @@ import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtected,
BasicAuth (BasicAuth),
JWTAuth (..))
JWTAuth)
import Servant.Server.Internal.RoutingApplication (RouteResult(FailFatal))
import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT, Secret, JSON)
import Web.JWT (JWT, VerifiedJWT, Secret)
import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret)
-- | Class to represent the ability to extract authentication-related
@ -50,10 +52,10 @@ class AuthData a where
-- | handlers to deal with authentication failures.
data AuthHandlers authData = AuthHandlers
{ -- we couldn't find the right type of auth data (or any, for that matter)
onMissingAuthData :: IO Response
onMissingAuthData :: IO ServantError
,
-- we found the right type of auth data in the request but the check failed
onUnauthenticated :: authData -> IO Response
onUnauthenticated :: authData -> IO ServantError
}
-- | concrete type to provide when in 'Strict' mode.
@ -119,10 +121,9 @@ basicAuthLax = laxProtect
instance AuthData JWTAuth where
authData req = do
-- We might want to write a proper parser for this? but split works fine...
hdr <- lookup "Authorization" . requestHeaders $ req
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
JWT.decode token -- try decode it. otherwise it's not a proper token
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
return . JWTAuth $ token

View file

@ -1,4 +1,5 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeOperators #-}
@ -19,6 +20,8 @@ import Network.Wai (Application, Request,
Response, ResponseReceived,
requestBody,
strictRequestBody)
import Servant.API.Authentication (AuthPolicy(Strict,Lax))
import Servant.Server.Internal.Authentication
import Servant.Server.Internal.ServantErr
type RoutingApplication =
@ -155,32 +158,60 @@ data Delayed :: * -> * where
Delayed :: IO (RouteResult a)
-> IO (RouteResult ())
-> IO (RouteResult b)
-> (a -> b -> RouteResult c)
-> Delayed c
-> IO (RouteResult c)
-> (a -> b -> c -> RouteResult d)
-> Delayed d
instance Functor Delayed where
fmap f (Delayed a b c g) = Delayed a b c ((fmap.fmap.fmap) f g)
fmap f (Delayed a b c d g) = Delayed a b c d ((fmap.fmap.fmap.fmap) f g)
-- | Add a capture to the end of the capture block.
addCapture :: Delayed (a -> b)
-> IO (RouteResult a)
-> Delayed b
addCapture (Delayed captures method body server) new =
Delayed (combineRouteResults (,) captures new) method body (\ (x, v) y -> ($ v) <$> server x y)
addCapture (Delayed captures method auth body server) new =
Delayed (combineRouteResults (,) captures new) method auth body (\ (x, v) y z -> ($ v) <$> server x y z)
-- | Add a method check to the end of the method block.
addMethodCheck :: Delayed a
-> IO (RouteResult ())
-> Delayed a
addMethodCheck (Delayed captures method body server) new =
Delayed captures (combineRouteResults const method new) body server
addMethodCheck (Delayed captures method auth body server) new =
Delayed captures (combineRouteResults const method new) auth body server
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
-> IO (RouteResult (Maybe usr))
-> Delayed a
addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
let newAuth = runDelayed delayed `bindRouteResults` \ authProtectionStrict -> new `bindRouteResults` \ mAuthData -> case mAuthData of
Nothing -> do
-- we're in strict mode: don't let the request go
-- call the provided "on missing auth" handler
resp <- onMissingAuthData (authHandlers authProtectionStrict)
return $ FailFatal resp
-- successfully pulled auth data out of the request
Just authData -> do
mUsr <- (checkAuthStrict authProtectionStrict) authData
case mUsr of
-- this user is not authenticated
Nothing -> do
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData
return $ FailFatal resp
-- this user is authenticated
Just usr ->
(return . Route . subServerStrict authProtectionStrict) usr
in Delayed captures method newAuth body server
-- | Add a body check to the end of the body block.
addBodyCheck :: Delayed (a -> b)
-> IO (RouteResult a)
-> Delayed b
addBodyCheck (Delayed captures method body server) new =
Delayed captures method (combineRouteResults (,) body new) (\ x (y, v) -> ($ v) <$> server x y)
addBodyCheck (Delayed captures method auth body server) new =
Delayed captures method auth (combineRouteResults (,) body new) (\ x (y, v) z -> ($ v) <$> server x y z)
-- | Add an accept header check to the end of the body block.
-- The accept header check should occur after the body check,
@ -189,8 +220,8 @@ addBodyCheck (Delayed captures method body server) new =
addAcceptCheck :: Delayed a
-> IO (RouteResult ())
-> Delayed a
addAcceptCheck (Delayed captures method body server) new =
Delayed captures method (combineRouteResults const body new) server
addAcceptCheck (Delayed captures method auth body server) new =
Delayed captures method auth (combineRouteResults const body new) server
-- | Many combinators extract information that is passed to
-- the handler without the possibility of failure. In such a

View file

@ -67,6 +67,7 @@ import Servant.API.Authentication
import Servant.Server.Internal.Authentication
import Servant.Server (Server, serve, ServantErr(..), err404)
import Servant.Server.Internal (RouteMismatch (..))
import Web.JWT hiding (JSON)
-- * test data types
@ -116,7 +117,8 @@ spec = do
routerSpec
responseHeadersSpec
miscReqCombinatorsSpec
authRequiredSpec
basicAuthRequiredSpec
jwtAuthRequiredSpec
type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal
@ -714,7 +716,7 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $
type AuthUser = ByteString
type BasicAuthFooRealm = AuthProtect (BasicAuth "foo-realm") AuthUser 'Strict
type BasicAuthBarRealm = AuthProtect (BasicAuth "bar-realm") AuthUser 'Strict
type AuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person
type BasicAuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person
:<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal
basicAuthFooCheck :: BasicAuth "foo-realm" -> IO (Maybe AuthUser)
@ -726,11 +728,11 @@ basicAuthBarCheck :: BasicAuth "bar-realm" -> IO (Maybe AuthUser)
basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar"
then return (Just "bar")
else return Nothing
authRequiredApi :: Proxy AuthRequiredAPI
authRequiredApi = Proxy
basicBasicAuthRequiredApi :: Proxy BasicAuthRequiredAPI
basicBasicAuthRequiredApi = Proxy
authRequiredServer :: Server AuthRequiredAPI
authRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice)
basicAuthRequiredServer :: Server BasicAuthRequiredAPI
basicAuthRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice)
:<|> basicAuthStrict basicAuthBarCheck (const . return $ jerry)
-- base64-encoded "servant:server"
@ -745,25 +747,25 @@ base64BarColonPassword = "YmFyOmJhcg=="
base64UserColonPassword :: ByteString
base64UserColonPassword = "dXNlcjpwYXNzd29yZA=="
authGet :: ByteString -> ByteString -> WaiSession SResponse
authGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] ""
basicAuthGet :: ByteString -> ByteString -> WaiSession SResponse
basicAuthGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] ""
authRequiredSpec :: Spec
authRequiredSpec = do
basicAuthRequiredSpec :: Spec
basicAuthRequiredSpec = do
describe "Servant.API.Authentication" $ do
with (return $ serve authRequiredApi authRequiredServer) $ do
with (return $ serve basicBasicAuthRequiredApi basicAuthRequiredServer) $ do
it "allows access with the correct username and password" $ do
response <- authGet "/foo" base64ServantColonServer
response <- basicAuthGet "/foo" base64ServantColonServer
liftIO $ do
decode' (simpleBody response) `shouldBe` Just alice
response <- authGet "/bar" base64BarColonPassword
response <- basicAuthGet "/bar" base64BarColonPassword
liftIO $ do
decode' (simpleBody response) `shouldBe` Just jerry
it "rejects requests with the incorrect username and password" $ do
authGet "/foo" base64UserColonPassword `shouldRespondWith` 401
authGet "/bar" base64UserColonPassword `shouldRespondWith` 401
basicAuthGet "/foo" base64UserColonPassword `shouldRespondWith` 401
basicAuthGet "/bar" base64UserColonPassword `shouldRespondWith` 401
it "does not respond to non-authenticated requests" $ do
get "/foo" `shouldRespondWith` 401
@ -774,3 +776,42 @@ authRequiredSpec = do
bar401 <- get "/bar"
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401)
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401)
type JWTAuthProtect = AuthProtect JWTAuth (JWT VerifiedJWT) 'Strict
type JWTAuthRequiredAPI = JWTAuthProtect :> "foo" :> Get '[JSON] Person
jwtAuthRequiredApi :: Proxy JWTAuthRequiredAPI
jwtAuthRequiredApi = Proxy
jwtAuthRequiredServer :: Server JWTAuthRequiredAPI
jwtAuthRequiredServer = jwtAuthStrict (secret "secret") (const . return $ alice)
correctToken = "blah"
incorrectToken = "blah"
jwtAuthGet :: ByteString -> ByteString -> WaiSession SResponse
jwtAuthGet path token = Test.Hspec.Wai.request methodGet path [("Authorization", "Bearer " <> token)] ""
jwtAuthRequiredSpec :: Spec
jwtAuthRequiredSpec = do
describe "JWT Auth" $ do
with (return $ serve jwtAuthRequiredApi jwtAuthRequiredServer) $ do
it "allows access with the correct token" $ do
response <- jwtAuthGet "/foo" correctToken
liftIO $ do
decode' (simpleBody response) `shouldBe` Just alice
it "rejects requests with an incorrect token" $ do
jwtAuthGet "/foo" incorrectToken `shouldRespondWith` 401
it "rejects requests without auth data" $ do
get "/foo" `shouldRespondWith` 401
it "responds correctly to requests without auth data" $ do
a <- jwtAuthGet "/foo" incorrectToken
WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_token\"" a)
it "respond correctly to requests with incorrect auth data" $ do
a <- get "/foo"
WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_request\"" a)