From 4c838434890fee67a35cab16654bcf51da36b270 Mon Sep 17 00:00:00 2001 From: Andres Loeh Date: Mon, 1 Jun 2015 22:39:12 +0200 Subject: [PATCH] Second Iteration of Authentication Implemented with the AuthProtected data family as per alp's suggestion. (during rebase, removed monoid instance for RouteResult) --- servant-server/servant-server.cabal | 1 + servant-server/src/Servant/Server/Internal.hs | 191 +++++++++++------- .../Servant/Server/Internal/Authentication.hs | 102 ++++++++++ .../src/Servant/Server/Internal/Enter.hs | 19 +- .../Server/Internal/RoutingApplication.hs | 69 +++++++ servant-server/test/Servant/ServerSpec.hs | 96 ++++++++- servant/src/Servant/API/Authentication.hs | 39 ++-- 7 files changed, 421 insertions(+), 96 deletions(-) create mode 100644 servant-server/src/Servant/Server/Internal/Authentication.hs diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index 83c4ca6a..2db7b129 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -36,6 +36,7 @@ library Servant Servant.Server Servant.Server.Internal + Servant.Server.Internal.Authentication Servant.Server.Internal.Enter Servant.Server.Internal.Router Servant.Server.Internal.RoutingApplication diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index 2258f8a8..ce7a80bd 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -20,41 +21,67 @@ module Servant.Server.Internal ) where #if !MIN_VERSION_base(4,8,0) -import Control.Applicative ((<$>)) +import Control.Applicative ((<$>)) #endif import Control.Monad.Trans.Except (ExceptT) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import qualified Data.Map as M +import qualified Data.ByteString as B import Data.ByteString.Base64 (decodeLenient) -import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy as BL +import qualified Data.Map as M import Data.Maybe (mapMaybe, fromMaybe) -import Data.String (fromString) -import Data.String.Conversions (ConvertibleStrings, cs, (<>)) -import Data.Text (Text) +import Data.String (fromString) +import Data.String.Conversions (cs, (<>), ConvertibleStrings) +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding (decodeUtf8, + encodeUtf8) import Data.Typeable -import Data.Word8 (isSpace, _colon, toLower) -import GHC.TypeLits (KnownSymbol, symbolVal) -import Network.HTTP.Types hiding (Header, ResponseHeaders) -import Network.Socket (SockAddr) -import Network.Wai (Application, lazyRequestBody, - rawQueryString, requestHeaders, - requestMethod, responseLBS, remoteHost, - isSecure, vault, httpVersion, Response, - Request, pathInfo) -import Servant.API ((:<|>) (..), (:>), Capture, - Delete, Get, Header, - IsSecure(..), Patch, Post, Put, - QueryFlag, QueryParam, QueryParams, - Raw, RemoteHost, ReqBody, Vault) -import Servant.API.ContentTypes (AcceptHeader (..), - AllCTRender (..), - AllCTUnrender (..), - AllMime, - canHandleAcceptH) -import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders, - getResponse) - +import GHC.TypeLits (KnownSymbol, + symbolVal) +import Network.HTTP.Types hiding (Header, + ResponseHeaders) +import Network.Socket (SockAddr) +import Network.Wai (Application, + httpVersion, + isSecure, + lazyRequestBody, + rawQueryString, + remoteHost, + requestHeaders, + requestMethod, + responseLBS, vault) +import Servant.API ((:<|>) (..), (:>), + Capture, Delete, + Get, Header, IsSecure (Secure, NotSecure), + MatrixFlag, + MatrixParam, + MatrixParams, + Patch, Post, Put, + QueryFlag, + QueryParam, + QueryParams, Raw, + RemoteHost, + ReqBody, Vault) +import Servant.API.Authentication (AuthPolicy (Strict, Lax), + AuthProtect, + AuthProtected) +import Servant.API.ContentTypes (AcceptHeader (..), + AllCTRender (..), + AllCTUnrender (..)) +import Servant.API.ResponseHeaders (GetHeaders, + Headers, + getHeaders, + getResponse) +import Servant.Common.Text (FromText, fromText) +import Servant.Server.Internal.Authentication (AuthData (authData), + AuthProtected (..), + checkAuthStrict, + onMissingAuthData, + onUnauthenticated) +import Servant.Server.Internal.PathInfo import Servant.Server.Internal.Router import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.ServantErr @@ -69,18 +96,6 @@ class HasServer layout where type Server layout = ServerT layout (ExceptT ServantErr IO) --- | A type-indexed class to encapsulate Basic authentication handling. --- Authentication handling is indexed by the lookup type. --- --- > data ExampleAuthDB --- > data ExampleUser --- > instance BasicAuthLookup ExampleAuthDB where --- > type BasicAuthVal = ExampleUser --- > basicAuthLookup _ _ _ = return Nothing -class BasicAuthLookup lookup where - type BasicAuthVal lookup :: * - basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe (BasicAuthVal lookup)) - -- * Instances -- | A server for @a ':<|>' b@ first tries to match the request against the route @@ -254,54 +269,74 @@ instance route Proxy = methodRouterHeaders methodDelete (Proxy :: Proxy ctypes) ok200 --- | Authentication +-- | Authentication in Strict mode. instance #if MIN_VERSION_base(4,8,0) {-# OVERLAPPABLE #-} #endif - ( HasServer sublayout - , BasicAuthLookup lookup - , KnownSymbol realm - ) - => HasServer (BasicAuth realm lookup :> sublayout) where + (AuthData authdata , HasServer sublayout) => HasServer (AuthProtect authdata (usr :: *) 'Strict :> sublayout) where - type ServerT (BasicAuth realm lookup :> sublayout) m - = BasicAuthVal lookup -> ServerT sublayout m + type ServerT (AuthProtect authdata usr 'Strict :> sublayout) m = AuthProtected authdata usr (usr -> ServerT sublayout m) 'Strict - route _ action request respond = - case lookup "Authorization" (requestHeaders request) of - Nothing -> respond . succeedWith $ authFailure401 - Just authBs -> - -- ripped from: https://hackage.haskell.org/package/wai-extra-1.3.4.5/docs/src/Network-Wai-Middleware-HttpAuth.html#basicAuth - let (x,y) = B.break isSpace authBs in - if B.map toLower x == "basic" - -- check base64-encoded password - then checkB64AndRespond (B.dropWhile isSpace y) - -- Authenticaiton header is not Basic, fail with 401. - else respond . succeedWith $ authFailure401 - where - realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm) - headerBytes = "Basic realm=\"" <> realmBytes <> "\"" - authFailure401 = responseLBS status401 [("WWW-Authenticate", headerBytes)] "" - checkB64AndRespond encoded = - case B.uncons passwordWithColonAtHead of - Just (_, password) -> do - -- let's check these credentials using the user-provided lookup method - maybeAuthData <- basicAuthLookup (Proxy :: Proxy lookup) username password - case maybeAuthData of - Nothing -> respond . succeedWith $ authFailure403 - (Just authData) -> - route (Proxy :: Proxy sublayout) (action authData) request respond + route _ subserver = WithRequest $ \req -> + route (Proxy :: Proxy sublayout) $ do + -- Note: this may perform IO for each attempt at matching. + rr <- routeResult <$> subserver - -- no username:password present - Nothing -> respond . succeedWith $ authFailure401 - where - authFailure403 = responseLBS status403 [] "" - raw = decodeLenient encoded - -- split username and password at the colon ':' char. - (username, passwordWithColonAtHead) = B.breakByte _colon raw + case rr of + -- Successful route match, so we extract the author-provided + -- auth data. + Right authProtectionStrict -> + case authData req of + -- could not pull authenticate data out of the request + Nothing -> do + -- we're in strict mode: don't let the request go + -- call the provided "on missing auth" handler + resp <- onMissingAuthData (authHandlers authProtectionStrict) + return $ failWith (RouteMismatch resp) + -- succesfully pulled auth data out of the Request + Just authData' -> do + mUsr <- (checkAuthStrict authProtectionStrict) authData' + case mUsr of + -- this user is not authenticated. + Nothing -> do + resp <- onUnauthenticated (authHandlers authProtectionStrict) authData' + return $ failWith (RouteMismatch resp) + -- this user is authenticated. + Just usr -> + (return . succeedWith . subServerStrict authProtectionStrict) usr + -- route did not match, propagate failure. + Left rMismatch -> + return (failWith rMismatch) + +-- | Authentication in Lax mode. +instance +#if MIN_VERSION_base(4,8,0) + {-# OVERLAPPABLE #-} +#endif + (AuthData authdata , HasServer sublayout) => HasServer (AuthProtect authdata (usr :: *) 'Lax :> sublayout) where + + type ServerT (AuthProtect authdata usr 'Lax :> sublayout) m = AuthProtected authdata usr (Maybe usr -> ServerT sublayout m) 'Lax + + route _ subserver = WithRequest $ \req -> + route (Proxy :: Proxy sublayout) $ do + -- Note: this may perform IO for each attempt at matching. + rr <- routeResult <$> subserver + -- Successful route match, so we extract the author-provided + -- auth data. + case rr of + -- route matched, extract author-provided lax authentication data + Right authProtectionLax -> do + -- extract a user from the request object and perform + -- authentication on it. In Lax mode, we just pass `Maybe usr` + -- to the autho. + musr <- maybe (pure Nothing) (checkAuthLax authProtectionLax) (authData req) + (return . succeedWith . subServerLax authProtectionLax) musr + -- route did not match, propagate failure + Left rMismatch -> + return (failWith rMismatch) -- | When implementing the handler for a 'Get' endpoint, -- just like for 'Servant.API.Delete.Delete', 'Servant.API.Post.Post' diff --git a/servant-server/src/Servant/Server/Internal/Authentication.hs b/servant-server/src/Servant/Server/Internal/Authentication.hs new file mode 100644 index 00000000..02c09893 --- /dev/null +++ b/servant-server/src/Servant/Server/Internal/Authentication.hs @@ -0,0 +1,102 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} + +module Servant.Server.Internal.Authentication +( AuthProtected (..) +, AuthData (..) +, AuthHandlers (..) +, basicAuthLax +, basicAuthStrict +, laxProtect +, strictProtect + ) where + +import Control.Monad (guard) +import qualified Data.ByteString as B +import Data.ByteString.Base64 (decodeLenient) +import Data.Monoid ((<>)) +import Data.Proxy (Proxy (Proxy)) +import Data.String (fromString) +import Data.Word8 (isSpace, toLower, _colon) +import GHC.TypeLits (KnownSymbol, symbolVal) +import Network.HTTP.Types.Status (status401) +import Network.Wai (Request, Response, requestHeaders, + responseBuilder) +import Servant.API.Authentication (AuthPolicy (Strict, Lax), + AuthProtected, + BasicAuth (BasicAuth)) + +-- | Class to represent the ability to extract authentication-related +-- data from a 'Request' object. +class AuthData a where + authData :: Request -> Maybe a + +-- | handlers to deal with authentication failures. +data AuthHandlers authData = AuthHandlers + { -- we couldn't find the right type of auth data (or any, for that matter) + onMissingauthData :: IO Response + , + -- we found the right type of auth data in the request but the check failed + onUnauthenticated :: authData -> IO Response + } + +-- | concrete type to provide when in 'Strict' mode. +data instance AuthProtected authData usr subserver 'Strict = + AuthProtectedStrict { checkAuthStrict :: authData -> IO (Maybe usr) + , subServerStrict :: subserver + , authHandlers :: AuthHandlers authData + } + +-- | concrete type to provide when in 'Lax' mode. +data instance AuthProtected authData usr subserver 'Lax = + AuthProtectedLax { checkAuthLax :: authData -> IO (Maybe usr) + , subServerLax :: subserver + } + +-- | handy function to build an auth-protected bit of API with a Lax policy +laxProtect :: (authData -> IO (Maybe usr)) -- ^ check auth + -> subserver -- ^ the handlers for the auth-aware bits of the API + -> AuthProtected authData usr subserver 'Lax +laxProtect = AuthProtectedLax + +-- | handy function to build an auth-protected bit of API with a Strict policy +strictProtect :: (authData -> IO (Maybe usr)) -- ^ check auth + -> subserver -- ^ handlers for the auth-protected bits of the API + -> AuthHandlers authData -- ^ functions to call on auth failure + -> AuthProtected authData usr subserver 'Strict +strictProtect = AuthProtectedStrict + +-- | 'BasicAuth' instance for authData +instance AuthData (BasicAuth realm) where + authData request = do + authBs <- lookup "Authorization" (requestHeaders request) + let (x,y) = B.break isSpace authBs + guard (B.map toLower x == "basic") + -- decode the base64-encoded username and password + let (username, passWithColonAtHead) = B.break (== _colon) (decodeLenient (B.dropWhile isSpace y)) + (_, password) <- B.uncons passWithColonAtHead + return $ BasicAuth username password + +-- | handlers for Basic Authentication. +basicAuthHandlers :: forall realm. KnownSymbol realm => AuthHandlers (BasicAuth realm) +basicAuthHandlers = + let realmBytes = (fromString . symbolVal) (Proxy :: Proxy realm) + headerBytes = "Basic realm=\"" <> realmBytes <> "\"" + authFailure = responseBuilder status401 [("WWW-Authenticate", headerBytes)] mempty in + AuthHandlers (return authFailure) ((const . return) authFailure) + +-- | Basic authentication combinator with strict failure. +basicAuthStrict :: KnownSymbol realm + => (BasicAuth realm -> IO (Maybe usr)) + -> subserver + -> AuthProtected (BasicAuth realm) usr subserver 'Strict +basicAuthStrict check subserver = strictProtect check subserver basicAuthHandlers + +-- | Basic authentication combinator with lax failure. +basicAuthLax :: KnownSymbol realm + => (BasicAuth realm -> IO (Maybe usr)) + -> subserver + -> AuthProtected (BasicAuth realm) usr subserver 'Lax +basicAuthLax = laxProtect diff --git a/servant-server/src/Servant/Server/Internal/Enter.hs b/servant-server/src/Servant/Server/Internal/Enter.hs index 5bcebe9d..35e2991a 100644 --- a/servant-server/src/Servant/Server/Internal/Enter.hs +++ b/servant-server/src/Servant/Server/Internal/Enter.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} @@ -13,7 +14,7 @@ module Servant.Server.Internal.Enter where #if !MIN_VERSION_base(4,8,0) import Control.Applicative #endif -import qualified Control.Category as C +import qualified Control.Category as C #if MIN_VERSION_mtl(2,2,1) import Control.Monad.Except #endif @@ -24,9 +25,16 @@ import qualified Control.Monad.State.Lazy as LState import qualified Control.Monad.State.Strict as SState import qualified Control.Monad.Writer.Lazy as LWriter import qualified Control.Monad.Writer.Strict as SWriter +import qualified Control.Monad.State.Lazy as LState +import qualified Control.Monad.State.Strict as SState +import qualified Control.Monad.Writer.Lazy as LWriter +import qualified Control.Monad.Writer.Strict as SWriter import Data.Typeable import Servant.API +import Servant.API.Authentication +import Servant.Server.Internal.Authentication (AuthProtected (AuthProtectedStrict, AuthProtectedLax)) + class Enter typ arg ret | typ arg -> ret, typ ret -> arg where enter :: arg -> typ -> ret @@ -95,3 +103,12 @@ squashNat = Nat squash -- | Like @mmorph@'s `generalize`. generalizeNat :: Applicative m => Identity :~> m generalizeNat = Nat (pure . runIdentity) + +-- | 'Enter' instance for AuthProtectedStrict +instance Enter subserver arg ret => Enter (AuthProtected authData usr subserver 'Strict) arg (AuthProtected authData usr ret 'Strict) where + enter arg (AuthProtectedStrict check subserver handlers) = AuthProtectedStrict check (enter arg subserver) handlers + + +-- | 'Enter' instance for AuthProtectedLax +instance Enter subserver arg ret => Enter (AuthProtected authData usr subserver 'Lax) arg (AuthProtected authData usr ret 'Lax) where + enter arg (AuthProtectedLax check subserver) = AuthProtectedLax check (enter arg subserver) diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 4b27c688..7116a8b2 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -15,6 +15,9 @@ import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import Data.IORef (newIORef, readIORef, writeIORef) +import Data.Maybe (fromMaybe) +import Data.String (fromString) +import Network.HTTP.Types hiding (Header, ResponseHeaders) import Network.Wai (Application, Request, Response, ResponseReceived, requestBody, @@ -33,6 +36,52 @@ data RouteResult a = | Route !a deriving (Eq, Show, Read, Functor) +-- Note that the ordering of the constructors has great significance! It +-- determines the Ord instance and, consequently, the monoid instance. +data RouteMismatch = + NotFound -- ^ the usual "not found" error + | WrongMethod -- ^ a more informative "you just got the HTTP method wrong" error + | UnsupportedMediaType -- ^ request body has unsupported media type + | InvalidBody String -- ^ an even more informative "your json request body wasn't valid" error + | HttpError Status [Header] (Maybe BL.ByteString) -- ^ an even even more informative arbitrary HTTP response code error. + | RouteMismatch Response -- ^ an arbitrary mismatch with custom Response. + +instance Show RouteMismatch where + show = const "hello" + +-- | specialized 'Less Than' for use with Monoid RouteMismatch +(<=:) :: RouteMismatch -> RouteMismatch -> Bool +{-# INLINE (<=:) #-} +NotFound <=: _ = True +WrongMethod <=: rmm = not (rmm <=: NotFound) +UnsupportedMediaType <=: rmm = not (rmm <=: WrongMethod) +InvalidBody _ <=: rmm = not (rmm <=: UnsupportedMediaType) +HttpError _ _ _ <=: rmm = not (rmm <=: (InvalidBody "")) +RouteMismatch _ <=: _ = False + +instance Monoid RouteMismatch where + mempty = NotFound + -- The following isn't great, since it picks @InvalidBody@ based on + -- alphabetical ordering, but any choice would be arbitrary. + -- + -- "As one judge said to the other, 'Be just and if you can't be just, be + -- arbitrary'" -- William Burroughs + -- + -- It used to be the case that `mappend = max` but getting rid of the `Eq` + -- and `Ord` instance meant we had to roll out our own max ;\ + rmm `mappend` NotFound = rmm + NotFound `mappend` rmm = rmm + WrongMethod `mappend` rmm | rmm <=: WrongMethod = WrongMethod + WrongMethod `mappend` rmm = rmm + UnsupportedMediaType `mappend` rmm | rmm <=: UnsupportedMediaType = UnsupportedMediaType + UnsupportedMediaType `mappend` rmm = rmm + i@(InvalidBody _) `mappend` rmm | rmm <=: i = i + InvalidBody _ `mappend` rmm = rmm + h@(HttpError _ _ _) `mappend` rmm | rmm <=: h = h + HttpError _ _ _ `mappend` rmm = rmm + r@(RouteMismatch _) `mappend` _ = r +>>>>>>> 272091e... Second Iteration of Authentication + data ReqBodyState = Uncalled | Called !B.ByteString | Done !B.ByteString @@ -62,6 +111,7 @@ toApplication ra request respond = do ra request{ requestBody = memoReqBody } routingRespond where +<<<<<<< HEAD routingRespond :: RouteResult Response -> IO ResponseReceived routingRespond (Fail err) = respond $ responseServantErr err routingRespond (FailFatal err) = respond $ responseServantErr err @@ -235,6 +285,25 @@ runDelayed (Delayed captures method body server) = -- Also takes a continuation for how to turn the -- result of the delayed server into a response. runAction :: Delayed (ExceptT ServantErr IO a) +======= + routingRespond :: Either RouteMismatch Response -> IO ResponseReceived + routingRespond (Left NotFound) = + respond $ responseLBS notFound404 [] "not found" + routingRespond (Left WrongMethod) = + respond $ responseLBS methodNotAllowed405 [] "method not allowed" + routingRespond (Left (InvalidBody err)) = + respond $ responseLBS badRequest400 [] $ fromString $ "invalid request body: " ++ err + routingRespond (Left UnsupportedMediaType) = + respond $ responseLBS unsupportedMediaType415 [] "unsupported media type" + routingRespond (Left (HttpError status headers body)) = + respond $ responseLBS status headers $ fromMaybe (BL.fromStrict $ statusMessage status) body + routingRespond (Left (RouteMismatch resp)) = + respond resp + routingRespond (Right response) = + respond response + +runAction :: IO (RouteResult (ExceptT ServantErr IO a)) +>>>>>>> 272091e... Second Iteration of Authentication -> (RouteResult Response -> IO r) -> (a -> RouteResult Response) -> IO r diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index fd7cdb1d..a1df1076 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -556,6 +556,100 @@ routerSpec = do it "calls f on route result" $ do get "" `shouldRespondWith` 202 +type PrioErrorsApi = ReqBody '[JSON] Person :> "foo" :> Get '[JSON] Integer + +prioErrorsApi :: Proxy PrioErrorsApi +prioErrorsApi = Proxy + +-- | Test the relative priority of error responses from the server. +-- +-- In particular, we check whether matching continues even if a 'ReqBody' +-- or similar construct is encountered early in a path. We don't want to +-- see a complaint about the request body unless the path actually matches. +-- +prioErrorsSpec :: Spec +prioErrorsSpec = describe "PrioErrors" $ do + let server = return . age + with (return $ serve prioErrorsApi server) $ do + let check (mdescr, method) path (cdescr, ctype, body) resp = + it fulldescr $ + Test.Hspec.Wai.request method path [(hContentType, ctype)] body + `shouldRespondWith` resp + where + fulldescr = "returns " ++ show (matchStatus resp) ++ " on " ++ mdescr + ++ " " ++ cs path ++ " (" ++ cdescr ++ ")" + + get' = ("GET", methodGet) + put' = ("PUT", methodPut) + + txt = ("text" , "text/plain;charset=utf8" , "42" ) + ijson = ("invalid json", "application/json;charset=utf8", "invalid" ) + vjson = ("valid json" , "application/json;charset=utf8", encode alice) + + check get' "/" txt 404 + check get' "/bar" txt 404 + check get' "/foo" txt 415 + check put' "/" txt 404 + check put' "/bar" txt 404 + check put' "/foo" txt 405 + check get' "/" ijson 404 + check get' "/bar" ijson 404 + check get' "/foo" ijson 400 + check put' "/" ijson 404 + check put' "/bar" ijson 404 + check put' "/foo" ijson 405 + check get' "/" vjson 404 + check get' "/bar" vjson 404 + check get' "/foo" vjson 200 + check put' "/" vjson 404 + check put' "/bar" vjson 404 + check put' "/foo" vjson 405 + +-- | Test server error functionality. +errorsSpec :: Spec +errorsSpec = do + let he = HttpError status409 [] (Just "A custom error") + let ib = InvalidBody "The body is invalid" + let wm = WrongMethod + let nf = NotFound + + describe "Servant.Server.Internal.RouteMismatch" $ do + it "HttpError > *" $ do + ib <> he `shouldBe` he + wm <> he `shouldBe` he + nf <> he `shouldBe` he + + he <> ib `shouldBe` he + he <> wm `shouldBe` he + he <> nf `shouldBe` he + + it "HE > InvalidBody > (WM,NF)" $ do + he <> ib `shouldBe` he + wm <> ib `shouldBe` ib + nf <> ib `shouldBe` ib + + ib <> he `shouldBe` he + ib <> wm `shouldBe` ib + ib <> nf `shouldBe` ib + + it "HE > IB > WrongMethod > NF" $ do + he <> wm `shouldBe` he + ib <> wm `shouldBe` ib + nf <> wm `shouldBe` wm + + wm <> he `shouldBe` he + wm <> ib `shouldBe` ib + wm <> nf `shouldBe` wm + + it "* > NotFound" $ do + he <> nf `shouldBe` he + ib <> nf `shouldBe` ib + wm <> nf `shouldBe` wm + + nf <> he `shouldBe` he + nf <> ib `shouldBe` ib + nf <> wm `shouldBe` wm + type MiscCombinatorsAPI = "version" :> HttpVersion :> Get '[JSON] String :<|> "secure" :> IsSecure :> Get '[JSON] String @@ -644,4 +738,4 @@ authRequiredSpec = do foo401 <- get "/foo" bar401 <- get "/bar" WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401) - WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" bar401) + WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401) diff --git a/servant/src/Servant/API/Authentication.hs b/servant/src/Servant/API/Authentication.hs index d82b4472..4b6c9981 100644 --- a/servant/src/Servant/API/Authentication.hs +++ b/servant/src/Servant/API/Authentication.hs @@ -1,24 +1,31 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_HADDOCK not-home #-} -module Servant.API.Authentication (BasicAuth) where +module Servant.API.Authentication where -import Data.Typeable (Typeable) -import GHC.TypeLits (Symbol) +import Data.ByteString (ByteString) +import Data.Typeable (Typeable) +import GHC.TypeLits (Symbol) + +-- | we can be either Strict or Lax. +-- Strict: all handlers under 'AuthProtect' take a 'usr' argument. +-- when auth fails, we call user-supplied handlers to respond. +-- Lax: all handlers under 'AuthProtect' take a 'Maybe usr' argument. +-- when auth fails, we call the handlers with 'Nothing'. +data AuthPolicy = Strict | Lax + +-- | the combinator to be used in API types +data AuthProtect authdata usr (policy :: AuthPolicy) + +-- | what we'll ask user to provide at the server-level when we see a +-- 'AuthProtect' combinator in an API type +data family AuthProtected authdata usr subserver :: AuthPolicy -> * -- | Basic Authentication with respect to a specified @realm@ and a @lookup@ -- type to encapsulate authentication logic. --- --- Example: --- >>> type MyApi = BasicAuth "book-realm" DB :> "books" :> Get '[JSON] [Book] -data BasicAuth (realm :: Symbol) lookup - deriving (Typeable) - --- $setup --- >>> import Servant.API --- >>> import Data.Aeson --- >>> import Data.Text --- >>> data DB --- >>> data Book --- >>> instance ToJSON Book where { toJSON = undefined } +data BasicAuth (realm :: Symbol) = BasicAuth { baUser :: ByteString + , baPass :: ByteString + } deriving (Eq, Show, Typeable)