Add BasicAuth tests

This commit is contained in:
aaron levin 2015-05-12 15:23:44 -04:00 committed by Arian van Putten
parent 17885bc50f
commit d2e2122933
2 changed files with 74 additions and 12 deletions

View file

@ -40,18 +40,18 @@ import Data.Word8 (isSpace, _colon, toLower)
import GHC.TypeLits (KnownSymbol, symbolVal) import GHC.TypeLits (KnownSymbol, symbolVal)
import Network.HTTP.Types hiding (Header, ResponseHeaders) import Network.HTTP.Types hiding (Header, ResponseHeaders)
import Network.Socket (SockAddr) import Network.Socket (SockAddr)
import Network.Wai (Application, Request, Response, import Network.Wai (Application, isSecure, httpVersion, Request, Response,
ResponseReceived, lazyRequestBody, ResponseReceived, lazyRequestBody,
pathInfo, rawQueryString, pathInfo, rawQueryString, remoteHost,
requestBody, requestHeaders, requestBody, requestHeaders,
requestMethod, responseLBS, requestMethod, responseLBS,
strictRequestBody) strictRequestBody, vault)
import Servant.API ((:<|>) (..), (:>), BasicAuth, Capture, import Servant.API ((:<|>) (..), (:>), BasicAuth, Capture,
Delete, Get, Header, Delete, Get, Header, IsSecure(Secure, NotSecure),
MatrixFlag, MatrixParam, MatrixParams, MatrixFlag, MatrixParam, MatrixParams,
Patch, Post, Put, QueryFlag, Patch, Post, Put, QueryFlag,
QueryParam, QueryParams, Raw, QueryParam, QueryParams, Raw,
ReqBody) RemoteHost, ReqBody, Vault)
import Servant.API.ContentTypes (AcceptHeader (..), import Servant.API.ContentTypes (AcceptHeader (..),
AllCTRender (..), AllCTRender (..),
AllCTUnrender (..)) AllCTUnrender (..))
@ -80,8 +80,8 @@ type Server layout = ServerT layout (ExceptT ServantErr IO)
-- > type BasicAuthVal = ExampleUser -- > type BasicAuthVal = ExampleUser
-- > basicAuthLookup _ _ _ = return Nothing -- > basicAuthLookup _ _ _ = return Nothing
class BasicAuthLookup lookup where class BasicAuthLookup lookup where
type BasicAuthVal type BasicAuthVal lookup :: *
basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe BasicAuthVal) basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe (BasicAuthVal lookup))
-- * Instances -- * Instances
@ -258,7 +258,7 @@ instance
=> HasServer (BasicAuth realm lookup :> sublayout) where => HasServer (BasicAuth realm lookup :> sublayout) where
type ServerT (BasicAuth realm lookup :> sublayout) m type ServerT (BasicAuth realm lookup :> sublayout) m
= BasicAuthVal -> ServerT sublayout m = BasicAuthVal lookup -> ServerT sublayout m
route _ action request respond = route _ action request respond =
case lookup "Authorization" (requestHeaders request) of case lookup "Authorization" (requestHeaders request) of

View file

