Second Iteration of Authentication

Implemented with the AuthProtected data family as per alp's suggestion.
This commit is contained in:
Andres Loeh 2015-06-01 22:39:12 +02:00 committed by Arian van Putten
parent d2e2122933
commit 272091effe
7 changed files with 305 additions and 118 deletions

View file

@ -36,6 +36,7 @@ library
Servant
Servant.Server
Servant.Server.Internal
Servant.Server.Internal.Authentication
Servant.Server.Internal.Enter
Servant.Server.Internal.PathInfo
Servant.Server.Internal.Router

View file

@ -3,6 +3,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
@ -27,38 +28,60 @@ import Control.Monad.Trans.Except (ExceptT)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map as M
import qualified Data.ByteString as B
import Data.ByteString.Base64 (decodeLenient)
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map as M
import Data.Maybe (mapMaybe, fromMaybe)
import Data.String (fromString)
import Data.String.Conversions (cs, (<>), ConvertibleStrings)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding (decodeUtf8,
encodeUtf8)
import Data.Typeable
import Data.Word8 (isSpace, _colon, toLower)
import GHC.TypeLits (KnownSymbol, symbolVal)
import Network.HTTP.Types hiding (Header, ResponseHeaders)
import GHC.TypeLits (KnownSymbol,
symbolVal)
import Network.HTTP.Types hiding (Header,
ResponseHeaders)
import Network.Socket (SockAddr)
import Network.Wai (Application, isSecure, httpVersion, Request, Response,
ResponseReceived, lazyRequestBody,
pathInfo, rawQueryString, remoteHost,
requestBody, requestHeaders,
requestMethod, responseLBS,
strictRequestBody, vault)
import Servant.API ((:<|>) (..), (:>), BasicAuth, Capture,
Delete, Get, Header, IsSecure(Secure, NotSecure),
MatrixFlag, MatrixParam, MatrixParams,
Patch, Post, Put, QueryFlag,
QueryParam, QueryParams, Raw,
RemoteHost, ReqBody, Vault)
import Network.Wai (Application,
httpVersion,
isSecure,
lazyRequestBody,
rawQueryString,
remoteHost,
requestHeaders,
requestMethod,
responseLBS, vault)
import Servant.API ((:<|>) (..), (:>),
Capture, Delete,
Get, Header, IsSecure (Secure, NotSecure),
MatrixFlag,
MatrixParam,
MatrixParams,
Patch, Post, Put,
QueryFlag,
QueryParam,
QueryParams, Raw,
RemoteHost,
ReqBody, Vault)
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtect,
AuthProtected)
import Servant.API.ContentTypes (AcceptHeader (..),
AllCTRender (..),
AllCTUnrender (..))
import Servant.API.ResponseHeaders (Headers, getResponse, GetHeaders,
getHeaders)
import Servant.API.ResponseHeaders (GetHeaders,
Headers,
getHeaders,
getResponse)
import Servant.Common.Text (FromText, fromText)
import Servant.Server.Internal.Authentication (AuthData (authData),
AuthProtected (..),
checkAuthStrict,
onMissingAuthData,
onUnauthenticated)
import Servant.Server.Internal.PathInfo
import Servant.Server.Internal.Router
import Servant.Server.Internal.RoutingApplication
@ -71,18 +94,6 @@ class HasServer layout where
type Server layout = ServerT layout (ExceptT ServantErr IO)
-- | A type-indexed class to encapsulate Basic authentication handling.
-- Authentication handling is indexed by the lookup type.
--
-- > data ExampleAuthDB
-- > data ExampleUser
-- > instance BasicAuthLookup ExampleAuthDB where
-- > type BasicAuthVal = ExampleUser
-- > basicAuthLookup _ _ _ = return Nothing
class BasicAuthLookup lookup where
type BasicAuthVal lookup :: *
basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe (BasicAuthVal lookup))
-- * Instances
-- | A server for @a ':<|>' b@ first tries to match the request against the route
@ -246,54 +257,74 @@ instance
route Proxy = methodRouterHeaders methodDelete (Proxy :: Proxy ctypes) ok200
-- | Authentication
-- | Authentication in Strict mode.
instance
#if MIN_VERSION_base(4,8,0)
{-# OVERLAPPABLE #-}
#endif
( HasServer sublayout
, BasicAuthLookup lookup
, KnownSymbol realm
)
=> HasServer (BasicAuth realm lookup :> sublayout) where
(AuthData authdata , HasServer sublayout) => HasServer (AuthProtect authdata (usr :: *) 'Strict :> sublayout) where
type ServerT (BasicAuth realm lookup :> sublayout) m
= BasicAuthVal lookup -> ServerT sublayout m
type ServerT (AuthProtect authdata usr 'Strict :> sublayout) m = AuthProtected authdata usr (usr -> ServerT sublayout m) 'Strict
route _ action request respond =
case lookup "Authorization" (requestHeaders request) of
Nothing -> respond . succeedWith $ authFailure401
Just authBs ->
-- ripped from: https://hackage.haskell.org/package/wai-extra-1.3.4.5/docs/src/Network-Wai-Middleware-HttpAuth.html#basicAuth
let (x,y) = B.break isSpace authBs in
if B.map toLower x == "basic"
-- check base64-encoded password
then checkB64AndRespond (B.dropWhile isSpace y)
-- Authenticaiton header is not Basic, fail with 401.
else respond . succeedWith $ authFailure401
where
realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
authFailure401 = responseLBS status401 [("WWW-Authenticate", headerBytes)] ""
checkB64AndRespond encoded =
case B.uncons passwordWithColonAtHead of
Just (_, password) -> do
-- let's check these credentials using the user-provided lookup method
maybeAuthData <- basicAuthLookup (Proxy :: Proxy lookup) username password
case maybeAuthData of
Nothing -> respond . succeedWith $ authFailure403
(Just authData) ->
route (Proxy :: Proxy sublayout) (action authData) request respond
route _ subserver = WithRequest $ \req ->
route (Proxy :: Proxy sublayout) $ do
-- Note: this may perform IO for each attempt at matching.
rr <- routeResult <$> subserver
-- no username:password present
Nothing -> respond . succeedWith $ authFailure401
where
authFailure403 = responseLBS status403 [] ""
raw = decodeLenient encoded
-- split username and password at the colon ':' char.
(username, passwordWithColonAtHead) = B.breakByte _colon raw
case rr of
-- Successful route match, so we extract the author-provided
-- auth data.
Right authProtectionStrict ->
case authData req of
-- could not pull authenticate data out of the request
Nothing -> do
-- we're in strict mode: don't let the request go
-- call the provided "on missing auth" handler
resp <- onMissingAuthData (authHandlers authProtectionStrict)
return $ failWith (RouteMismatch resp)
-- succesfully pulled auth data out of the Request
Just authData' -> do
mUsr <- (checkAuthStrict authProtectionStrict) authData'
case mUsr of
-- this user is not authenticated.
Nothing -> do
resp <- onUnauthenticated (authHandlers authProtectionStrict) authData'
return $ failWith (RouteMismatch resp)
-- this user is authenticated.
Just usr ->
(return . succeedWith . subServerStrict authProtectionStrict) usr
-- route did not match, propagate failure.
Left rMismatch ->
return (failWith rMismatch)
-- | Authentication in Lax mode.
instance
#if MIN_VERSION_base(4,8,0)
{-# OVERLAPPABLE #-}
#endif
(AuthData authdata , HasServer sublayout) => HasServer (AuthProtect authdata (usr :: *) 'Lax :> sublayout) where
type ServerT (AuthProtect authdata usr 'Lax :> sublayout) m = AuthProtected authdata usr (Maybe usr -> ServerT sublayout m) 'Lax
route _ subserver = WithRequest $ \req ->
route (Proxy :: Proxy sublayout) $ do
-- Note: this may perform IO for each attempt at matching.
rr <- routeResult <$> subserver
-- Successful route match, so we extract the author-provided
-- auth data.
case rr of
-- route matched, extract author-provided lax authentication data
Right authProtectionLax -> do
-- extract a user from the request object and perform
-- authentication on it. In Lax mode, we just pass `Maybe usr`
-- to the autho.
musr <- maybe (pure Nothing) (checkAuthLax authProtectionLax) (authData req)
(return . succeedWith . subServerLax authProtectionLax) musr
-- route did not match, propagate failure
Left rMismatch ->
return (failWith rMismatch)
-- | When implementing the handler for a 'Get' endpoint,
-- just like for 'Servant.API.Delete.Delete', 'Servant.API.Post.Post'
@ -722,7 +753,9 @@ instance (KnownSymbol sym, HasServer sublayout)
Just Nothing -> True -- param is there, with no value
Just (Just v) -> examine v -- param with a value
Nothing -> False -- param not in the query string
route (Proxy :: Proxy sublayout) (feedTo subserver param)
_ -> route (Proxy :: Proxy sublayout) (feedTo subserver False)
where paramname = cs $ symbolVal (Proxy :: Proxy sym)
examine v | v == "true" || v == "1" || v == "" = True

View file

@ -0,0 +1,102 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Servant.Server.Internal.Authentication
( AuthProtected (..)
, AuthData (..)
, AuthHandlers (..)
, basicAuthLax
, basicAuthStrict
, laxProtect
, strictProtect
) where
import Control.Monad (guard)
import qualified Data.ByteString as B
import Data.ByteString.Base64 (decodeLenient)
import Data.Monoid ((<>))
import Data.Proxy (Proxy (Proxy))
import Data.String (fromString)
import Data.Word8 (isSpace, toLower, _colon)
import GHC.TypeLits (KnownSymbol, symbolVal)
import Network.HTTP.Types.Status (status401)
import Network.Wai (Request, Response, requestHeaders,
responseBuilder)
import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtected,
BasicAuth (BasicAuth))
-- | Class to represent the ability to extract authentication-related
-- data from a 'Request' object.
class AuthData a where
authData :: Request -> Maybe a
-- | handlers to deal with authentication failures.
data AuthHandlers authData = AuthHandlers
{ -- we couldn't find the right type of auth data (or any, for that matter)
onMissingauthData :: IO Response
,
-- we found the right type of auth data in the request but the check failed
onUnauthenticated :: authData -> IO Response
}
-- | concrete type to provide when in 'Strict' mode.
data instance AuthProtected authData usr subserver 'Strict =
AuthProtectedStrict { checkAuthStrict :: authData -> IO (Maybe usr)
, subServerStrict :: subserver
, authHandlers :: AuthHandlers authData
}
-- | concrete type to provide when in 'Lax' mode.
data instance AuthProtected authData usr subserver 'Lax =
AuthProtectedLax { checkAuthLax :: authData -> IO (Maybe usr)
, subServerLax :: subserver
}
-- | handy function to build an auth-protected bit of API with a Lax policy
laxProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
-> subserver -- ^ the handlers for the auth-aware bits of the API
-> AuthProtected authData usr subserver 'Lax
laxProtect = AuthProtectedLax
-- | handy function to build an auth-protected bit of API with a Strict policy
strictProtect :: (authData -> IO (Maybe usr)) -- ^ check auth
-> subserver -- ^ handlers for the auth-protected bits of the API
-> AuthHandlers authData -- ^ functions to call on auth failure
-> AuthProtected authData usr subserver 'Strict
strictProtect = AuthProtectedStrict
-- | 'BasicAuth' instance for authData
instance AuthData (BasicAuth realm) where
authData request = do
authBs <- lookup "Authorization" (requestHeaders request)
let (x,y) = B.break isSpace authBs
guard (B.map toLower x == "basic")
-- decode the base64-encoded username and password
let (username, passWithColonAtHead) = B.break (== _colon) (decodeLenient (B.dropWhile isSpace y))
(_, password) <- B.uncons passWithColonAtHead
return $ BasicAuth username password
-- | handlers for Basic Authentication.
basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth realm)
basicAuthHandlers =
let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm)
headerBytes = "Basic realm=\"" <> realmBytes <> "\""
authFailure = responseBuilder status401 [("WWW-Authenticate", headerBytes)] mempty in
AuthHandlers (return authFailure) ((const . return) authFailure)
-- | Basic authentication combinator with strict failure.
basicAuthStrict :: KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr))
-> subserver
-> AuthProtected (BasicAuth realm) usr subserver 'Strict
basicAuthStrict check subserver = strictProtect check subserver basicAuthHandlers
-- | Basic authentication combinator with lax failure.
basicAuthLax :: KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr))
-> subserver
-> AuthProtected (BasicAuth realm) usr subserver 'Lax
basicAuthLax = laxProtect

