diff --git a/servant-client/test/Servant/ClientTestUtils.hs b/servant-client/test/Servant/ClientTestUtils.hs index 4b70a7a9..7f92fafb 100644 --- a/servant-client/test/Servant/ClientTestUtils.hs +++ b/servant-client/test/Servant/ClientTestUtils.hs @@ -237,7 +237,7 @@ basicAuthHandler = if username == "servant" && password == "server" then return (Authorized ()) else return Unauthorized - in BasicAuthCheck check + in BasicAuthCheck True check basicServerContext :: Context '[ BasicAuthCheck () ] basicServerContext = basicAuthHandler :. EmptyContext diff --git a/servant-http-streams/test/Servant/ClientSpec.hs b/servant-http-streams/test/Servant/ClientSpec.hs index 41e7fbe4..c2a21fbe 100644 --- a/servant-http-streams/test/Servant/ClientSpec.hs +++ b/servant-http-streams/test/Servant/ClientSpec.hs @@ -222,7 +222,7 @@ basicAuthHandler = if username == "servant" && password == "server" then return (Authorized ()) else return Unauthorized - in BasicAuthCheck check + in BasicAuthCheck True check basicServerContext :: Context '[ BasicAuthCheck () ] basicServerContext = basicAuthHandler :. EmptyContext diff --git a/servant-server/src/Servant/Server.hs b/servant-server/src/Servant/Server.hs index 5d40eb6f..a38689d2 100644 --- a/servant-server/src/Servant/Server.hs +++ b/servant-server/src/Servant/Server.hs @@ -43,7 +43,7 @@ module Servant.Server , descendIntoNamedContext -- * Basic Authentication - , BasicAuthCheck(BasicAuthCheck, unBasicAuthCheck) + , BasicAuthCheck(BasicAuthCheck, basicAuthRunCheck, basicAuthPresentChallenge) , BasicAuthResult(..) -- * General Authentication diff --git a/servant-server/src/Servant/Server/Internal/BasicAuth.hs b/servant-server/src/Servant/Server/Internal/BasicAuth.hs index b92e4b02..6c68d10a 100644 --- a/servant-server/src/Servant/Server/Internal/BasicAuth.hs +++ b/servant-server/src/Servant/Server/Internal/BasicAuth.hs @@ -44,9 +44,12 @@ data BasicAuthResult usr deriving (Eq, Show, Read, Generic, Typeable, Functor) -- | Datatype wrapping a function used to check authentication. -newtype BasicAuthCheck usr = BasicAuthCheck - { unBasicAuthCheck :: BasicAuthData - -> IO (BasicAuthResult usr) +data BasicAuthCheck usr + = BasicAuthCheck + { basicAuthPresentChallenge :: Bool + -- ^ Decides if we'll send a @WWW-Authenticate@ HTTP header. Sending the header causes browser to + -- surface a prompt for user name and password, which may be undesirable for APIs. + , basicAuthRunCheck :: BasicAuthData -> IO (BasicAuthResult usr) } deriving (Generic, Typeable, Functor) @@ -68,7 +71,7 @@ decodeBAHdr req = do -- | Run and check basic authentication, returning the appropriate http error per -- the spec. runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr -runBasicAuth req realm (BasicAuthCheck ba) = +runBasicAuth req realm (BasicAuthCheck presentChallenge ba) = case decodeBAHdr req of Nothing -> plzAuthenticate Just e -> liftIO (ba e) >>= \res -> case res of @@ -76,4 +79,6 @@ runBasicAuth req realm (BasicAuthCheck ba) = NoSuchUser -> plzAuthenticate Unauthorized -> delayedFailFatal err403 Authorized usr -> return usr - where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] } + where + plzAuthenticate = + delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm | presentChallenge] } diff --git a/servant-server/test/Servant/Server/ErrorSpec.hs b/servant-server/test/Servant/Server/ErrorSpec.hs index 72251b21..b9a8f2bf 100644 --- a/servant-server/test/Servant/Server/ErrorSpec.hs +++ b/servant-server/test/Servant/Server/ErrorSpec.hs @@ -44,7 +44,7 @@ errorOrderAuthCheck = if username == "servant" && password == "server" then return (Authorized ()) else return Unauthorized - in BasicAuthCheck check + in BasicAuthCheck True check ------------------------------------------------------------------------------ -- * Error Order {{{ diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index e3dec48e..09419ac4 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -744,7 +744,7 @@ basicAuthServer = basicAuthContext :: Context '[ BasicAuthCheck () ] basicAuthContext = - let basicHandler = BasicAuthCheck $ \(BasicAuthData usr pass) -> + let basicHandler = BasicAuthCheck True $ \(BasicAuthData usr pass) -> if usr == "servant" && pass == "server" then return (Authorized ()) else return Unauthorized