diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index e671a0cd..6ba833ce 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -28,7 +28,7 @@ import Network.HTTP.Types (hAccept, hContentType, ok200, parseQuery, status409) import Network.Wai (Application, Request, pathInfo, queryString, rawQueryString, - responseLBS) + responseLBS, responseBuilder) import Network.Wai.Test (assertHeader, defaultRequest, request, runSession, simpleBody, SResponse) import Test.Hspec (Spec, describe, it, shouldBe) @@ -43,9 +43,10 @@ import Servant.API ((:<|>) (..), (:>), MatrixParam, MatrixParams, Patch, PlainText, Post, Put, RemoteHost, QueryFlag, QueryParam, QueryParams, Raw, ReqBody) -import Servant.API.Authentication (BasicAuth) +import Servant.API.Authentication +import Servant.Server.Internal.Authentication import Servant.Server (Server, serve, ServantErr(..), err404) -import Servant.Server.Internal (RouteMismatch (..), BasicAuthLookup(basicAuthLookup, BasicAuthVal)) +import Servant.Server.Internal (RouteMismatch (..)) -- * test data types @@ -656,6 +657,17 @@ prioErrorsSpec = describe "PrioErrors" $ do check put' "/bar" vjson 404 check put' "/foo" vjson 405 + +-- | fake equality to use for testing the RouteMismatch spec (errorSpec). +-- this is a hack around RouteMismatch not having an `Eq` instance. +(=:=) :: RouteMismatch -> RouteMismatch -> Bool +NotFound =:= NotFound = True +WrongMethod =:= WrongMethod = True +(InvalidBody ib1) =:= (InvalidBody ib2) = ib1 == ib2 +(HttpError s1 hs1 mb1) =:= (HttpError s2 hs2 mb2) = s1 == s2 && hs1 == hs2 && mb1 == mb2 +(RouteMismatch _) =:= (RouteMismatch _) = True +_ =:= _ = False + -- | Test server error functionality. errorsSpec :: Spec errorsSpec = do @@ -663,43 +675,52 @@ errorsSpec = do let ib = InvalidBody "The body is invalid" let wm = WrongMethod let nf = NotFound + let rm = RouteMismatch (responseBuilder status409 [] mempty) describe "Servant.Server.Internal.RouteMismatch" $ do - it "HttpError > *" $ do - ib <> he `shouldBe` he - wm <> he `shouldBe` he - nf <> he `shouldBe` he + it "RouteMismatch > *" $ do + (ib <> rm) =:= rm `shouldBe` True + (wm <> rm) =:= rm `shouldBe` True + (nf <> rm) =:= rm `shouldBe` True + (he <> rm) =:= rm `shouldBe` True - he <> ib `shouldBe` he - he <> wm `shouldBe` he - he <> nf `shouldBe` he + (rm <> ib) =:= rm `shouldBe` True + (rm <> wm) =:= rm `shouldBe` True + (rm <> nf) =:= rm `shouldBe` True + (rm <> he) =:= rm `shouldBe` True + + it "RouteMismatch > HttpError > *" $ do + (ib <> he) =:= he `shouldBe` True + (wm <> he) =:= he `shouldBe` True + (nf <> he) =:= he `shouldBe` True + + (he <> ib) =:= he `shouldBe` True + (he <> wm) =:= he `shouldBe` True + (he <> nf) =:= he `shouldBe` True it "HE > InvalidBody > (WM,NF)" $ do - he <> ib `shouldBe` he - wm <> ib `shouldBe` ib - nf <> ib `shouldBe` ib + (wm <> ib) =:= ib `shouldBe` True + (nf <> ib) =:= ib `shouldBe` True - ib <> he `shouldBe` he - ib <> wm `shouldBe` ib - ib <> nf `shouldBe` ib + (ib <> wm) =:= ib `shouldBe` True + (ib <> nf) =:= ib `shouldBe` True it "HE > IB > WrongMethod > NF" $ do - he <> wm `shouldBe` he - ib <> wm `shouldBe` ib - nf <> wm `shouldBe` wm + (nf <> wm) =:= wm `shouldBe` True - wm <> he `shouldBe` he - wm <> ib `shouldBe` ib - wm <> nf `shouldBe` wm + (wm <> nf) =:= wm `shouldBe` True + -- TODO: this is redundant, but maybe helpful for clarity. it "* > NotFound" $ do - he <> nf `shouldBe` he - ib <> nf `shouldBe` ib - wm <> nf `shouldBe` wm + (he <> nf) =:= he `shouldBe` True + (ib <> nf) =:= ib `shouldBe` True + (wm <> nf) =:= wm `shouldBe` True + (rm <> nf) =:= rm `shouldBe` True - nf <> he `shouldBe` he - nf <> ib `shouldBe` ib - nf <> wm `shouldBe` wm + (nf <> he) =:= he `shouldBe` True + (nf <> ib) =:= ib `shouldBe` True + (nf <> wm) =:= wm `shouldBe` True + (nf <> rm) =:= rm `shouldBe` True type MiscCombinatorsAPI = "version" :> HttpVersion :> Get '[JSON] String @@ -733,28 +754,40 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $ where go path res = Test.Hspec.Wai.get path `shouldRespondWith` res -data AuthDB -instance BasicAuthLookup AuthDB where - type BasicAuthVal = Person - basicAuthLookup _ user pass = if user == "servant" && pass == "server" - then return (Just alice) - else return Nothing + -- | we include two endpoints /foo and /bar and we put the BasicAuth -- portion in two different places -type AuthRequiredAPI = - BasicAuth "foo-realm" AuthDB :> "foo" :> Get '[JSON] Person - :<|> "bar" :> BasicAuth "bar-realm" AuthDB :> Get '[JSON] Animal +type AuthUser = ByteString +type BasicAuthFooRealm = AuthProtect (BasicAuth "foo-realm") AuthUser 'Strict +type BasicAuthBarRealm = AuthProtect (BasicAuth "bar-realm") AuthUser 'Strict +type AuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person + :<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal +basicAuthFooCheck :: BasicAuth "foo-realm" -> IO (Maybe AuthUser) +basicAuthFooCheck (BasicAuth user pass) = if user == "servant" && pass == "server" + then return (Just "servant") + else return Nothing + +basicAuthBarCheck :: BasicAuth "bar-realm" -> IO (Maybe AuthUser) +basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar" + then return (Just "bar") + else return Nothing authRequiredApi :: Proxy AuthRequiredAPI authRequiredApi = Proxy authRequiredServer :: Server AuthRequiredAPI -authRequiredServer = const (return alice) :<|> const (return jerry) +authRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice) + :<|> basicAuthStrict basicAuthBarCheck (const . return $ jerry) +-- authRequiredServer = const (return alice) :<|> const (return jerry) -- base64-encoded "servant:server" base64ServantColonServer :: ByteString base64ServantColonServer = "c2VydmFudDpzZXJ2ZXI=" +-- base64-encoded "bar:bar" +base64BarColonPassword :: ByteString +base64BarColonPassword = "YmFyOmJhcg==" + -- base64-encoded "user:password" base64UserColonPassword :: ByteString base64UserColonPassword = "dXNlcjpwYXNzd29yZA==" @@ -769,17 +802,15 @@ authRequiredSpec = do it "allows access with the correct username and password" $ do response <- authGet "/foo" base64ServantColonServer liftIO $ do - decode' (simpleBody response) `shouldBe` - Just alice + decode' (simpleBody response) `shouldBe` Just alice - response <- authGet "/bar" base64ServantColonServer + response <- authGet "/bar" base64BarColonPassword liftIO $ do - decode' (simpleBody response) `shouldBe` - Just jerry + decode' (simpleBody response) `shouldBe` Just jerry it "rejects requests with the incorrect username and password" $ do - authGet "/foo" base64UserColonPassword `shouldRespondWith` 403 - authGet "/bar" base64UserColonPassword `shouldRespondWith` 403 + authGet "/foo" base64UserColonPassword `shouldRespondWith` 401 + authGet "/bar" base64UserColonPassword `shouldRespondWith` 401 it "does not respond to non-authenticated requests" $ do get "/foo" `shouldRespondWith` 401