Begin integrating upstream changes into auth

This commit is contained in:
aaron levin 2015-11-22 23:10:00 +01:00
parent 2424424ab2
commit 2b3df72fa2
2 changed files with 30 additions and 77 deletions

View file

@ -59,9 +59,6 @@ import Network.Wai (Application,
import Servant.API ((:<|>) (..), (:>),
Capture, Delete,
Get, Header, IsSecure (Secure, NotSecure),
MatrixFlag,
MatrixParam,
MatrixParams,
Patch, Post, Put,
QueryFlag,
QueryParam,
@ -73,7 +70,9 @@ import Servant.API.Authentication (AuthPolicy (Strict,
AuthProtected)
import Servant.API.ContentTypes (AcceptHeader (..),
AllCTRender (..),
AllCTUnrender (..))
AllCTUnrender (..),
AllMime (..),
canHandleAcceptH)
import Servant.API.ResponseHeaders (GetHeaders,
Headers,
getHeaders,
@ -83,7 +82,6 @@ import Servant.Server.Internal.Authentication (AuthData (authData)
checkAuthStrict,
AuthHandlers(onMissingAuthData,
onUnauthenticated))
import Servant.Server.Internal.PathInfo
import Servant.Server.Internal.Router
import Servant.Server.Internal.RoutingApplication
import Servant.Server.Internal.ServantErr
@ -280,38 +278,38 @@ instance
type ServerT (AuthProtect authdata usr 'Strict :> sublayout) m = AuthProtected authdata usr (usr -> ServerT sublayout m) 'Strict
route _ subserver = WithRequest $ \req ->
route (Proxy :: Proxy sublayout) $ do
-- Note: this may perform IO for each attempt at matching.
-- route :: Proxy (a :> s) -> Delayed (AuthProtected aData usr (usr -> servert) 'S) -> Router
route _ (Delayed _ _ _ m) = WithRequest $ \req ->
route (Proxy :: Proxy sublayout) $ addBodyCheck subserver $ case authData req of
Nothing -> do
-- we're in strict mode: don't let the request go
-- call the provided "on missing auth" handler
resp <- onMissingAuthData (authHandlers authProtectionStrict)
return $ FailFatal resp
-- succesfully pulled auth data out of the Request
Just authData' -> do
mUsr <- (checkAuthStrict authProtectionStrict) authData'
case mUsr of
-- this user is not authenticated.
Nothing -> do
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData'
return $ FailFatal resp
-- this user is authenticated.
Just usr ->
(return . Route . subServerStrict authProtectionStrict) usr
rr <- routeResult <$> subserver
case rr of
-- Successful route match, so we extract the author-provided
-- auth data.
Right authProtectionStrict ->
case authData req of
-- could not pull authenticate data out of the request
Nothing -> do
-- we're in strict mode: don't let the request go
-- call the provided "on missing auth" handler
resp <- onMissingAuthData (authHandlers authProtectionStrict)
return $ failWith (RouteMismatch resp)
-- succesfully pulled auth data out of the Request
Just authData' -> do
mUsr <- (checkAuthStrict authProtectionStrict) authData'
case mUsr of
-- this user is not authenticated.
Nothing -> do
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData'
return $ failWith (RouteMismatch resp)
-- this user is authenticated.
Just usr ->
(return . succeedWith . subServerStrict authProtectionStrict) usr
-- route did not match, propagate failure.
Left rMismatch ->
return (failWith rMismatch)
return (Fail rMismatch)
-- | Authentication in Lax mode.
instance
@ -335,10 +333,10 @@ instance
-- authentication on it. In Lax mode, we just pass `Maybe usr`
-- to the autho.
musr <- maybe (pure Nothing) (checkAuthLax authProtectionLax) (authData req)
(return . succeedWith . subServerLax authProtectionLax) musr
(return . Route . subServerLax authProtectionLax) musr
-- route did not match, propagate failure
Left rMismatch ->
return (failWith rMismatch)
return (Fail rMismatch)
-- | When implementing the handler for a 'Get' endpoint,
-- just like for 'Servant.API.Delete.Delete', 'Servant.API.Post.Post'

View file

@ -15,9 +15,6 @@ import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.IORef (newIORef, readIORef,
writeIORef)
import Data.Maybe (fromMaybe)
import Data.String (fromString)
import Network.HTTP.Types hiding (ResponseHeaders)
import Network.Wai (Application, Request,
Response, ResponseReceived,
requestBody,
@ -36,48 +33,6 @@ data RouteResult a =
| Route !a
deriving (Eq, Show, Read, Functor)
-- Note that the ordering of the constructors has great significance! It
-- determines the Ord instance and, consequently, the monoid instance.
data RouteMismatch =
NotFound -- ^ the usual "not found" error
| WrongMethod -- ^ a more informative "you just got the HTTP method wrong" error
| UnsupportedMediaType -- ^ request body has unsupported media type
| InvalidBody String -- ^ an even more informative "your json request body wasn't valid" error
| HttpError Status [Header] (Maybe BL.ByteString) -- ^ an even even more informative arbitrary HTTP response code error.
| RouteMismatch Response -- ^ an arbitrary mismatch with custom Response.
-- | specialized 'Less Than' for use with Monoid RouteMismatch
(<=:) :: RouteMismatch -> RouteMismatch -> Bool
{-# INLINE (<=:) #-}
NotFound <=: _ = True
WrongMethod <=: rmm = not (rmm <=: NotFound)
UnsupportedMediaType <=: rmm = not (rmm <=: WrongMethod)
InvalidBody _ <=: rmm = not (rmm <=: UnsupportedMediaType)
HttpError _ _ _ <=: rmm = not (rmm <=: (InvalidBody ""))
RouteMismatch _ <=: _ = False
instance Monoid RouteMismatch where
mempty = NotFound
-- The following isn't great, since it picks @InvalidBody@ based on
-- alphabetical ordering, but any choice would be arbitrary.
--
-- "As one judge said to the other, 'Be just and if you can't be just, be
-- arbitrary'" -- William Burroughs
--
-- It used to be the case that `mappend = max` but getting rid of the `Eq`
-- and `Ord` instance meant we had to roll out our own max ;\
rmm `mappend` NotFound = rmm
NotFound `mappend` rmm = rmm
WrongMethod `mappend` rmm | rmm <=: WrongMethod = WrongMethod
WrongMethod `mappend` rmm = rmm
UnsupportedMediaType `mappend` rmm | rmm <=: UnsupportedMediaType = UnsupportedMediaType
UnsupportedMediaType `mappend` rmm = rmm
i@(InvalidBody _) `mappend` rmm | rmm <=: i = i
InvalidBody _ `mappend` rmm = rmm
h@(HttpError _ _ _) `mappend` rmm | rmm <=: h = h
HttpError _ _ _ `mappend` rmm = rmm
r@(RouteMismatch _) `mappend` _ = r
data ReqBodyState = Uncalled
| Called !B.ByteString
| Done !B.ByteString