114 lines
3.2 KiB
Haskell
114 lines
3.2 KiB
Haskell
{-# LANGUAGE CPP #-}
|
|
module Servant.Auth.Server.Internal.Types where
|
|
|
|
import Control.Applicative
|
|
import Control.Monad (MonadPlus(..), ap)
|
|
import Control.Monad.Reader
|
|
import Control.Monad.Time
|
|
import Data.Monoid (Monoid (..))
|
|
import Data.Semigroup (Semigroup (..))
|
|
import Data.Time (getCurrentTime)
|
|
import GHC.Generics (Generic)
|
|
import Network.Wai (Request)
|
|
|
|
import qualified Control.Monad.Fail as Fail
|
|
|
|
-- | The result of an authentication attempt.
|
|
data AuthResult val
|
|
= BadPassword
|
|
| NoSuchUser
|
|
-- | Authentication succeeded.
|
|
| Authenticated val
|
|
-- | If an authentication procedure cannot be carried out - if for example it
|
|
-- expects a password and username in a header that is not present -
|
|
-- @Indefinite@ is returned. This indicates that other authentication
|
|
-- methods should be tried.
|
|
| Indefinite
|
|
deriving (Eq, Show, Read, Generic, Ord, Functor, Traversable, Foldable)
|
|
|
|
instance Semigroup (AuthResult val) where
|
|
Indefinite <> y = y
|
|
x <> _ = x
|
|
|
|
instance Monoid (AuthResult val) where
|
|
mempty = Indefinite
|
|
mappend = (<>)
|
|
|
|
instance Applicative AuthResult where
|
|
pure = return
|
|
(<*>) = ap
|
|
|
|
instance Monad AuthResult where
|
|
return = Authenticated
|
|
Authenticated v >>= f = f v
|
|
BadPassword >>= _ = BadPassword
|
|
NoSuchUser >>= _ = NoSuchUser
|
|
Indefinite >>= _ = Indefinite
|
|
|
|
instance Alternative AuthResult where
|
|
empty = mzero
|
|
(<|>) = mplus
|
|
|
|
instance MonadPlus AuthResult where
|
|
mzero = mempty
|
|
mplus = (<>)
|
|
|
|
|
|
-- | An @AuthCheck@ is the function used to decide the authentication status
|
|
-- (the 'AuthResult') of a request. Different @AuthCheck@s may be combined as a
|
|
-- Monoid or Alternative; the semantics of this is that the *first*
|
|
-- non-'Indefinite' result from left to right is used and the rest are ignored.
|
|
newtype AuthCheck val = AuthCheck
|
|
{ runAuthCheck :: Request -> IO (AuthResult val) }
|
|
deriving (Generic, Functor)
|
|
|
|
instance Semigroup (AuthCheck val) where
|
|
AuthCheck f <> AuthCheck g = AuthCheck $ \x -> do
|
|
fx <- f x
|
|
case fx of
|
|
Indefinite -> g x
|
|
r -> pure r
|
|
|
|
instance Monoid (AuthCheck val) where
|
|
mempty = AuthCheck $ const $ return mempty
|
|
mappend = (<>)
|
|
|
|
instance Applicative AuthCheck where
|
|
pure = return
|
|
(<*>) = ap
|
|
|
|
instance Monad AuthCheck where
|
|
return = AuthCheck . return . return . return
|
|
AuthCheck ac >>= f = AuthCheck $ \req -> do
|
|
aresult <- ac req
|
|
case aresult of
|
|
Authenticated usr -> runAuthCheck (f usr) req
|
|
BadPassword -> return BadPassword
|
|
NoSuchUser -> return NoSuchUser
|
|
Indefinite -> return Indefinite
|
|
|
|
#if !MIN_VERSION_base(4,13,0)
|
|
fail = Fail.fail
|
|
#endif
|
|
|
|
instance Fail.MonadFail AuthCheck where
|
|
fail _ = AuthCheck . const $ return Indefinite
|
|
|
|
instance MonadReader Request AuthCheck where
|
|
ask = AuthCheck $ \x -> return (Authenticated x)
|
|
local f (AuthCheck check) = AuthCheck $ \req -> check (f req)
|
|
|
|
instance MonadIO AuthCheck where
|
|
liftIO action = AuthCheck $ const $ Authenticated <$> action
|
|
|
|
instance MonadTime AuthCheck where
|
|
currentTime = liftIO getCurrentTime
|
|
|
|
instance Alternative AuthCheck where
|
|
empty = mzero
|
|
(<|>) = mplus
|
|
|
|
instance MonadPlus AuthCheck where
|
|
mzero = mempty
|
|
mplus = (<>)
|