View file

@ -1,4 +1,5 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
@ -24,9 +25,16 @@ import qualified Control.Monad.State.Lazy as LState
import qualified Control.Monad.State.Strict as SState
import qualified Control.Monad.Writer.Lazy as LWriter
import qualified Control.Monad.Writer.Strict as SWriter
import qualified Control.Monad.State.Lazy as LState
import qualified Control.Monad.State.Strict as SState
import qualified Control.Monad.Writer.Lazy as LWriter
import qualified Control.Monad.Writer.Strict as SWriter
import Data.Typeable
import Servant.API
import Servant.API.Authentication
import Servant.Server.Internal.Authentication (AuthProtected (AuthProtectedStrict, AuthProtectedLax))
class Enter typ arg ret | typ arg -> ret, typ ret -> arg where
enter :: arg -> typ -> ret
@ -95,3 +103,12 @@ squashNat = Nat squash
-- | Like @mmorph@'s `generalize`.
generalizeNat :: Applicative m => Identity :~> m
generalizeNat = Nat (pure . runIdentity)
-- | 'Enter' instance for AuthProtectedStrict
instance Enter subserver arg ret => Enter (AuthProtected authData usr subserver 'Strict) arg (AuthProtected authData usr ret 'Strict) where
enter arg (AuthProtectedStrict check subserver handlers) = AuthProtectedStrict check (enter arg subserver) handlers
-- | 'Enter' instance for AuthProtectedLax
instance Enter subserver arg ret => Enter (AuthProtected authData usr subserver 'Lax) arg (AuthProtected authData usr ret 'Lax) where
enter arg (AuthProtectedLax check subserver) = AuthProtectedLax check (enter arg subserver)

