Make auth tests
This commit is contained in:
parent
dc699a93e0
commit
667dbbc8cd
4 changed files with 107 additions and 33 deletions
|
@ -124,6 +124,7 @@ test-suite spec
|
||||||
, wai
|
, wai
|
||||||
, wai-extra
|
, wai-extra
|
||||||
, warp
|
, warp
|
||||||
|
, jwt
|
||||||
|
|
||||||
test-suite doctests
|
test-suite doctests
|
||||||
build-depends: base
|
build-depends: base
|
||||||
|
|
|
@ -17,7 +17,7 @@ module Servant.Server.Internal.Authentication
|
||||||
, jwtAuthStrict
|
, jwtAuthStrict
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad (guard, (<=<))
|
import Control.Monad (guard)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import Data.ByteString.Base64 (decodeLenient)
|
import Data.ByteString.Base64 (decodeLenient)
|
||||||
#if !MIN_VERSION_base(4,8,0)
|
#if !MIN_VERSION_base(4,8,0)
|
||||||
|
@ -30,7 +30,7 @@ import Data.String (fromString)
|
||||||
import Data.Word8 (isSpace, toLower, _colon)
|
import Data.Word8 (isSpace, toLower, _colon)
|
||||||
import GHC.TypeLits (KnownSymbol, symbolVal)
|
import GHC.TypeLits (KnownSymbol, symbolVal)
|
||||||
import Data.Text.Encoding (decodeUtf8)
|
import Data.Text.Encoding (decodeUtf8)
|
||||||
import Data.Text (splitOn, Text)
|
import Data.Text (splitOn)
|
||||||
import Network.HTTP.Types.Status (status401)
|
import Network.HTTP.Types.Status (status401)
|
||||||
import Network.Wai (Request, Response, requestHeaders,
|
import Network.Wai (Request, Response, requestHeaders,
|
||||||
responseBuilder)
|
responseBuilder)
|
||||||
|
@ -38,8 +38,10 @@ import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
||||||
AuthProtected,
|
AuthProtected,
|
||||||
BasicAuth (BasicAuth),
|
BasicAuth (BasicAuth),
|
||||||
JWTAuth (..))
|
JWTAuth (..))
|
||||||
|
JWTAuth)
|
||||||
|
import Servant.Server.Internal.RoutingApplication (RouteResult(FailFatal))
|
||||||
|
|
||||||
import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT, Secret, JSON)
|
import Web.JWT (JWT, VerifiedJWT, Secret)
|
||||||
import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret)
|
import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret)
|
||||||
|
|
||||||
-- | Class to represent the ability to extract authentication-related
|
-- | Class to represent the ability to extract authentication-related
|
||||||
|
@ -50,10 +52,10 @@ class AuthData a where
|
||||||
-- | handlers to deal with authentication failures.
|
-- | handlers to deal with authentication failures.
|
||||||
data AuthHandlers authData = AuthHandlers
|
data AuthHandlers authData = AuthHandlers
|
||||||
{ -- we couldn't find the right type of auth data (or any, for that matter)
|
{ -- we couldn't find the right type of auth data (or any, for that matter)
|
||||||
onMissingAuthData :: IO Response
|
onMissingAuthData :: IO ServantError
|
||||||
,
|
,
|
||||||
-- we found the right type of auth data in the request but the check failed
|
-- we found the right type of auth data in the request but the check failed
|
||||||
onUnauthenticated :: authData -> IO Response
|
onUnauthenticated :: authData -> IO ServantError
|
||||||
}
|
}
|
||||||
|
|
||||||
-- | concrete type to provide when in 'Strict' mode.
|
-- | concrete type to provide when in 'Strict' mode.
|
||||||
|
@ -119,10 +121,9 @@ basicAuthLax = laxProtect
|
||||||
|
|
||||||
instance AuthData JWTAuth where
|
instance AuthData JWTAuth where
|
||||||
authData req = do
|
authData req = do
|
||||||
-- We might want to write a proper parser for this? but split works fine...
|
|
||||||
hdr <- lookup "Authorization" . requestHeaders $ req
|
hdr <- lookup "Authorization" . requestHeaders $ req
|
||||||
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
|
["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr
|
||||||
JWT.decode token -- try decode it. otherwise it's not a proper token
|
_ <- JWT.decode token -- try decode it. otherwise it's not a proper token
|
||||||
return . JWTAuth $ token
|
return . JWTAuth $ token
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
{-# LANGUAGE CPP #-}
|
{-# LANGUAGE CPP #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE DeriveFunctor #-}
|
{-# LANGUAGE DeriveFunctor #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
@ -19,6 +20,8 @@ import Network.Wai (Application, Request,
|
||||||
Response, ResponseReceived,
|
Response, ResponseReceived,
|
||||||
requestBody,
|
requestBody,
|
||||||
strictRequestBody)
|
strictRequestBody)
|
||||||
|
import Servant.API.Authentication (AuthPolicy(Strict,Lax))
|
||||||
|
import Servant.Server.Internal.Authentication
|
||||||
import Servant.Server.Internal.ServantErr
|
import Servant.Server.Internal.ServantErr
|
||||||
|
|
||||||
type RoutingApplication =
|
type RoutingApplication =
|
||||||
|
@ -155,32 +158,60 @@ data Delayed :: * -> * where
|
||||||
Delayed :: IO (RouteResult a)
|
Delayed :: IO (RouteResult a)
|
||||||
-> IO (RouteResult ())
|
-> IO (RouteResult ())
|
||||||
-> IO (RouteResult b)
|
-> IO (RouteResult b)
|
||||||
-> (a -> b -> RouteResult c)
|
-> IO (RouteResult c)
|
||||||
-> Delayed c
|
-> (a -> b -> c -> RouteResult d)
|
||||||
|
-> Delayed d
|
||||||
|
|
||||||
instance Functor Delayed where
|
instance Functor Delayed where
|
||||||
fmap f (Delayed a b c g) = Delayed a b c ((fmap.fmap.fmap) f g)
|
fmap f (Delayed a b c d g) = Delayed a b c d ((fmap.fmap.fmap.fmap) f g)
|
||||||
|
|
||||||
-- | Add a capture to the end of the capture block.
|
-- | Add a capture to the end of the capture block.
|
||||||
addCapture :: Delayed (a -> b)
|
addCapture :: Delayed (a -> b)
|
||||||
-> IO (RouteResult a)
|
-> IO (RouteResult a)
|
||||||
-> Delayed b
|
-> Delayed b
|
||||||
addCapture (Delayed captures method body server) new =
|
addCapture (Delayed captures method auth body server) new =
|
||||||
Delayed (combineRouteResults (,) captures new) method body (\ (x, v) y -> ($ v) <$> server x y)
|
Delayed (combineRouteResults (,) captures new) method auth body (\ (x, v) y z -> ($ v) <$> server x y z)
|
||||||
|
|
||||||
-- | Add a method check to the end of the method block.
|
-- | Add a method check to the end of the method block.
|
||||||
addMethodCheck :: Delayed a
|
addMethodCheck :: Delayed a
|
||||||
-> IO (RouteResult ())
|
-> IO (RouteResult ())
|
||||||
-> Delayed a
|
-> Delayed a
|
||||||
addMethodCheck (Delayed captures method body server) new =
|
addMethodCheck (Delayed captures method auth body server) new =
|
||||||
Delayed captures (combineRouteResults const method new) body server
|
Delayed captures (combineRouteResults const method new) auth body server
|
||||||
|
|
||||||
|
addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict)
|
||||||
|
-> IO (RouteResult (Maybe usr))
|
||||||
|
-> Delayed a
|
||||||
|
addAuthStrictCheck delayed@(Delayed captures method _ body server) new =
|
||||||
|
let newAuth = runDelayed delayed `bindRouteResults` \ authProtectionStrict -> new `bindRouteResults` \ mAuthData -> case mAuthData of
|
||||||
|
|
||||||
|
Nothing -> do
|
||||||
|
-- we're in strict mode: don't let the request go
|
||||||
|
-- call the provided "on missing auth" handler
|
||||||
|
resp <- onMissingAuthData (authHandlers authProtectionStrict)
|
||||||
|
return $ FailFatal resp
|
||||||
|
|
||||||
|
-- successfully pulled auth data out of the request
|
||||||
|
Just authData -> do
|
||||||
|
mUsr <- (checkAuthStrict authProtectionStrict) authData
|
||||||
|
case mUsr of
|
||||||
|
-- this user is not authenticated
|
||||||
|
Nothing -> do
|
||||||
|
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData
|
||||||
|
return $ FailFatal resp
|
||||||
|
|
||||||
|
-- this user is authenticated
|
||||||
|
Just usr ->
|
||||||
|
(return . Route . subServerStrict authProtectionStrict) usr
|
||||||
|
in Delayed captures method newAuth body server
|
||||||
|
|
||||||
|
|
||||||
-- | Add a body check to the end of the body block.
|
-- | Add a body check to the end of the body block.
|
||||||
addBodyCheck :: Delayed (a -> b)
|
addBodyCheck :: Delayed (a -> b)
|
||||||
-> IO (RouteResult a)
|
-> IO (RouteResult a)
|
||||||
-> Delayed b
|
-> Delayed b
|
||||||
addBodyCheck (Delayed captures method body server) new =
|
addBodyCheck (Delayed captures method auth body server) new =
|
||||||
Delayed captures method (combineRouteResults (,) body new) (\ x (y, v) -> ($ v) <$> server x y)
|
Delayed captures method auth (combineRouteResults (,) body new) (\ x (y, v) z -> ($ v) <$> server x y z)
|
||||||
|
|
||||||
-- | Add an accept header check to the end of the body block.
|
-- | Add an accept header check to the end of the body block.
|
||||||
-- The accept header check should occur after the body check,
|
-- The accept header check should occur after the body check,
|
||||||
|
@ -189,8 +220,8 @@ addBodyCheck (Delayed captures method body server) new =
|
||||||
addAcceptCheck :: Delayed a
|
addAcceptCheck :: Delayed a
|
||||||
-> IO (RouteResult ())
|
-> IO (RouteResult ())
|
||||||
-> Delayed a
|
-> Delayed a
|
||||||
addAcceptCheck (Delayed captures method body server) new =
|
addAcceptCheck (Delayed captures method auth body server) new =
|
||||||
Delayed captures method (combineRouteResults const body new) server
|
Delayed captures method auth (combineRouteResults const body new) server
|
||||||
|
|
||||||
-- | Many combinators extract information that is passed to
|
-- | Many combinators extract information that is passed to
|
||||||
-- the handler without the possibility of failure. In such a
|
-- the handler without the possibility of failure. In such a
|
||||||
|
|
|
@ -67,6 +67,7 @@ import Servant.API.Authentication
|
||||||
import Servant.Server.Internal.Authentication
|
import Servant.Server.Internal.Authentication
|
||||||
import Servant.Server (Server, serve, ServantErr(..), err404)
|
import Servant.Server (Server, serve, ServantErr(..), err404)
|
||||||
import Servant.Server.Internal (RouteMismatch (..))
|
import Servant.Server.Internal (RouteMismatch (..))
|
||||||
|
import Web.JWT hiding (JSON)
|
||||||
|
|
||||||
|
|
||||||
-- * test data types
|
-- * test data types
|
||||||
|
@ -116,7 +117,8 @@ spec = do
|
||||||
routerSpec
|
routerSpec
|
||||||
responseHeadersSpec
|
responseHeadersSpec
|
||||||
miscReqCombinatorsSpec
|
miscReqCombinatorsSpec
|
||||||
authRequiredSpec
|
basicAuthRequiredSpec
|
||||||
|
jwtAuthRequiredSpec
|
||||||
|
|
||||||
|
|
||||||
type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal
|
type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal
|
||||||
|
@ -714,7 +716,7 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $
|
||||||
type AuthUser = ByteString
|
type AuthUser = ByteString
|
||||||
type BasicAuthFooRealm = AuthProtect (BasicAuth "foo-realm") AuthUser 'Strict
|
type BasicAuthFooRealm = AuthProtect (BasicAuth "foo-realm") AuthUser 'Strict
|
||||||
type BasicAuthBarRealm = AuthProtect (BasicAuth "bar-realm") AuthUser 'Strict
|
type BasicAuthBarRealm = AuthProtect (BasicAuth "bar-realm") AuthUser 'Strict
|
||||||
type AuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person
|
type BasicAuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person
|
||||||
:<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal
|
:<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal
|
||||||
|
|
||||||
basicAuthFooCheck :: BasicAuth "foo-realm" -> IO (Maybe AuthUser)
|
basicAuthFooCheck :: BasicAuth "foo-realm" -> IO (Maybe AuthUser)
|
||||||
|
@ -726,11 +728,11 @@ basicAuthBarCheck :: BasicAuth "bar-realm" -> IO (Maybe AuthUser)
|
||||||
basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar"
|
basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar"
|
||||||
then return (Just "bar")
|
then return (Just "bar")
|
||||||
else return Nothing
|
else return Nothing
|
||||||
authRequiredApi :: Proxy AuthRequiredAPI
|
basicBasicAuthRequiredApi :: Proxy BasicAuthRequiredAPI
|
||||||
authRequiredApi = Proxy
|
basicBasicAuthRequiredApi = Proxy
|
||||||
|
|
||||||
authRequiredServer :: Server AuthRequiredAPI
|
basicAuthRequiredServer :: Server BasicAuthRequiredAPI
|
||||||
authRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice)
|
basicAuthRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice)
|
||||||
:<|> basicAuthStrict basicAuthBarCheck (const . return $ jerry)
|
:<|> basicAuthStrict basicAuthBarCheck (const . return $ jerry)
|
||||||
|
|
||||||
-- base64-encoded "servant:server"
|
-- base64-encoded "servant:server"
|
||||||
|
@ -745,25 +747,25 @@ base64BarColonPassword = "YmFyOmJhcg=="
|
||||||
base64UserColonPassword :: ByteString
|
base64UserColonPassword :: ByteString
|
||||||
base64UserColonPassword = "dXNlcjpwYXNzd29yZA=="
|
base64UserColonPassword = "dXNlcjpwYXNzd29yZA=="
|
||||||
|
|
||||||
authGet :: ByteString -> ByteString -> WaiSession SResponse
|
basicAuthGet :: ByteString -> ByteString -> WaiSession SResponse
|
||||||
authGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] ""
|
basicAuthGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] ""
|
||||||
|
|
||||||
authRequiredSpec :: Spec
|
basicAuthRequiredSpec :: Spec
|
||||||
authRequiredSpec = do
|
basicAuthRequiredSpec = do
|
||||||
describe "Servant.API.Authentication" $ do
|
describe "Servant.API.Authentication" $ do
|
||||||
with (return $ serve authRequiredApi authRequiredServer) $ do
|
with (return $ serve basicBasicAuthRequiredApi basicAuthRequiredServer) $ do
|
||||||
it "allows access with the correct username and password" $ do
|
it "allows access with the correct username and password" $ do
|
||||||
response <- authGet "/foo" base64ServantColonServer
|
response <- basicAuthGet "/foo" base64ServantColonServer
|
||||||
liftIO $ do
|
liftIO $ do
|
||||||
decode' (simpleBody response) `shouldBe` Just alice
|
decode' (simpleBody response) `shouldBe` Just alice
|
||||||
|
|
||||||
response <- authGet "/bar" base64BarColonPassword
|
response <- basicAuthGet "/bar" base64BarColonPassword
|
||||||
liftIO $ do
|
liftIO $ do
|
||||||
decode' (simpleBody response) `shouldBe` Just jerry
|
decode' (simpleBody response) `shouldBe` Just jerry
|
||||||
|
|
||||||
it "rejects requests with the incorrect username and password" $ do
|
it "rejects requests with the incorrect username and password" $ do
|
||||||
authGet "/foo" base64UserColonPassword `shouldRespondWith` 401
|
basicAuthGet "/foo" base64UserColonPassword `shouldRespondWith` 401
|
||||||
authGet "/bar" base64UserColonPassword `shouldRespondWith` 401
|
basicAuthGet "/bar" base64UserColonPassword `shouldRespondWith` 401
|
||||||
|
|
||||||
it "does not respond to non-authenticated requests" $ do
|
it "does not respond to non-authenticated requests" $ do
|
||||||
get "/foo" `shouldRespondWith` 401
|
get "/foo" `shouldRespondWith` 401
|
||||||
|
@ -774,3 +776,42 @@ authRequiredSpec = do
|
||||||
bar401 <- get "/bar"
|
bar401 <- get "/bar"
|
||||||
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401)
|
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401)
|
||||||
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401)
|
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
type JWTAuthProtect = AuthProtect JWTAuth (JWT VerifiedJWT) 'Strict
|
||||||
|
|
||||||
|
type JWTAuthRequiredAPI = JWTAuthProtect :> "foo" :> Get '[JSON] Person
|
||||||
|
|
||||||
|
|
||||||
|
jwtAuthRequiredApi :: Proxy JWTAuthRequiredAPI
|
||||||
|
jwtAuthRequiredApi = Proxy
|
||||||
|
|
||||||
|
jwtAuthRequiredServer :: Server JWTAuthRequiredAPI
|
||||||
|
jwtAuthRequiredServer = jwtAuthStrict (secret "secret") (const . return $ alice)
|
||||||
|
|
||||||
|
correctToken = "blah"
|
||||||
|
incorrectToken = "blah"
|
||||||
|
|
||||||
|
jwtAuthGet :: ByteString -> ByteString -> WaiSession SResponse
|
||||||
|
jwtAuthGet path token = Test.Hspec.Wai.request methodGet path [("Authorization", "Bearer " <> token)] ""
|
||||||
|
|
||||||
|
jwtAuthRequiredSpec :: Spec
|
||||||
|
jwtAuthRequiredSpec = do
|
||||||
|
describe "JWT Auth" $ do
|
||||||
|
with (return $ serve jwtAuthRequiredApi jwtAuthRequiredServer) $ do
|
||||||
|
it "allows access with the correct token" $ do
|
||||||
|
response <- jwtAuthGet "/foo" correctToken
|
||||||
|
liftIO $ do
|
||||||
|
decode' (simpleBody response) `shouldBe` Just alice
|
||||||
|
it "rejects requests with an incorrect token" $ do
|
||||||
|
jwtAuthGet "/foo" incorrectToken `shouldRespondWith` 401
|
||||||
|
it "rejects requests without auth data" $ do
|
||||||
|
get "/foo" `shouldRespondWith` 401
|
||||||
|
it "responds correctly to requests without auth data" $ do
|
||||||
|
a <- jwtAuthGet "/foo" incorrectToken
|
||||||
|
WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_token\"" a)
|
||||||
|
it "respond correctly to requests with incorrect auth data" $ do
|
||||||
|
a <- get "/foo"
|
||||||
|
WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_request\"" a)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue