Second Iteration of Authentication
Implemented with the AuthProtected data family as per alp's suggestion. (during rebase, removed monoid instance for RouteResult)
This commit is contained in:
parent
42d0234cdc
commit
4c83843489
7 changed files with 421 additions and 96 deletions
|
@ -36,6 +36,7 @@ library
|
||||||
Servant
|
Servant
|
||||||
Servant.Server
|
Servant.Server
|
||||||
Servant.Server.Internal
|
Servant.Server.Internal
|
||||||
|
Servant.Server.Internal.Authentication
|
||||||
Servant.Server.Internal.Enter
|
Servant.Server.Internal.Enter
|
||||||
Servant.Server.Internal.Router
|
Servant.Server.Internal.Router
|
||||||
Servant.Server.Internal.RoutingApplication
|
Servant.Server.Internal.RoutingApplication
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE PolyKinds #-}
|
{-# LANGUAGE PolyKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
@ -20,41 +21,67 @@ module Servant.Server.Internal
|
||||||
) where
|
) where
|
||||||
|
|
||||||
#if !MIN_VERSION_base(4,8,0)
|
#if !MIN_VERSION_base(4,8,0)
|
||||||
import Control.Applicative ((<$>))
|
import Control.Applicative ((<$>))
|
||||||
#endif
|
#endif
|
||||||
import Control.Monad.Trans.Except (ExceptT)
|
import Control.Monad.Trans.Except (ExceptT)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Lazy as BL
|
import qualified Data.ByteString.Lazy as BL
|
||||||
import qualified Data.Map as M
|
import qualified Data.Map as M
|
||||||
|
import qualified Data.ByteString as B
|
||||||
import Data.ByteString.Base64 (decodeLenient)
|
import Data.ByteString.Base64 (decodeLenient)
|
||||||
import qualified Data.ByteString.Lazy as BL
|
import qualified Data.ByteString.Lazy as BL
|
||||||
|
import qualified Data.Map as M
|
||||||
import Data.Maybe (mapMaybe, fromMaybe)
|
import Data.Maybe (mapMaybe, fromMaybe)
|
||||||
import Data.String (fromString)
|
import Data.String (fromString)
|
||||||
import Data.String.Conversions (ConvertibleStrings, cs, (<>))
|
import Data.String.Conversions (cs, (<>), ConvertibleStrings)
|
||||||
import Data.Text (Text)
|
import Data.Text (Text)
|
||||||
|
import qualified Data.Text as T
|
||||||
|
import Data.Text.Encoding (decodeUtf8,
|
||||||
|
encodeUtf8)
|
||||||
import Data.Typeable
|
import Data.Typeable
|
||||||
import Data.Word8 (isSpace, _colon, toLower)
|
import GHC.TypeLits (KnownSymbol,
|
||||||
import GHC.TypeLits (KnownSymbol, symbolVal)
|
symbolVal)
|
||||||
import Network.HTTP.Types hiding (Header, ResponseHeaders)
|
import Network.HTTP.Types hiding (Header,
|
||||||
import Network.Socket (SockAddr)
|
ResponseHeaders)
|
||||||
import Network.Wai (Application, lazyRequestBody,
|
import Network.Socket (SockAddr)
|
||||||
rawQueryString, requestHeaders,
|
import Network.Wai (Application,
|
||||||
requestMethod, responseLBS, remoteHost,
|
httpVersion,
|
||||||
isSecure, vault, httpVersion, Response,
|
isSecure,
|
||||||
Request, pathInfo)
|
lazyRequestBody,
|
||||||
import Servant.API ((:<|>) (..), (:>), Capture,
|
rawQueryString,
|
||||||
Delete, Get, Header,
|
remoteHost,
|
||||||
IsSecure(..), Patch, Post, Put,
|
requestHeaders,
|
||||||
QueryFlag, QueryParam, QueryParams,
|
requestMethod,
|
||||||
Raw, RemoteHost, ReqBody, Vault)
|
responseLBS, vault)
|
||||||
import Servant.API.ContentTypes (AcceptHeader (..),
|
import Servant.API ((:<|>) (..), (:>),
|
||||||
AllCTRender (..),
|
Capture, Delete,
|
||||||
AllCTUnrender (..),
|
Get, Header, IsSecure (Secure, NotSecure),
|
||||||
AllMime,
|
MatrixFlag,
|
||||||
canHandleAcceptH)
|
MatrixParam,
|
||||||
import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders,
|
MatrixParams,
|
||||||
getResponse)
|
Patch, Post, Put,
|
||||||
|
QueryFlag,
|
||||||
|
QueryParam,
|
||||||
|
QueryParams, Raw,
|
||||||
|
RemoteHost,
|
||||||
|
ReqBody, Vault)
|
||||||
|
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
||||||
|
AuthProtect,
|
||||||
|
AuthProtected)
|
||||||
|
import Servant.API.ContentTypes (AcceptHeader (..),
|
||||||
|
AllCTRender (..),
|
||||||
|
AllCTUnrender (..))
|
||||||
|
import Servant.API.ResponseHeaders (GetHeaders,
|
||||||
|
Headers,
|
||||||
|
getHeaders,
|
||||||
|
getResponse)
|
||||||
|
import Servant.Common.Text (FromText, fromText)
|
||||||
|
import Servant.Server.Internal.Authentication (AuthData (authData),
|
||||||
|
AuthProtected (..),
|
||||||
|
checkAuthStrict,
|
||||||
|
onMissingAuthData,
|
||||||
|
onUnauthenticated)
|
||||||
|
import Servant.Server.Internal.PathInfo
|
||||||
import Servant.Server.Internal.Router
|
import Servant.Server.Internal.Router
|
||||||
import Servant.Server.Internal.RoutingApplication
|
import Servant.Server.Internal.RoutingApplication
|
||||||
import Servant.Server.Internal.ServantErr
|
import Servant.Server.Internal.ServantErr
|
||||||
|
@ -69,18 +96,6 @@ class HasServer layout where
|
||||||
|
|
||||||
type Server layout = ServerT layout (ExceptT ServantErr IO)
|
type Server layout = ServerT layout (ExceptT ServantErr IO)
|
||||||
|
|
||||||
-- | A type-indexed class to encapsulate Basic authentication handling.
|
|
||||||
-- Authentication handling is indexed by the lookup type.
|
|
||||||
--
|
|
||||||
-- > data ExampleAuthDB
|
|
||||||
-- > data ExampleUser
|
|
||||||
-- > instance BasicAuthLookup ExampleAuthDB where
|
|
||||||
-- > type BasicAuthVal = ExampleUser
|
|
||||||
-- > basicAuthLookup _ _ _ = return Nothing
|
|
||||||
class BasicAuthLookup lookup where
|
|
||||||
type BasicAuthVal lookup :: *
|
|
||||||
basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe (BasicAuthVal lookup))
|
|
||||||
|
|
||||||
-- * Instances
|
-- * Instances
|
||||||
|
|
||||||
-- | A server for @a ':<|>' b@ first tries to match the request against the route
|
-- | A server for @a ':<|>' b@ first tries to match the request against the route
|
||||||
|
@ -254,54 +269,74 @@ instance
|
||||||
|
|
||||||
route Proxy = methodRouterHeaders methodDelete (Proxy :: Proxy ctypes) ok200
|
route Proxy = methodRouterHeaders methodDelete (Proxy :: Proxy ctypes) ok200
|
||||||
|
|
||||||
-- | Authentication
|
-- | Authentication in Strict mode.
|
||||||
instance
|
instance
|
||||||
#if MIN_VERSION_base(4,8,0)
|
#if MIN_VERSION_base(4,8,0)
|
||||||
{-# OVERLAPPABLE #-}
|
{-# OVERLAPPABLE #-}
|
||||||
#endif
|
#endif
|
||||||
( HasServer sublayout
|
(AuthData authdata , HasServer sublayout) => HasServer (AuthProtect authdata (usr :: *) 'Strict :> sublayout) where
|
||||||
, BasicAuthLookup lookup
|
|
||||||
, KnownSymbol realm
|
|
||||||
)
|
|
||||||
=> HasServer (BasicAuth realm lookup :> sublayout) where
|
|
||||||
|
|
||||||
type ServerT (BasicAuth realm lookup :> sublayout) m
|
type ServerT (AuthProtect authdata usr 'Strict :> sublayout) m = AuthProtected authdata usr (usr -> ServerT sublayout m) 'Strict
|
||||||
= BasicAuthVal lookup -> ServerT sublayout m
|
|
||||||
|
|
||||||
route _ action request respond =
|
route _ subserver = WithRequest $ \req ->
|
||||||
case lookup "Authorization" (requestHeaders request) of
|
route (Proxy :: Proxy sublayout) $ do
|
||||||
Nothing -> respond . succeedWith $ authFailure401
|
-- Note: this may perform IO for each attempt at matching.
|
||||||
Just authBs ->
|
rr <- routeResult <$> subserver
|
||||||
-- ripped from: https://hackage.haskell.org/package/wai-extra-1.3.4.5/docs/src/Network-Wai-Middleware-HttpAuth.html#basicAuth
|
|
||||||
let (x,y) = B.break isSpace authBs in
|
|
||||||
if B.map toLower x == "basic"
|
|
||||||
-- check base64-encoded password
|
|
||||||
then checkB64AndRespond (B.dropWhile isSpace y)
|
|
||||||
-- Authenticaiton header is not Basic, fail with 401.
|
|
||||||
else respond . succeedWith $ authFailure401
|
|
||||||
where
|
|
||||||
realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
|
|
||||||
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
|
|
||||||
authFailure401 = responseLBS status401 [("WWW-Authenticate", headerBytes)] ""
|
|
||||||
checkB64AndRespond encoded =
|
|
||||||
case B.uncons passwordWithColonAtHead of
|
|
||||||
Just (_, password) -> do
|
|
||||||
-- let's check these credentials using the user-provided lookup method
|
|
||||||
maybeAuthData <- basicAuthLookup (Proxy :: Proxy lookup) username password
|
|
||||||
case maybeAuthData of
|
|
||||||
Nothing -> respond . succeedWith $ authFailure403
|
|
||||||
(Just authData) ->
|
|
||||||
route (Proxy :: Proxy sublayout) (action authData) request respond
|
|
||||||
|
|
||||||
-- no username:password present
|
case rr of
|
||||||
Nothing -> respond . succeedWith $ authFailure401
|
-- Successful route match, so we extract the author-provided
|
||||||
where
|
-- auth data.
|
||||||
authFailure403 = responseLBS status403 [] ""
|
Right authProtectionStrict ->
|
||||||
raw = decodeLenient encoded
|
case authData req of
|
||||||
-- split username and password at the colon ':' char.
|
-- could not pull authenticate data out of the request
|
||||||
(username, passwordWithColonAtHead) = B.breakByte _colon raw
|
Nothing -> do
|
||||||
|
-- we're in strict mode: don't let the request go
|
||||||
|
-- call the provided "on missing auth" handler
|
||||||
|
resp <- onMissingAuthData (authHandlers authProtectionStrict)
|
||||||
|
return $ failWith (RouteMismatch resp)
|
||||||
|
|
||||||
|
-- succesfully pulled auth data out of the Request
|
||||||
|
Just authData' -> do
|
||||||
|
mUsr <- (checkAuthStrict authProtectionStrict) authData'
|
||||||
|
case mUsr of
|
||||||
|
-- this user is not authenticated.
|
||||||
|
Nothing -> do
|
||||||
|
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData'
|
||||||
|
return $ failWith (RouteMismatch resp)
|
||||||
|
|
||||||
|
-- this user is authenticated.
|
||||||
|
Just usr ->
|
||||||
|
(return . succeedWith . subServerStrict authProtectionStrict) usr
|
||||||
|
-- route did not match, propagate failure.
|
||||||
|
Left rMismatch ->
|
||||||
|
return (failWith rMismatch)
|
||||||
|
|
||||||
|
-- | Authentication in Lax mode.
|
||||||
|
instance
|
||||||
|
#if MIN_VERSION_base(4,8,0)
|
||||||
|
{-# OVERLAPPABLE #-}
|
||||||
|
#endif
|
||||||
|
(AuthData authdata , HasServer sublayout) => HasServer (AuthProtect authdata (usr :: *) 'Lax :> sublayout) where
|
||||||
|
|
||||||
|
type ServerT (AuthProtect authdata usr 'Lax :> sublayout) m = AuthProtected authdata usr (Maybe usr -> ServerT sublayout m) 'Lax
|
||||||
|
|
||||||
|
route _ subserver = WithRequest $ \req ->
|
||||||
|
route (Proxy :: Proxy sublayout) $ do
|
||||||
|
-- Note: this may perform IO for each attempt at matching.
|
||||||
|
rr <- routeResult <$> subserver
|
||||||
|
-- Successful route match, so we extract the author-provided
|
||||||
|
-- auth data.
|
||||||
|
case rr of
|
||||||
|
-- route matched, extract author-provided lax authentication data
|
||||||
|
Right authProtectionLax -> do
|
||||||
|
-- extract a user from the request object and perform
|
||||||
|
-- authentication on it. In Lax mode, we just pass `Maybe usr`
|
||||||
|
-- to the autho.
|
||||||
|
musr <- maybe (pure Nothing) (checkAuthLax authProtectionLax) (authData req)
|
||||||
|
(return . succeedWith . subServerLax authProtectionLax) musr
|
||||||
|
-- route did not match, propagate failure
|
||||||
|
Left rMismatch ->
|
||||||
|
return (failWith rMismatch)
|
||||||
|
|
||||||
-- | When implementing the handler for a 'Get' endpoint,
|
-- | When implementing the handler for a 'Get' endpoint,
|
||||||
-- just like for 'Servant.API.Delete.Delete', 'Servant.API.Post.Post'
|
-- just like for 'Servant.API.Delete.Delete', 'Servant.API.Post.Post'
|
||||||
|
|
102
servant-server/src/Servant/Server/Internal/Authentication.hs
Normal file
102
servant-server/src/Servant/Server/Internal/Authentication.hs
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
|
||||||
|
module Servant.Server.Internal.Authentication
|
||||||
|
( AuthProtected (..)
|
||||||
|
, AuthData (..)
|
||||||
|
, AuthHandlers (..)
|
||||||
|
, basicAuthLax
|
||||||
|
, basicAuthStrict
|
||||||
|
, laxProtect
|
||||||
|
, strictProtect
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad (guard)
|
||||||
|
import qualified Data.ByteString as B
|
||||||
|
import Data.ByteString.Base64 (decodeLenient)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Data.Proxy (Proxy (Proxy))
|
||||||
|
import Data.String (fromString)
|
||||||
|
import Data.Word8 (isSpace, toLower, _colon)
|
||||||
|
import GHC.TypeLits (KnownSymbol, symbolVal)
|
||||||
|
import Network.HTTP.Types.Status (status401)
|
||||||
|
import Network.Wai (Request, Response, requestHeaders,
|
||||||
|
responseBuilder)
|
||||||
|
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
|
||||||
|
AuthProtected,
|
||||||
|
BasicAuth (BasicAuth))
|
||||||
|
|
||||||
|
-- | Class to represent the ability to extract authentication-related
|
||||||
|
-- data from a 'Request' object.
|
||||||
|
class AuthData a where
|
||||||
|
authData :: Request -> Maybe a
|
||||||
|
|
||||||
|
-- | handlers to deal with authentication failures.
|
||||||
|
data AuthHandlers authData = AuthHandlers
|
||||||
|
{ -- we couldn't find the right type of auth data (or any, for that matter)
|
||||||
|
onMissingauthData :: IO Response
|
||||||
|
,
|
||||||
|
-- we found the right type of auth data in the request but the check failed
|
||||||
|
onUnauthenticated :: authData -> IO Response
|
||||||
|
}
|
||||||
|
|
||||||
|
-- | concrete type to provide when in 'Strict' mode.
|
||||||
|
data instance AuthProtected authData usr subserver 'Strict =
|
||||||
|
AuthProtectedStrict { checkAuthStrict :: authData -> IO (Maybe usr)
|
||||||
|
, subServerStrict :: subserver
|
||||||
|
, authHandlers :: AuthHandlers authData
|
||||||
|
}
|
||||||
|
|
||||||
|
-- | concrete type to provide when in 'Lax' mode.
|
||||||
|
data instance AuthProtected authData usr subserver 'Lax =
|
||||||
|
AuthProtectedLax { checkAuthLax :: authData -> IO (Maybe usr)
|
||||||
|
, subServerLax :: subserver
|
||||||
|
}
|
||||||
|
|
||||||
|
-- | handy function to build an auth-protected bit of API with a Lax policy
|
||||||
|
laxProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
|
||||||
|
-> subserver -- ^ the handlers for the auth-aware bits of the API
|
||||||
|
-> AuthProtected authData usr subserver 'Lax
|
||||||
|
laxProtect = AuthProtectedLax
|
||||||
|
|
||||||
|
-- | handy function to build an auth-protected bit of API with a Strict policy
|
||||||
|
strictProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
|
||||||
|
-> subserver -- ^ handlers for the auth-protected bits of the API
|
||||||
|
-> AuthHandlers authData -- ^ functions to call on auth failure
|
||||||
|
-> AuthProtected authData usr subserver 'Strict
|
||||||
|
strictProtect = AuthProtectedStrict
|
||||||
|
|
||||||
|
-- | 'BasicAuth' instance for authData
|
||||||
|
instance AuthData (BasicAuth realm) where
|
||||||
|
authData request = do
|
||||||
|
authBs <- lookup "Authorization" (requestHeaders request)
|
||||||
|
let (x,y) = B.break isSpace authBs
|
||||||
|
guard (B.map toLower x == "basic")
|
||||||
|
-- decode the base64-encoded username and password
|
||||||
|
let (username, passWithColonAtHead) = B.break (== _colon) (decodeLenient (B.dropWhile isSpace y))
|
||||||
|
(_, password) <- B.uncons passWithColonAtHead
|
||||||
|
return $ BasicAuth username password
|
||||||
|
|
||||||
|
-- | handlers for Basic Authentication.
|
||||||
|
basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth realm)
|
||||||
|
basicAuthHandlers =
|
||||||
|
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
|
||||||
|
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
|
||||||
|
authFailure = responseBuilder status401 [("WWW-Authenticate", headerBytes)] mempty in
|
||||||
|
AuthHandlers (return authFailure) ((const . return) authFailure)
|
||||||
|
|
||||||
|
-- | Basic authentication combinator with strict failure.
|
||||||
|
basicAuthStrict :: KnownSymbol realm
|
||||||
|
=> (BasicAuth realm -> IO (Maybe usr))
|
||||||
|
-> subserver
|
||||||
|
-> AuthProtected (BasicAuth realm) usr subserver 'Strict
|
||||||
|
basicAuthStrict check subserver = strictProtect check subserver basicAuthHandlers
|
||||||
|
|
||||||
|
-- | Basic authentication combinator with lax failure.
|
||||||
|
basicAuthLax :: KnownSymbol realm
|
||||||
|
=> (BasicAuth realm -> IO (Maybe usr))
|
||||||
|
-> subserver
|
||||||
|
-> AuthProtected (BasicAuth realm) usr subserver 'Lax
|
||||||
|
basicAuthLax = laxProtect
|
|
@ -1,4 +1,5 @@
|
||||||
{-# LANGUAGE CPP #-}
|
{-# LANGUAGE CPP #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE DeriveDataTypeable #-}
|
{-# LANGUAGE DeriveDataTypeable #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE FunctionalDependencies #-}
|
{-# LANGUAGE FunctionalDependencies #-}
|
||||||
|
@ -13,7 +14,7 @@ module Servant.Server.Internal.Enter where
|
||||||
#if !MIN_VERSION_base(4,8,0)
|
#if !MIN_VERSION_base(4,8,0)
|
||||||
import Control.Applicative
|
import Control.Applicative
|
||||||
#endif
|
#endif
|
||||||
import qualified Control.Category as C
|
import qualified Control.Category as C
|
||||||
#if MIN_VERSION_mtl(2,2,1)
|
#if MIN_VERSION_mtl(2,2,1)
|
||||||
import Control.Monad.Except
|
import Control.Monad.Except
|
||||||
#endif
|
#endif
|
||||||
|
@ -24,9 +25,16 @@ import qualified Control.Monad.State.Lazy as LState
|
||||||
import qualified Control.Monad.State.Strict as SState
|
import qualified Control.Monad.State.Strict as SState
|
||||||
import qualified Control.Monad.Writer.Lazy as LWriter
|
import qualified Control.Monad.Writer.Lazy as LWriter
|
||||||
import qualified Control.Monad.Writer.Strict as SWriter
|
import qualified Control.Monad.Writer.Strict as SWriter
|
||||||
|
import qualified Control.Monad.State.Lazy as LState
|
||||||
|
import qualified Control.Monad.State.Strict as SState
|
||||||
|
import qualified Control.Monad.Writer.Lazy as LWriter
|
||||||
|
import qualified Control.Monad.Writer.Strict as SWriter
|
||||||
import Data.Typeable
|
import Data.Typeable
|
||||||
import Servant.API
|
import Servant.API
|
||||||
|
|
||||||
|
import Servant.API.Authentication
|
||||||
|
import Servant.Server.Internal.Authentication (AuthProtected (AuthProtectedStrict, AuthProtectedLax))
|
||||||
|
|
||||||
class Enter typ arg ret | typ arg -> ret, typ ret -> arg where
|
class Enter typ arg ret | typ arg -> ret, typ ret -> arg where
|
||||||
enter :: arg -> typ -> ret
|
enter :: arg -> typ -> ret
|
||||||
|
|
||||||
|
@ -95,3 +103,12 @@ squashNat = Nat squash
|
||||||
-- | Like @mmorph@'s `generalize`.
|
-- | Like @mmorph@'s `generalize`.
|
||||||
generalizeNat :: Applicative m => Identity :~> m
|
generalizeNat :: Applicative m => Identity :~> m
|
||||||
generalizeNat = Nat (pure . runIdentity)
|
generalizeNat = Nat (pure . runIdentity)
|
||||||
|
|
||||||
|
-- | 'Enter' instance for AuthProtectedStrict
|
||||||
|
instance Enter subserver arg ret => Enter (AuthProtected authData usr subserver 'Strict) arg (AuthProtected authData usr ret 'Strict) where
|
||||||
|
enter arg (AuthProtectedStrict check subserver handlers) = AuthProtectedStrict check (enter arg subserver) handlers
|
||||||
|
|
||||||
|
|
||||||
|
-- | 'Enter' instance for AuthProtectedLax
|
||||||
|
instance Enter subserver arg ret => Enter (AuthProtected authData usr subserver 'Lax) arg (AuthProtected authData usr ret 'Lax) where
|
||||||
|
enter arg (AuthProtectedLax check subserver) = AuthProtectedLax check (enter arg subserver)
|
||||||
|
|
|
@ -15,6 +15,9 @@ import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Lazy as BL
|
import qualified Data.ByteString.Lazy as BL
|
||||||
import Data.IORef (newIORef, readIORef,
|
import Data.IORef (newIORef, readIORef,
|
||||||
writeIORef)
|
writeIORef)
|
||||||
|
import Data.Maybe (fromMaybe)
|
||||||
|
import Data.String (fromString)
|
||||||
|
import Network.HTTP.Types hiding (Header, ResponseHeaders)
|
||||||
import Network.Wai (Application, Request,
|
import Network.Wai (Application, Request,
|
||||||
Response, ResponseReceived,
|
Response, ResponseReceived,
|
||||||
requestBody,
|
requestBody,
|
||||||
|
@ -33,6 +36,52 @@ data RouteResult a =
|
||||||
| Route !a
|
| Route !a
|
||||||
deriving (Eq, Show, Read, Functor)
|
deriving (Eq, Show, Read, Functor)
|
||||||
|
|
||||||
|
-- Note that the ordering of the constructors has great significance! It
|
||||||
|
-- determines the Ord instance and, consequently, the monoid instance.
|
||||||
|
data RouteMismatch =
|
||||||
|
NotFound -- ^ the usual "not found" error
|
||||||
|
| WrongMethod -- ^ a more informative "you just got the HTTP method wrong" error
|
||||||
|
| UnsupportedMediaType -- ^ request body has unsupported media type
|
||||||
|
| InvalidBody String -- ^ an even more informative "your json request body wasn't valid" error
|
||||||
|
| HttpError Status [Header] (Maybe BL.ByteString) -- ^ an even even more informative arbitrary HTTP response code error.
|
||||||
|
| RouteMismatch Response -- ^ an arbitrary mismatch with custom Response.
|
||||||
|
|
||||||
|
instance Show RouteMismatch where
|
||||||
|
show = const "hello"
|
||||||
|
|
||||||
|
-- | specialized 'Less Than' for use with Monoid RouteMismatch
|
||||||
|
(<=:) :: RouteMismatch -> RouteMismatch -> Bool
|
||||||
|
{-# INLINE (<=:) #-}
|
||||||
|
NotFound <=: _ = True
|
||||||
|
WrongMethod <=: rmm = not (rmm <=: NotFound)
|
||||||
|
UnsupportedMediaType <=: rmm = not (rmm <=: WrongMethod)
|
||||||
|
InvalidBody _ <=: rmm = not (rmm <=: UnsupportedMediaType)
|
||||||
|
HttpError _ _ _ <=: rmm = not (rmm <=: (InvalidBody ""))
|
||||||
|
RouteMismatch _ <=: _ = False
|
||||||
|
|
||||||
|
instance Monoid RouteMismatch where
|
||||||
|
mempty = NotFound
|
||||||
|
-- The following isn't great, since it picks @InvalidBody@ based on
|
||||||
|
-- alphabetical ordering, but any choice would be arbitrary.
|
||||||
|
--
|
||||||
|
-- "As one judge said to the other, 'Be just and if you can't be just, be
|
||||||
|
-- arbitrary'" -- William Burroughs
|
||||||
|
--
|
||||||
|
-- It used to be the case that `mappend = max` but getting rid of the `Eq`
|
||||||
|
-- and `Ord` instance meant we had to roll out our own max ;\
|
||||||
|
rmm `mappend` NotFound = rmm
|
||||||
|
NotFound `mappend` rmm = rmm
|
||||||
|
WrongMethod `mappend` rmm | rmm <=: WrongMethod = WrongMethod
|
||||||
|
WrongMethod `mappend` rmm = rmm
|
||||||
|
UnsupportedMediaType `mappend` rmm | rmm <=: UnsupportedMediaType = UnsupportedMediaType
|
||||||
|
UnsupportedMediaType `mappend` rmm = rmm
|
||||||
|
i@(InvalidBody _) `mappend` rmm | rmm <=: i = i
|
||||||
|
InvalidBody _ `mappend` rmm = rmm
|
||||||
|
h@(HttpError _ _ _) `mappend` rmm | rmm <=: h = h
|
||||||
|
HttpError _ _ _ `mappend` rmm = rmm
|
||||||
|
r@(RouteMismatch _) `mappend` _ = r
|
||||||
|
>>>>>>> 272091e... Second Iteration of Authentication
|
||||||
|
|
||||||
data ReqBodyState = Uncalled
|
data ReqBodyState = Uncalled
|
||||||
| Called !B.ByteString
|
| Called !B.ByteString
|
||||||
| Done !B.ByteString
|
| Done !B.ByteString
|
||||||
|
@ -62,6 +111,7 @@ toApplication ra request respond = do
|
||||||
|
|
||||||
ra request{ requestBody = memoReqBody } routingRespond
|
ra request{ requestBody = memoReqBody } routingRespond
|
||||||
where
|
where
|
||||||
|
<<<<<<< HEAD
|
||||||
routingRespond :: RouteResult Response -> IO ResponseReceived
|
routingRespond :: RouteResult Response -> IO ResponseReceived
|
||||||
routingRespond (Fail err) = respond $ responseServantErr err
|
routingRespond (Fail err) = respond $ responseServantErr err
|
||||||
routingRespond (FailFatal err) = respond $ responseServantErr err
|
routingRespond (FailFatal err) = respond $ responseServantErr err
|
||||||
|
@ -235,6 +285,25 @@ runDelayed (Delayed captures method body server) =
|
||||||
-- Also takes a continuation for how to turn the
|
-- Also takes a continuation for how to turn the
|
||||||
-- result of the delayed server into a response.
|
-- result of the delayed server into a response.
|
||||||
runAction :: Delayed (ExceptT ServantErr IO a)
|
runAction :: Delayed (ExceptT ServantErr IO a)
|
||||||
|
=======
|
||||||
|
routingRespond :: Either RouteMismatch Response -> IO ResponseReceived
|
||||||
|
routingRespond (Left NotFound) =
|
||||||
|
respond $ responseLBS notFound404 [] "not found"
|
||||||
|
routingRespond (Left WrongMethod) =
|
||||||
|
respond $ responseLBS methodNotAllowed405 [] "method not allowed"
|
||||||
|
routingRespond (Left (InvalidBody err)) =
|
||||||
|
respond $ responseLBS badRequest400 [] $ fromString $ "invalid request body: " ++ err
|
||||||
|
routingRespond (Left UnsupportedMediaType) =
|
||||||
|
respond $ responseLBS unsupportedMediaType415 [] "unsupported media type"
|
||||||
|
routingRespond (Left (HttpError status headers body)) =
|
||||||
|
respond $ responseLBS status headers $ fromMaybe (BL.fromStrict $ statusMessage status) body
|
||||||
|
routingRespond (Left (RouteMismatch resp)) =
|
||||||
|
respond resp
|
||||||
|
routingRespond (Right response) =
|
||||||
|
respond response
|
||||||
|
|
||||||
|
runAction :: IO (RouteResult (ExceptT ServantErr IO a))
|
||||||
|
>>>>>>> 272091e... Second Iteration of Authentication
|
||||||
-> (RouteResult Response -> IO r)
|
-> (RouteResult Response -> IO r)
|
||||||
-> (a -> RouteResult Response)
|
-> (a -> RouteResult Response)
|
||||||
-> IO r
|
-> IO r
|
||||||
|
|
|
@ -556,6 +556,100 @@ routerSpec = do
|
||||||
it "calls f on route result" $ do
|
it "calls f on route result" $ do
|
||||||
get "" `shouldRespondWith` 202
|
get "" `shouldRespondWith` 202
|
||||||
|
|
||||||
|
type PrioErrorsApi = ReqBody '[JSON] Person :> "foo" :> Get '[JSON] Integer
|
||||||
|
|
||||||
|
prioErrorsApi :: Proxy PrioErrorsApi
|
||||||
|
prioErrorsApi = Proxy
|
||||||
|
|
||||||
|
-- | Test the relative priority of error responses from the server.
|
||||||
|
--
|
||||||
|
-- In particular, we check whether matching continues even if a 'ReqBody'
|
||||||
|
-- or similar construct is encountered early in a path. We don't want to
|
||||||
|
-- see a complaint about the request body unless the path actually matches.
|
||||||
|
--
|
||||||
|
prioErrorsSpec :: Spec
|
||||||
|
prioErrorsSpec = describe "PrioErrors" $ do
|
||||||
|
let server = return . age
|
||||||
|
with (return $ serve prioErrorsApi server) $ do
|
||||||
|
let check (mdescr, method) path (cdescr, ctype, body) resp =
|
||||||
|
it fulldescr $
|
||||||
|
Test.Hspec.Wai.request method path [(hContentType, ctype)] body
|
||||||
|
`shouldRespondWith` resp
|
||||||
|
where
|
||||||
|
fulldescr = "returns " ++ show (matchStatus resp) ++ " on " ++ mdescr
|
||||||
|
++ " " ++ cs path ++ " (" ++ cdescr ++ ")"
|
||||||
|
|
||||||
|
get' = ("GET", methodGet)
|
||||||
|
put' = ("PUT", methodPut)
|
||||||
|
|
||||||
|
txt = ("text" , "text/plain;charset=utf8" , "42" )
|
||||||
|
ijson = ("invalid json", "application/json;charset=utf8", "invalid" )
|
||||||
|
vjson = ("valid json" , "application/json;charset=utf8", encode alice)
|
||||||
|
|
||||||
|
check get' "/" txt 404
|
||||||
|
check get' "/bar" txt 404
|
||||||
|
check get' "/foo" txt 415
|
||||||
|
check put' "/" txt 404
|
||||||
|
check put' "/bar" txt 404
|
||||||
|
check put' "/foo" txt 405
|
||||||
|
check get' "/" ijson 404
|
||||||
|
check get' "/bar" ijson 404
|
||||||
|
check get' "/foo" ijson 400
|
||||||
|
check put' "/" ijson 404
|
||||||
|
check put' "/bar" ijson 404
|
||||||
|
check put' "/foo" ijson 405
|
||||||
|
check get' "/" vjson 404
|
||||||
|
check get' "/bar" vjson 404
|
||||||
|
check get' "/foo" vjson 200
|
||||||
|
check put' "/" vjson 404
|
||||||
|
check put' "/bar" vjson 404
|
||||||
|
check put' "/foo" vjson 405
|
||||||
|
|
||||||
|
-- | Test server error functionality.
|
||||||
|
errorsSpec :: Spec
|
||||||
|
errorsSpec = do
|
||||||
|
let he = HttpError status409 [] (Just "A custom error")
|
||||||
|
let ib = InvalidBody "The body is invalid"
|
||||||
|
let wm = WrongMethod
|
||||||
|
let nf = NotFound
|
||||||
|
|
||||||
|
describe "Servant.Server.Internal.RouteMismatch" $ do
|
||||||
|
it "HttpError > *" $ do
|
||||||
|
ib <> he `shouldBe` he
|
||||||
|
wm <> he `shouldBe` he
|
||||||
|
nf <> he `shouldBe` he
|
||||||
|
|
||||||
|
he <> ib `shouldBe` he
|
||||||
|
he <> wm `shouldBe` he
|
||||||
|
he <> nf `shouldBe` he
|
||||||
|
|
||||||
|
it "HE > InvalidBody > (WM,NF)" $ do
|
||||||
|
he <> ib `shouldBe` he
|
||||||
|
wm <> ib `shouldBe` ib
|
||||||
|
nf <> ib `shouldBe` ib
|
||||||
|
|
||||||
|
ib <> he `shouldBe` he
|
||||||
|
ib <> wm `shouldBe` ib
|
||||||
|
ib <> nf `shouldBe` ib
|
||||||
|
|
||||||
|
it "HE > IB > WrongMethod > NF" $ do
|
||||||
|
he <> wm `shouldBe` he
|
||||||
|
ib <> wm `shouldBe` ib
|
||||||
|
nf <> wm `shouldBe` wm
|
||||||
|
|
||||||
|
wm <> he `shouldBe` he
|
||||||
|
wm <> ib `shouldBe` ib
|
||||||
|
wm <> nf `shouldBe` wm
|
||||||
|
|
||||||
|
it "* > NotFound" $ do
|
||||||
|
he <> nf `shouldBe` he
|
||||||
|
ib <> nf `shouldBe` ib
|
||||||
|
wm <> nf `shouldBe` wm
|
||||||
|
|
||||||
|
nf <> he `shouldBe` he
|
||||||
|
nf <> ib `shouldBe` ib
|
||||||
|
nf <> wm `shouldBe` wm
|
||||||
|
|
||||||
type MiscCombinatorsAPI
|
type MiscCombinatorsAPI
|
||||||
= "version" :> HttpVersion :> Get '[JSON] String
|
= "version" :> HttpVersion :> Get '[JSON] String
|
||||||
:<|> "secure" :> IsSecure :> Get '[JSON] String
|
:<|> "secure" :> IsSecure :> Get '[JSON] String
|
||||||
|
@ -644,4 +738,4 @@ authRequiredSpec = do
|
||||||
foo401 <- get "/foo"
|
foo401 <- get "/foo"
|
||||||
bar401 <- get "/bar"
|
bar401 <- get "/bar"
|
||||||
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401)
|
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401)
|
||||||
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" bar401)
|
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401)
|
||||||
|
|
|
@ -1,24 +1,31 @@
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE DeriveDataTypeable #-}
|
{-# LANGUAGE DeriveDataTypeable #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
{-# LANGUAGE PolyKinds #-}
|
{-# LANGUAGE PolyKinds #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# OPTIONS_HADDOCK not-home #-}
|
{-# OPTIONS_HADDOCK not-home #-}
|
||||||
module Servant.API.Authentication (BasicAuth) where
|
module Servant.API.Authentication where
|
||||||
|
|
||||||
import Data.Typeable (Typeable)
|
import Data.ByteString (ByteString)
|
||||||
import GHC.TypeLits (Symbol)
|
import Data.Typeable (Typeable)
|
||||||
|
import GHC.TypeLits (Symbol)
|
||||||
|
|
||||||
|
-- | we can be either Strict or Lax.
|
||||||
|
-- Strict: all handlers under 'AuthProtect' take a 'usr' argument.
|
||||||
|
-- when auth fails, we call user-supplied handlers to respond.
|
||||||
|
-- Lax: all handlers under 'AuthProtect' take a 'Maybe usr' argument.
|
||||||
|
-- when auth fails, we call the handlers with 'Nothing'.
|
||||||
|
data AuthPolicy = Strict | Lax
|
||||||
|
|
||||||
|
-- | the combinator to be used in API types
|
||||||
|
data AuthProtect authdata usr (policy :: AuthPolicy)
|
||||||
|
|
||||||
|
-- | what we'll ask user to provide at the server-level when we see a
|
||||||
|
-- 'AuthProtect' combinator in an API type
|
||||||
|
data family AuthProtected authdata usr subserver :: AuthPolicy -> *
|
||||||
|
|
||||||
-- | Basic Authentication with respect to a specified @realm@ and a @lookup@
|
-- | Basic Authentication with respect to a specified @realm@ and a @lookup@
|
||||||
-- type to encapsulate authentication logic.
|
-- type to encapsulate authentication logic.
|
||||||
--
|
data BasicAuth (realm :: Symbol) = BasicAuth { baUser :: ByteString
|
||||||
-- Example:
|
, baPass :: ByteString
|
||||||
-- >>> type MyApi = BasicAuth "book-realm" DB :> "books" :> Get '[JSON] [Book]
|
} deriving (Eq, Show, Typeable)
|
||||||
data BasicAuth (realm :: Symbol) lookup
|
|
||||||
deriving (Typeable)
|
|
||||||
|
|
||||||
-- $setup
|
|
||||||
-- >>> import Servant.API
|
|
||||||
-- >>> import Data.Aeson
|
|
||||||
-- >>> import Data.Text
|
|
||||||
-- >>> data DB
|
|
||||||
-- >>> data Book
|
|
||||||
-- >>> instance ToJSON Book where { toJSON = undefined }
|
|
||||||
|
|
Loading…
Reference in a new issue