add realm flag

This commit is contained in:
Alexander Thiemann 2021-01-06 21:03:55 -08:00
parent d06b65c4e6
commit 2d3b40dfeb
6 changed files with 15 additions and 10 deletions

View file

@ -237,7 +237,7 @@ basicAuthHandler =
if username == "servant" && password == "server" if username == "servant" && password == "server"
then return (Authorized ()) then return (Authorized ())
else return Unauthorized else return Unauthorized
in BasicAuthCheck check in BasicAuthCheck True check
basicServerContext :: Context '[ BasicAuthCheck () ] basicServerContext :: Context '[ BasicAuthCheck () ]
basicServerContext = basicAuthHandler :. EmptyContext basicServerContext = basicAuthHandler :. EmptyContext

View file

@ -222,7 +222,7 @@ basicAuthHandler =
if username == "servant" && password == "server" if username == "servant" && password == "server"
then return (Authorized ()) then return (Authorized ())
else return Unauthorized else return Unauthorized
in BasicAuthCheck check in BasicAuthCheck True check
basicServerContext :: Context '[ BasicAuthCheck () ] basicServerContext :: Context '[ BasicAuthCheck () ]
basicServerContext = basicAuthHandler :. EmptyContext basicServerContext = basicAuthHandler :. EmptyContext

View file

@ -43,7 +43,7 @@ module Servant.Server
, descendIntoNamedContext , descendIntoNamedContext
-- * Basic Authentication -- * Basic Authentication
, BasicAuthCheck(BasicAuthCheck, unBasicAuthCheck) , BasicAuthCheck(BasicAuthCheck, basicAuthRunCheck, basicAuthPresentChallenge)
, BasicAuthResult(..) , BasicAuthResult(..)
-- * General Authentication -- * General Authentication

View file

@ -44,9 +44,12 @@ data BasicAuthResult usr
deriving (Eq, Show, Read, Generic, Typeable, Functor) deriving (Eq, Show, Read, Generic, Typeable, Functor)
-- | Datatype wrapping a function used to check authentication. -- | Datatype wrapping a function used to check authentication.
newtype BasicAuthCheck usr = BasicAuthCheck data BasicAuthCheck usr
{ unBasicAuthCheck :: BasicAuthData = BasicAuthCheck
-> IO (BasicAuthResult usr) { basicAuthPresentChallenge :: Bool
-- ^ Decides if we'll send a @WWW-Authenticate@ HTTP header. Sending the header causes browser to
-- surface a prompt for user name and password, which may be undesirable for APIs.
, basicAuthRunCheck :: BasicAuthData -> IO (BasicAuthResult usr)
} }
deriving (Generic, Typeable, Functor) deriving (Generic, Typeable, Functor)
@ -68,7 +71,7 @@ decodeBAHdr req = do
-- | Run and check basic authentication, returning the appropriate http error per -- | Run and check basic authentication, returning the appropriate http error per
-- the spec. -- the spec.
runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr
runBasicAuth req realm (BasicAuthCheck ba) = runBasicAuth req realm (BasicAuthCheck presentChallenge ba) =
case decodeBAHdr req of case decodeBAHdr req of
Nothing -> plzAuthenticate Nothing -> plzAuthenticate
Just e -> liftIO (ba e) >>= \res -> case res of Just e -> liftIO (ba e) >>= \res -> case res of
@ -76,4 +79,6 @@ runBasicAuth req realm (BasicAuthCheck ba) =
NoSuchUser -> plzAuthenticate NoSuchUser -> plzAuthenticate
Unauthorized -> delayedFailFatal err403 Unauthorized -> delayedFailFatal err403
Authorized usr -> return usr Authorized usr -> return usr
where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] } where
plzAuthenticate =
delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm | presentChallenge] }

View file

@ -44,7 +44,7 @@ errorOrderAuthCheck =
if username == "servant" && password == "server" if username == "servant" && password == "server"
then return (Authorized ()) then return (Authorized ())
else return Unauthorized else return Unauthorized
in BasicAuthCheck check in BasicAuthCheck True check
------------------------------------------------------------------------------ ------------------------------------------------------------------------------
-- * Error Order {{{ -- * Error Order {{{

View file

@ -744,7 +744,7 @@ basicAuthServer =
basicAuthContext :: Context '[ BasicAuthCheck () ] basicAuthContext :: Context '[ BasicAuthCheck () ]
basicAuthContext = basicAuthContext =
let basicHandler = BasicAuthCheck $ \(BasicAuthData usr pass) -> let basicHandler = BasicAuthCheck True $ \(BasicAuthData usr pass) ->
if usr == "servant" && pass == "server" if usr == "servant" && pass == "server"
then return (Authorized ()) then return (Authorized ())
else return Unauthorized else return Unauthorized