@ -2,6 +2,7 @@
{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
@ -12,6 +13,7 @@ module Servant.ServerSpec where
import Control.Monad (forM_, when) import Control.Monad (forM_, when)
import Control.Monad.Trans.Except (ExceptT, throwE) import Control.Monad.Trans.Except (ExceptT, throwE)
import Data.Aeson (FromJSON, ToJSON, decode', encode) import Data.Aeson (FromJSON, ToJSON, decode', encode)
import Data.ByteString (ByteString)
import Data.ByteString.Conversion () import Data.ByteString.Conversion ()
import Data.Char (toUpper) import Data.Char (toUpper)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
@ -27,12 +29,13 @@ import Network.HTTP.Types (hAccept, hContentType,
import Network.Wai (Application, Request, pathInfo, import Network.Wai (Application, Request, pathInfo,
queryString, rawQueryString, queryString, rawQueryString,
responseLBS) responseLBS)
import Network.Wai.Test (defaultRequest, request, import Network.Wai.Test (assertHeader, defaultRequest, request,
runSession, simpleBody) runSession, simpleBody, SResponse)
import Test.Hspec (Spec, describe, it, shouldBe) import Test.Hspec (Spec, describe, it, shouldBe)
import Test.Hspec.Wai (get, liftIO, matchHeaders, import Test.Hspec.Wai (get, liftIO, matchHeaders,
matchStatus, post, request, matchStatus, post, request,
shouldRespondWith, with, (<:>)) shouldRespondWith, with, (<:>))
import Test.Hspec.Wai.Internal (WaiSession(WaiSession))
import Servant.API ((:<|>) (..), (:>), import Servant.API ((:<|>) (..), (:>),
addHeader, Capture, addHeader, Capture,
Delete, Get, Header (..), Headers, Delete, Get, Header (..), Headers,
@ -40,9 +43,9 @@ import Servant.API ((:<|>) (..), (:>),
MatrixParam, MatrixParams, Patch, PlainText, MatrixParam, MatrixParams, Patch, PlainText,
Post, Put, RemoteHost, QueryFlag, QueryParam, Post, Put, RemoteHost, QueryFlag, QueryParam,
QueryParams, Raw, ReqBody) QueryParams, Raw, ReqBody)
import Servant.API.Authentication (BasicAuth)
import Servant.Server (Server, serve, ServantErr(..), err404) import Servant.Server (Server, serve, ServantErr(..), err404)
import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal (RouteMismatch (..), BasicAuthLookup(basicAuthLookup, BasicAuthVal))
(RouteMismatch (..))
-- * test data types -- * test data types
@ -94,6 +97,7 @@ spec = do
errorsSpec errorsSpec
responseHeadersSpec responseHeadersSpec
miscReqCombinatorsSpec miscReqCombinatorsSpec
authRequiredSpec
type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal
@ -728,3 +732,61 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $
go "/host" "\"0.0.0.0:0\"" go "/host" "\"0.0.0.0:0\""
where go path res = Test.Hspec.Wai.get path `shouldRespondWith` res where go path res = Test.Hspec.Wai.get path `shouldRespondWith` res
data AuthDB
instance BasicAuthLookup AuthDB where
type BasicAuthVal = Person
basicAuthLookup _ user pass = if user == "servant" && pass == "server"
then return (Just alice)
else return Nothing
-- | we include two endpoints /foo and /bar and we put the BasicAuth
-- portion in two different places
type AuthRequiredAPI =
BasicAuth "foo-realm" AuthDB :> "foo" :> Get '[JSON] Person
:<|> "bar" :> BasicAuth "bar-realm" AuthDB :> Get '[JSON] Animal
authRequiredApi :: Proxy AuthRequiredAPI
authRequiredApi = Proxy
authRequiredServer :: Server AuthRequiredAPI
authRequiredServer = const (return alice) :<|> const (return jerry)
-- base64-encoded "servant:server"
base64ServantColonServer :: ByteString
base64ServantColonServer = "c2VydmFudDpzZXJ2ZXI="
-- base64-encoded "user:password"
base64UserColonPassword :: ByteString
base64UserColonPassword = "dXNlcjpwYXNzd29yZA=="
authGet :: ByteString -> ByteString -> WaiSession SResponse
authGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] ""
authRequiredSpec :: Spec
authRequiredSpec = do
describe "Servant.API.Authentication" $ do
with (return $ serve authRequiredApi authRequiredServer) $ do
it "allows access with the correct username and password" $ do
response <- authGet "/foo" base64ServantColonServer
liftIO $ do
decode' (simpleBody response) `shouldBe`
Just alice
response <- authGet "/bar" base64ServantColonServer
liftIO $ do
decode' (simpleBody response) `shouldBe`
Just jerry
it "rejects requests with the incorrect username and password" $ do
authGet "/foo" base64UserColonPassword `shouldRespondWith` 403
authGet "/bar" base64UserColonPassword `shouldRespondWith` 403
it "does not respond to non-authenticated requests" $ do
get "/foo" `shouldRespondWith` 401
get "/bar" `shouldRespondWith` 401
it "adds the appropriate header to rejected 401 requests" $ do
foo401 <- get "/foo"
bar401 <- get "/bar"
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401)
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" bar401)