diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index f0d60d1e..b3bdbbea 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -40,18 +40,18 @@ import Data.Word8 (isSpace, _colon, toLower) import GHC.TypeLits (KnownSymbol, symbolVal) import Network.HTTP.Types hiding (Header, ResponseHeaders) import Network.Socket (SockAddr) -import Network.Wai (Application, Request, Response, +import Network.Wai (Application, isSecure, httpVersion, Request, Response, ResponseReceived, lazyRequestBody, - pathInfo, rawQueryString, + pathInfo, rawQueryString, remoteHost, requestBody, requestHeaders, requestMethod, responseLBS, - strictRequestBody) + strictRequestBody, vault) import Servant.API ((:<|>) (..), (:>), BasicAuth, Capture, - Delete, Get, Header, + Delete, Get, Header, IsSecure(Secure, NotSecure), MatrixFlag, MatrixParam, MatrixParams, Patch, Post, Put, QueryFlag, QueryParam, QueryParams, Raw, - ReqBody) + RemoteHost, ReqBody, Vault) import Servant.API.ContentTypes (AcceptHeader (..), AllCTRender (..), AllCTUnrender (..)) @@ -80,8 +80,8 @@ type Server layout = ServerT layout (ExceptT ServantErr IO) -- > type BasicAuthVal = ExampleUser -- > basicAuthLookup _ _ _ = return Nothing class BasicAuthLookup lookup where - type BasicAuthVal - basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe BasicAuthVal) + type BasicAuthVal lookup :: * + basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe (BasicAuthVal lookup)) -- * Instances @@ -258,7 +258,7 @@ instance => HasServer (BasicAuth realm lookup :> sublayout) where type ServerT (BasicAuth realm lookup :> sublayout) m - = BasicAuthVal -> ServerT sublayout m + = BasicAuthVal lookup -> ServerT sublayout m route _ action request respond = case lookup "Authorization" (requestHeaders request) of diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index 45519e42..4fe6271b 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -2,6 +2,7 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE FlexibleInstances #-} @@ -12,6 +13,7 @@ module Servant.ServerSpec where import Control.Monad (forM_, when) import Control.Monad.Trans.Except (ExceptT, throwE) import Data.Aeson (FromJSON, ToJSON, decode', encode) +import Data.ByteString (ByteString) import Data.ByteString.Conversion () import Data.Char (toUpper) import Data.Monoid ((<>)) @@ -27,12 +29,13 @@ import Network.HTTP.Types (hAccept, hContentType, import Network.Wai (Application, Request, pathInfo, queryString, rawQueryString, responseLBS) -import Network.Wai.Test (defaultRequest, request, - runSession, simpleBody) +import Network.Wai.Test (assertHeader, defaultRequest, request, + runSession, simpleBody, SResponse) import Test.Hspec (Spec, describe, it, shouldBe) import Test.Hspec.Wai (get, liftIO, matchHeaders, matchStatus, post, request, shouldRespondWith, with, (<:>)) +import Test.Hspec.Wai.Internal (WaiSession(WaiSession)) import Servant.API ((:<|>) (..), (:>), addHeader, Capture, Delete, Get, Header (..), Headers, @@ -40,9 +43,9 @@ import Servant.API ((:<|>) (..), (:>), MatrixParam, MatrixParams, Patch, PlainText, Post, Put, RemoteHost, QueryFlag, QueryParam, QueryParams, Raw, ReqBody) +import Servant.API.Authentication (BasicAuth) import Servant.Server (Server, serve, ServantErr(..), err404) -import Servant.Server.Internal.RoutingApplication - (RouteMismatch (..)) +import Servant.Server.Internal (RouteMismatch (..), BasicAuthLookup(basicAuthLookup, BasicAuthVal)) -- * test data types @@ -94,6 +97,7 @@ spec = do errorsSpec responseHeadersSpec miscReqCombinatorsSpec + authRequiredSpec type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal @@ -728,3 +732,61 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $ go "/host" "\"0.0.0.0:0\"" where go path res = Test.Hspec.Wai.get path `shouldRespondWith` res + +data AuthDB +instance BasicAuthLookup AuthDB where + type BasicAuthVal = Person + basicAuthLookup _ user pass = if user == "servant" && pass == "server" + then return (Just alice) + else return Nothing +-- | we include two endpoints /foo and /bar and we put the BasicAuth +-- portion in two different places +type AuthRequiredAPI = + BasicAuth "foo-realm" AuthDB :> "foo" :> Get '[JSON] Person + :<|> "bar" :> BasicAuth "bar-realm" AuthDB :> Get '[JSON] Animal + +authRequiredApi :: Proxy AuthRequiredAPI +authRequiredApi = Proxy + +authRequiredServer :: Server AuthRequiredAPI +authRequiredServer = const (return alice) :<|> const (return jerry) + +-- base64-encoded "servant:server" +base64ServantColonServer :: ByteString +base64ServantColonServer = "c2VydmFudDpzZXJ2ZXI=" + +-- base64-encoded "user:password" +base64UserColonPassword :: ByteString +base64UserColonPassword = "dXNlcjpwYXNzd29yZA==" + +authGet :: ByteString -> ByteString -> WaiSession SResponse +authGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] "" + +authRequiredSpec :: Spec +authRequiredSpec = do + describe "Servant.API.Authentication" $ do + with (return $ serve authRequiredApi authRequiredServer) $ do + it "allows access with the correct username and password" $ do + response <- authGet "/foo" base64ServantColonServer + liftIO $ do + decode' (simpleBody response) `shouldBe` + Just alice + + response <- authGet "/bar" base64ServantColonServer + liftIO $ do + decode' (simpleBody response) `shouldBe` + Just jerry + + it "rejects requests with the incorrect username and password" $ do + authGet "/foo" base64UserColonPassword `shouldRespondWith` 403 + authGet "/bar" base64UserColonPassword `shouldRespondWith` 403 + + it "does not respond to non-authenticated requests" $ do + get "/foo" `shouldRespondWith` 401 + get "/bar" `shouldRespondWith` 401 + + it "adds the appropriate header to rejected 401 requests" $ do + foo401 <- get "/foo" + bar401 <- get "/bar" + WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401) + WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" bar401)