Make auth tests
This commit is contained in:
parent
dc699a93e0
commit
667dbbc8cd
4 changed files with 107 additions and 33 deletions
|
@ -124,6 +124,7 @@ test-suite spec
|
|||
, wai
|
||||
, wai-extra
|
||||
, warp
|
||||
, jwt
|
||||
|
||||
test-suite doctests
|
||||
build-depends: base
|
||||
|
|
|
@ -17,7 +17,7 @@ module Servant.Server.Internal.Authentication
|
|||
, jwtAuthStrict
|
||||
) where
|
||||
|
||||
import Control.Monad (guard, (<=<))
|
||||
import Control.Monad (guard)
|
||||
import qualified Data.ByteString as B
|
||||
import Data.ByteString.Base64 (decodeLenient)
|
||||
#if !MIN_VERSION_base(4,8,0)
|
||||
|
@ -30,7 +30,7 @@ import Data.String (fromString)
|
|||
import Data.Word8 (isSpace, toLower, _colon)
|
||||
import GHC.TypeLits (KnownSymbol, symbolVal)
|
||||
import Data.Text.Encoding (decodeUtf8)
|
||||
import Data.Text (splitOn, Text)
|
||||
import Data.Text (splitOn)
|
||||
import Network.HTTP.Types.Status (status401)
|
||||
import Network.Wai (Request, Response, requestHeaders,
|
||||
responseBuilder)
|
||||
|
@ -38,8 +38,10 @@ import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
|||
AuthProtected,
|
||||
BasicAuth (BasicAuth),
|
||||
JWTAuth (..))
|
||||
JWTAuth)
|
||||
import Servant.Server.Internal.RoutingApplication (RouteResult(FailFatal))
|
||||
|
||||
import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT, Secret, JSON)
|
||||
import Web.JWT (JWT, VerifiedJWT, Secret)
|
||||
import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret)
|
||||
|
||||
-- | Class to represent the ability to extract authentication-related
|
||||
|
@ -50,10 +52,10 @@ class AuthData a where
|
|||
-- | handlers to deal with authentication failures.
|
||||
data AuthHandlers authData = AuthHandlers
|
||||
{ -- we couldn't find the right type of auth data (or any, for that matter)
|
||||
onMissingAuthData :: IO Response
|
||||
onMissingAuthData :: IO ServantError
|
||||
,
|
||||
-- we found the right type of auth data in the request but the check failed
|
||||
onUnauthenticated :: authData -> IO Response
|
||||
onUnauthenticated :: authData -> IO ServantError
|
||||
}
|
||||
|
||||
-- | concrete type to provide when in 'Strict' mode.
|
||||
|
@ -119,10 +121,9 @@ basicAuthLax = laxProtect
|
|||
|
||||
instance AuthData JWTAuth where
|
||||
authData req = do
|
||||
-- We might want to write a proper parser for this? but split works fine...
|
||||
hdr <- lookup "Authorization" . requestHeaders $ req
|
||||
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
|
||||
JWT.decode token -- try decode it. otherwise it's not a proper token
|
||||
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
|
||||
return . JWTAuth $ token
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveFunctor #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
|
@ -19,6 +20,8 @@ import Network.Wai (Application, Request,
|
|||
Response, ResponseReceived,
|
||||
requestBody,
|
||||
strictRequestBody)
|
||||
import Servant.API.Authentication (AuthPolicy(Strict,Lax))
|
||||
import Servant.Server.Internal.Authentication
|
||||
import Servant.Server.Internal.ServantErr
|
||||
|
||||
type RoutingApplication =
|
||||
|
@ -155,32 +158,60 @@ data Delayed :: * -> * where
|
|||
Delayed :: IO (RouteResult a)
|
||||
-> IO (RouteResult ())
|
||||
-> IO (RouteResult b)
|
||||
-> (a -> b -> RouteResult c)
|
||||
-> Delayed c
|
||||
-> IO (RouteResult c)
|
||||
-> (a -> b -> c -> RouteResult d)
|
||||
-> Delayed d
|
||||
|
||||
instance Functor Delayed where
|
||||
fmap f (Delayed a b c g) = Delayed a b c ((fmap.fmap.fmap) f g)
|
||||
fmap f (Delayed a b c d g) = Delayed a b c d ((fmap.fmap.fmap.fmap) f g)
|
||||
|
||||
-- | Add a capture to the end of the capture block.
|
||||
addCapture :: Delayed (a -> b)
|
||||
-> IO (RouteResult a)
|
||||
-> Delayed b
|
||||
addCapture (Delayed captures method body server) new =
|
||||
Delayed (combineRouteResults (,) captures new) method body (\ (x, v) y -> ($ v) <$> server x y)
|
||||
addCapture (Delayed captures method auth body server) new =
|
||||
Delayed (combineRouteResults (,) captures new) method auth body (\ (x, v) y z -> ($ v) <$> server x y z)
|
||||
|
||||
-- | Add a method check to the end of the method block.
|
||||
addMethodCheck :: Delayed a
|
||||
-> IO (RouteResult ())
|
||||
-> Delayed a
|
||||
addMethodCheck (Delayed captures method body server) new =
|
||||
Delayed captures (combineRouteResults const method new) body server
|
||||
addMethodCheck (Delayed captures method auth body server) new =
|
||||
Delayed captures (combineRouteResults const method new) auth body server
|
||||
|
||||
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
|
||||
-> IO (RouteResult (Maybe usr))
|
||||
-> Delayed a
|
||||
addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
|
||||
let newAuth = runDelayed delayed `bindRouteResults` \ authProtectionStrict -> new `bindRouteResults` \ mAuthData -> case mAuthData of
|
||||
|
||||
Nothing -> do
|
||||
-- we're in strict mode: don't let the request go
|
||||
-- call the provided "on missing auth" handler
|
||||
resp <- onMissingAuthData (authHandlers authProtectionStrict)
|
||||
return $ FailFatal resp
|
||||
|
||||
-- successfully pulled auth data out of the request
|
||||
Just authData -> do
|
||||
mUsr <- (checkAuthStrict authProtectionStrict) authData
|
||||
case mUsr of
|
||||
-- this user is not authenticated
|
||||
Nothing -> do
|
||||
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData
|
||||
return $ FailFatal resp
|
||||
|
||||
-- this user is authenticated
|
||||
Just usr ->
|
||||
(return . Route . subServerStrict authProtectionStrict) usr
|
||||
in Delayed captures method newAuth body server
|
||||
|
||||
|
||||
-- | Add a body check to the end of the body block.
|
||||
addBodyCheck :: Delayed (a -> b)
|
||||
-> IO (RouteResult a)
|
||||
-> Delayed b
|
||||
addBodyCheck (Delayed captures method body server) new =
|
||||
Delayed captures method (combineRouteResults (,) body new) (\ x (y, v) -> ($ v) <$> server x y)
|
||||
addBodyCheck (Delayed captures method auth body server) new =
|
||||
Delayed captures method auth (combineRouteResults (,) body new) (\ x (y, v) z -> ($ v) <$> server x y z)
|
||||
|
||||
-- | Add an accept header check to the end of the body block.
|
||||
-- The accept header check should occur after the body check,
|
||||
|
@ -189,8 +220,8 @@ addBodyCheck (Delayed captures method body server) new =
|
|||
addAcceptCheck :: Delayed a
|
||||
-> IO (RouteResult ())
|
||||
-> Delayed a
|
||||
addAcceptCheck (Delayed captures method body server) new =
|
||||
Delayed captures method (combineRouteResults const body new) server
|
||||
addAcceptCheck (Delayed captures method auth body server) new =
|
||||
Delayed captures method auth (combineRouteResults const body new) server
|
||||
|
||||
-- | Many combinators extract information that is passed to
|
||||
-- the handler without the possibility of failure. In such a
|
||||
|
|
|
@ -67,6 +67,7 @@ import Servant.API.Authentication
|
|||
import Servant.Server.Internal.Authentication
|
||||
import Servant.Server (Server, serve, ServantErr(..), err404)
|
||||
import Servant.Server.Internal (RouteMismatch (..))
|
||||
import Web.JWT hiding (JSON)
|
||||
|
||||
|
||||
-- * test data types
|
||||
|
@ -116,7 +117,8 @@ spec = do
|
|||
routerSpec
|
||||
responseHeadersSpec
|
||||
miscReqCombinatorsSpec
|
||||
authRequiredSpec
|
||||
basicAuthRequiredSpec
|
||||
jwtAuthRequiredSpec
|
||||
|
||||
|
||||
type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal
|
||||
|
@ -714,7 +716,7 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $
|
|||
type AuthUser = ByteString
|
||||
type BasicAuthFooRealm = AuthProtect (BasicAuth "foo-realm") AuthUser 'Strict
|
||||
type BasicAuthBarRealm = AuthProtect (BasicAuth "bar-realm") AuthUser 'Strict
|
||||
type AuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person
|
||||
type BasicAuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person
|
||||
:<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal
|
||||
|
||||
basicAuthFooCheck :: BasicAuth "foo-realm" -> IO (Maybe AuthUser)
|
||||
|
@ -726,11 +728,11 @@ basicAuthBarCheck :: BasicAuth "bar-realm" -> IO (Maybe AuthUser)
|
|||
basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar"
|
||||
then return (Just "bar")
|
||||
else return Nothing
|
||||
authRequiredApi :: Proxy AuthRequiredAPI
|
||||
authRequiredApi = Proxy
|
||||
basicBasicAuthRequiredApi :: Proxy BasicAuthRequiredAPI
|
||||
basicBasicAuthRequiredApi = Proxy
|
||||
|
||||
authRequiredServer :: Server AuthRequiredAPI
|
||||
authRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice)
|
||||
basicAuthRequiredServer :: Server BasicAuthRequiredAPI
|
||||
basicAuthRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice)
|
||||
:<|> basicAuthStrict basicAuthBarCheck (const . return $ jerry)
|
||||
|
||||
-- base64-encoded "servant:server"
|
||||
|
@ -745,25 +747,25 @@ base64BarColonPassword = "YmFyOmJhcg=="
|
|||
base64UserColonPassword :: ByteString
|
||||
base64UserColonPassword = "dXNlcjpwYXNzd29yZA=="
|
||||
|
||||
authGet :: ByteString -> ByteString -> WaiSession SResponse
|
||||
authGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] ""
|
||||
basicAuthGet :: ByteString -> ByteString -> WaiSession SResponse
|
||||
basicAuthGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] ""
|
||||
|
||||
authRequiredSpec :: Spec
|
||||
authRequiredSpec = do
|
||||
basicAuthRequiredSpec :: Spec
|
||||
basicAuthRequiredSpec = do
|
||||
describe "Servant.API.Authentication" $ do
|
||||
with (return $ serve authRequiredApi authRequiredServer) $ do
|
||||
with (return $ serve basicBasicAuthRequiredApi basicAuthRequiredServer) $ do
|
||||
it "allows access with the correct username and password" $ do
|
||||
response <- authGet "/foo" base64ServantColonServer
|
||||
response <- basicAuthGet "/foo" base64ServantColonServer
|
||||
liftIO $ do
|
||||
decode' (simpleBody response) `shouldBe` Just alice
|
||||
|
||||
response <- authGet "/bar" base64BarColonPassword
|
||||
response <- basicAuthGet "/bar" base64BarColonPassword
|
||||
liftIO $ do
|
||||
decode' (simpleBody response) `shouldBe` Just jerry
|
||||
|
||||
it "rejects requests with the incorrect username and password" $ do
|
||||
authGet "/foo" base64UserColonPassword `shouldRespondWith` 401
|
||||
authGet "/bar" base64UserColonPassword `shouldRespondWith` 401
|
||||
basicAuthGet "/foo" base64UserColonPassword `shouldRespondWith` 401
|
||||
basicAuthGet "/bar" base64UserColonPassword `shouldRespondWith` 401
|
||||
|
||||
it "does not respond to non-authenticated requests" $ do
|
||||
get "/foo" `shouldRespondWith` 401
|
||||
|
@ -774,3 +776,42 @@ authRequiredSpec = do
|
|||
bar401 <- get "/bar"
|
||||
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401)
|
||||
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401)
|
||||
|
||||
|
||||
|
||||
type JWTAuthProtect = AuthProtect JWTAuth (JWT VerifiedJWT) 'Strict
|
||||
|
||||
type JWTAuthRequiredAPI = JWTAuthProtect :> "foo" :> Get '[JSON] Person
|
||||
|
||||
|
||||
jwtAuthRequiredApi :: Proxy JWTAuthRequiredAPI
|
||||
jwtAuthRequiredApi = Proxy
|
||||
|
||||
jwtAuthRequiredServer :: Server JWTAuthRequiredAPI
|
||||
jwtAuthRequiredServer = jwtAuthStrict (secret "secret") (const . return $ alice)
|
||||
|
||||
correctToken = "blah"
|
||||
incorrectToken = "blah"
|
||||
|
||||
jwtAuthGet :: ByteString -> ByteString -> WaiSession SResponse
|
||||
jwtAuthGet path token = Test.Hspec.Wai.request methodGet path [("Authorization", "Bearer " <> token)] ""
|
||||
|
||||
jwtAuthRequiredSpec :: Spec
|
||||
jwtAuthRequiredSpec = do
|
||||
describe "JWT Auth" $ do
|
||||
with (return $ serve jwtAuthRequiredApi jwtAuthRequiredServer) $ do
|
||||
it "allows access with the correct token" $ do
|
||||
response <- jwtAuthGet "/foo" correctToken
|
||||
liftIO $ do
|
||||
decode' (simpleBody response) `shouldBe` Just alice
|
||||
it "rejects requests with an incorrect token" $ do
|
||||
jwtAuthGet "/foo" incorrectToken `shouldRespondWith` 401
|
||||
it "rejects requests without auth data" $ do
|
||||
get "/foo" `shouldRespondWith` 401
|
||||
it "responds correctly to requests without auth data" $ do
|
||||
a <- jwtAuthGet "/foo" incorrectToken
|
||||
WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_token\"" a)
|
||||
it "respond correctly to requests with incorrect auth data" $ do
|
||||
a <- get "/foo"
|
||||
WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_request\"" a)
|
||||
|
||||
|
|
Loading…
Reference in a new issue