diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index a885497a..2258f8a8 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -78,8 +78,8 @@ type Server layout = ServerT layout (ExceptT ServantErr IO) -- > type BasicAuthVal = ExampleUser -- > basicAuthLookup _ _ _ = return Nothing class BasicAuthLookup lookup where - type BasicAuthVal - basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe BasicAuthVal) + type BasicAuthVal lookup :: * + basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe (BasicAuthVal lookup)) -- * Instances @@ -266,7 +266,7 @@ instance => HasServer (BasicAuth realm lookup :> sublayout) where type ServerT (BasicAuth realm lookup :> sublayout) m - = BasicAuthVal -> ServerT sublayout m + = BasicAuthVal lookup -> ServerT sublayout m route _ action request respond = case lookup "Authorization" (requestHeaders request) of diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index e017d399..fd7cdb1d 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -16,6 +16,7 @@ import Control.Applicative ((<$>)) import Control.Monad (forM_, when) import Control.Monad.Trans.Except (ExceptT, throwE) import Data.Aeson (FromJSON, ToJSON, decode', encode) +import Data.ByteString (ByteString) import Data.ByteString.Conversion () import Data.Char (toUpper) import Data.Proxy (Proxy (Proxy)) @@ -31,8 +32,8 @@ import Network.Wai (Application, Request, pathInfo, queryString, rawQueryString, responseLBS, responseBuilder) import Network.Wai.Internal (Response(ResponseBuilder)) -import Network.Wai.Test (defaultRequest, request, - runSession, simpleBody) +import Network.Wai.Test (asertHeader, defaultRequest, request, + runSession, simpleBody, SResponse) import Servant.API ((:<|>) (..), (:>), Capture, Delete, Get, Header (..), Headers, HttpVersion, IsSecure (..), JSON, @@ -45,10 +46,21 @@ import Test.Hspec (Spec, describe, it, shouldBe) import Test.Hspec.Wai (get, liftIO, matchHeaders, matchStatus, post, request, shouldRespondWith, with, (<:>)) -import Servant.Server.Internal.RoutingApplication (toApplication, RouteResult(..)) +import Test.Hspec.Wai.Internal (WaiSession(WaiSession)) +import Servant.Server.Internal.RoutingApplication (toApplication) import Servant.Server.Internal.Router (tweakResponse, runRouter, Router, Router'(LeafRouter)) +import Servant.API ((:<|>) (..), (:>), + addHeader, Capture, + Delete, Get, Header (..), Headers, + HttpVersion, IsSecure(..), JSON, MatrixFlag, + MatrixParam, MatrixParams, Patch, PlainText, + Post, Put, RemoteHost, QueryFlag, QueryParam, + QueryParams, Raw, ReqBody) +import Servant.API.Authentication (BasicAuth) +import Servant.Server (Server, serve, ServantErr(..), err404) +import Servant.Server.Internal (RouteMismatch (..), BasicAuthLookup(basicAuthLookup, BasicAuthVal)) -- * test data types @@ -98,6 +110,7 @@ spec = do routerSpec responseHeadersSpec miscReqCombinatorsSpec + authRequiredSpec type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal @@ -574,3 +587,61 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $ go "/host" "\"0.0.0.0:0\"" where go path res = Test.Hspec.Wai.get path `shouldRespondWith` res + +data AuthDB +instance BasicAuthLookup AuthDB where + type BasicAuthVal = Person + basicAuthLookup _ user pass = if user == "servant" && pass == "server" + then return (Just alice) + else return Nothing +-- | we include two endpoints /foo and /bar and we put the BasicAuth +-- portion in two different places +type AuthRequiredAPI = + BasicAuth "foo-realm" AuthDB :> "foo" :> Get '[JSON] Person + :<|> "bar" :> BasicAuth "bar-realm" AuthDB :> Get '[JSON] Animal + +authRequiredApi :: Proxy AuthRequiredAPI +authRequiredApi = Proxy + +authRequiredServer :: Server AuthRequiredAPI +authRequiredServer = const (return alice) :<|> const (return jerry) + +-- base64-encoded "servant:server" +base64ServantColonServer :: ByteString +base64ServantColonServer = "c2VydmFudDpzZXJ2ZXI=" + +-- base64-encoded "user:password" +base64UserColonPassword :: ByteString +base64UserColonPassword = "dXNlcjpwYXNzd29yZA==" + +authGet :: ByteString -> ByteString -> WaiSession SResponse +authGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] "" + +authRequiredSpec :: Spec +authRequiredSpec = do + describe "Servant.API.Authentication" $ do + with (return $ serve authRequiredApi authRequiredServer) $ do + it "allows access with the correct username and password" $ do + response <- authGet "/foo" base64ServantColonServer + liftIO $ do + decode' (simpleBody response) `shouldBe` + Just alice + + response <- authGet "/bar" base64ServantColonServer + liftIO $ do + decode' (simpleBody response) `shouldBe` + Just jerry + + it "rejects requests with the incorrect username and password" $ do + authGet "/foo" base64UserColonPassword `shouldRespondWith` 403 + authGet "/bar" base64UserColonPassword `shouldRespondWith` 403 + + it "does not respond to non-authenticated requests" $ do + get "/foo" `shouldRespondWith` 401 + get "/bar" `shouldRespondWith` 401 + + it "adds the appropriate header to rejected 401 requests" $ do + foo401 <- get "/foo" + bar401 <- get "/bar" + WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401) + WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" bar401)