View file

@ -16,8 +16,7 @@ import Data.IORef (newIORef, readIORef,
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import Data.String (fromString)
import Network.HTTP.Types hiding (Header,
ResponseHeaders)
import Network.HTTP.Types hiding (Header, ResponseHeaders)
import Network.Wai (Application, Request,
Response, ResponseReceived,
requestBody, responseLBS,
@ -32,7 +31,7 @@ type RoutingApplication =
-- | A wrapper around @'Either' 'RouteMismatch' a@.
newtype RouteResult a =
RR { routeResult :: Either RouteMismatch a }
deriving (Eq, Show, Functor, Applicative)
deriving (Show, Functor, Applicative, Monad)
-- | If we get a `Right`, it has precedence over everything else.
--
@ -52,8 +51,21 @@ data RouteMismatch =
| WrongMethod -- ^ a more informative "you just got the HTTP method wrong" error
| UnsupportedMediaType -- ^ request body has unsupported media type
| InvalidBody String -- ^ an even more informative "your json request body wasn't valid" error
| HttpError Status (Maybe BL.ByteString) -- ^ an even even more informative arbitrary HTTP response code error.
deriving (Eq, Ord, Show)
| HttpError Status [Header] (Maybe BL.ByteString) -- ^ an even even more informative arbitrary HTTP response code error.
| RouteMismatch Response -- ^ an arbitrary mismatch with custom Response.
instance Show RouteMismatch where
show = const "hello"
-- | specialized 'Less Than' for use with Monoid RouteMismatch
(<=:) :: RouteMismatch -> RouteMismatch -> Bool
{-# INLINE (<=:) #-}
NotFound <=: _ = True
WrongMethod <=: rmm = not (rmm <=: NotFound)
UnsupportedMediaType <=: rmm = not (rmm <=: WrongMethod)
InvalidBody _ <=: rmm = not (rmm <=: UnsupportedMediaType)
HttpError _ _ _ <=: rmm = not (rmm <=: (InvalidBody ""))
RouteMismatch _ <=: _ = False
instance Monoid RouteMismatch where
mempty = NotFound
@ -62,7 +74,20 @@ instance Monoid RouteMismatch where
--
-- "As one judge said to the other, 'Be just and if you can't be just, be
-- arbitrary'" -- William Burroughs
mappend = max
--
-- It used to be the case that `mappend = max` but getting rid of the `Eq`
-- and `Ord` instance meant we had to roll out our own max ;\
rmm `mappend` NotFound = rmm
NotFound `mappend` rmm = rmm
WrongMethod `mappend` rmm | rmm <=: WrongMethod = WrongMethod
WrongMethod `mappend` rmm = rmm
UnsupportedMediaType `mappend` rmm | rmm <=: UnsupportedMediaType = UnsupportedMediaType
UnsupportedMediaType `mappend` rmm = rmm
i@(InvalidBody _) `mappend` rmm | rmm <=: i = i
InvalidBody _ `mappend` rmm = rmm
h@(HttpError _ _ _) `mappend` rmm | rmm <=: h = h
HttpError _ _ _ `mappend` rmm = rmm
r@(RouteMismatch _) `mappend` _ = r
data ReqBodyState = Uncalled
| Called !B.ByteString
@ -102,8 +127,10 @@ toApplication ra request respond = do
respond $ responseLBS badRequest400 [] $ fromString $ "invalid request body: " ++ err
routingRespond (Left UnsupportedMediaType) =
respond $ responseLBS unsupportedMediaType415 [] "unsupported media type"
routingRespond (Left (HttpError status body)) =
respond $ responseLBS status [] $ fromMaybe (BL.fromStrict $ statusMessage status) body
routingRespond (Left (HttpError status headers body)) =
respond $ responseLBS status headers $ fromMaybe (BL.fromStrict $ statusMessage status) body
routingRespond (Left (RouteMismatch resp)) =
respond resp
routingRespond (Right response) =
respond response

View file

@ -659,7 +659,7 @@ prioErrorsSpec = describe "PrioErrors" $ do
-- | Test server error functionality.
errorsSpec :: Spec
errorsSpec = do
let he = HttpError status409 (Just "A custom error")
let he = HttpError status409 [] (Just "A custom error")
let ib = InvalidBody "The body is invalid"
let wm = WrongMethod
let nf = NotFound
@ -789,4 +789,4 @@ authRequiredSpec = do
foo401 <- get "/foo"
bar401 <- get "/bar"
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401)
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" bar401)
WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401)

View file

@ -1,24 +1,31 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_HADDOCK not-home #-}
module Servant.API.Authentication (BasicAuth) where
module Servant.API.Authentication where
import Data.ByteString (ByteString)
import Data.Typeable (Typeable)
import GHC.TypeLits (Symbol)
-- | we can be either Strict or Lax.
-- Strict: all handlers under 'AuthProtect' take a 'usr' argument.
-- when auth fails, we call user-supplied handlers to respond.
-- Lax: all handlers under 'AuthProtect' take a 'Maybe usr' argument.
-- when auth fails, we call the handlers with 'Nothing'.
data AuthPolicy = Strict | Lax
-- | the combinator to be used in API types
data AuthProtect authdata usr (policy :: AuthPolicy)
-- | what we'll ask user to provide at the server-level when we see a
-- 'AuthProtect' combinator in an API type
data family AuthProtected authdata usr subserver :: AuthPolicy -> *
-- | Basic Authentication with respect to a specified @realm@ and a @lookup@
-- type to encapsulate authentication logic.
--
-- Example:
-- >>> type MyApi = BasicAuth "book-realm" DB :> "books" :> Get '[JSON] [Book]
data BasicAuth (realm :: Symbol) lookup
deriving (Typeable)
-- $setup
-- >>> import Servant.API
-- >>> import Data.Aeson
-- >>> import Data.Text
-- >>> data DB
-- >>> data Book
-- >>> instance ToJSON Book where { toJSON = undefined }
data BasicAuth (realm :: Symbol) = BasicAuth { baUser :: ByteString
, baPass :: ByteString
} deriving (Eq, Show, Typeable)