diff --git a/servant-client/src/Servant/Client.hs b/servant-client/src/Servant/Client.hs index 6ccc1777..a6d6d942 100644 --- a/servant-client/src/Servant/Client.hs +++ b/servant-client/src/Servant/Client.hs @@ -17,7 +17,6 @@ module Servant.Client ( AuthClientData , AuthenticateReq(..) - , BasicAuthData(..) , client , HasClient(..) , mkAuthenticateReq @@ -40,6 +39,7 @@ import Network.HTTP.Media import qualified Network.HTTP.Types as H import qualified Network.HTTP.Types.Header as HTTP import Servant.API +import Servant.API.Auth (BasicAuthData) import Servant.Common.Auth import Servant.Common.BaseUrl import Servant.Common.Req diff --git a/servant-client/src/Servant/Common/Auth.hs b/servant-client/src/Servant/Common/Auth.hs index a8502be2..fad23c08 100644 --- a/servant-client/src/Servant/Common/Auth.hs +++ b/servant-client/src/Servant/Common/Auth.hs @@ -7,22 +7,15 @@ module Servant.Common.Auth ( AuthenticateReq(AuthenticateReq, unAuthReq) , AuthClientData - , BasicAuthData (BasicAuthData, username, password) , basicAuthReq , mkAuthenticateReq ) where -import Data.ByteString (ByteString) import Data.ByteString.Base64 (encode) import Data.Monoid ((<>)) import Data.Text.Encoding (decodeUtf8) import Servant.Common.Req (addHeader, Req) - - --- | A simple datatype to hold data required to decorate a request -data BasicAuthData = BasicAuthData { username :: ByteString - , password :: ByteString - } +import Servant.API.Auth (BasicAuthData(BasicAuthData)) -- | Authenticate a request using Basic Authentication basicAuthReq :: BasicAuthData -> Req -> Req diff --git a/servant-client/test/Servant/ClientSpec.hs b/servant-client/test/Servant/ClientSpec.hs index 04b7e55b..bb0ff44d 100644 --- a/servant-client/test/Servant/ClientSpec.hs +++ b/servant-client/test/Servant/ClientSpec.hs @@ -51,6 +51,7 @@ import Test.HUnit import Test.QuickCheck import Servant.API +import Servant.API.Auth (BasicAuthData(BasicAuthData)) import Servant.API.Internal.Test.ComprehensiveAPI import Servant.Client import Servant.Server @@ -167,7 +168,7 @@ type instance AuthClientData (AuthProtect "auth-tag") = () basicAuthHandler :: BasicAuthCheck () basicAuthHandler = - let check username password = + let check (BasicAuthData username password) = if username == "servant" && password == "server" then return (Authorized ()) else return Unauthorized diff --git a/servant-examples/basic-auth/basic-auth.hs b/servant-examples/basic-auth/basic-auth.hs index c409f6ca..0ee22666 100644 --- a/servant-examples/basic-auth/basic-auth.hs +++ b/servant-examples/basic-auth/basic-auth.hs @@ -14,6 +14,7 @@ import GHC.Generics (Generic) import Network.Wai.Handler.Warp (run) import Servant.API ((:<|>) ((:<|>)), (:>), BasicAuth, Get, JSON) +import Servant.API.Auth (BasicAuthData(BasicAuthData)) import Servant.Server (AuthReturnType, BasicAuthResult (Authorized, Unauthorized), Config ((:.), EmptyConfig), Server, serve, BasicAuthCheck(BasicAuthCheck)) @@ -59,7 +60,7 @@ type instance AuthReturnType (BasicAuth "foo-realm") = User -- | 'BasicAuthCheck' holds the handler we'll use to verify a username and password. authCheck :: BasicAuthCheck User authCheck = - let check username password = + let check (BasicAuthData username password) = if username == "servant" && password == "server" then return (Authorized (User "servant")) else return Unauthorized diff --git a/servant-server/src/Servant/Server/Internal/Auth.hs b/servant-server/src/Servant/Server/Internal/Auth.hs index 6e15c7a5..fd279232 100644 --- a/servant-server/src/Servant/Server/Internal/Auth.hs +++ b/servant-server/src/Servant/Server/Internal/Auth.hs @@ -17,6 +17,7 @@ import GHC.Generics import Network.HTTP.Types (Header) import Network.Wai (Request, requestHeaders) +import Servant.API.Auth (BasicAuthData(BasicAuthData)) import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.ServantErr @@ -34,6 +35,8 @@ newtype AuthHandler r usr = AuthHandler mkAuthHandler :: (r -> ExceptT ServantErr IO usr) -> AuthHandler r usr mkAuthHandler = AuthHandler +-- * Basic Auth + -- | The result of authentication/authorization data BasicAuthResult usr = Unauthorized @@ -42,11 +45,9 @@ data BasicAuthResult usr | Authorized usr deriving (Eq, Show, Read, Generic, Typeable, Functor) --- * Basic Auth newtype BasicAuthCheck usr = BasicAuthCheck - { unBasicAuthCheck :: BS.ByteString -- Username - -> BS.ByteString -- Password + { unBasicAuthCheck :: BasicAuthData -> IO (BasicAuthResult usr) } deriving (Generic, Typeable, Functor) @@ -55,7 +56,7 @@ mkBAChallengerHdr :: BS.ByteString -> Header mkBAChallengerHdr realm = ("WWW-Authenticate", "Basic realm=\"" <> realm <> "\"") -- | Find and decode an 'Authorization' header from the request as Basic Auth -decodeBAHdr :: Request -> Maybe (BS.ByteString, BS.ByteString) +decodeBAHdr :: Request -> Maybe BasicAuthData decodeBAHdr req = do ah <- lookup "Authorization" $ requestHeaders req let (b, rest) = BS.break isSpace ah @@ -63,13 +64,13 @@ decodeBAHdr req = do let decoded = decodeLenient (BS.dropWhile isSpace rest) let (username, passWithColonAtHead) = BS.break (== _colon) decoded (_, password) <- BS.uncons passWithColonAtHead - return (username, password) + return (BasicAuthData username password) runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> IO (RouteResult usr) runBasicAuth req realm (BasicAuthCheck ba) = case decodeBAHdr req of Nothing -> plzAuthenticate - Just e -> uncurry ba e >>= \res -> case res of + Just e -> ba e >>= \res -> case res of BadPassword -> plzAuthenticate NoSuchUser -> plzAuthenticate Unauthorized -> return $ Fail err403 diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index 9729428b..813c05c3 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -33,7 +33,7 @@ import Network.HTTP.Types (Status (..), hAccept, hContentType, import Network.Wai (Application, Request, requestHeaders, pathInfo, queryString, rawQueryString, responseBuilder, responseLBS) -import Network.Wai.Internal (Response (ResponseBuilder), requestHeaders) +import Network.Wai.Internal (Response (ResponseBuilder)) import Network.Wai.Test (defaultRequest, request, runSession, simpleBody, simpleHeaders, simpleStatus) @@ -55,20 +55,18 @@ import qualified Test.Hspec.Wai as THW import Test.Hspec.Wai (get, liftIO, matchHeaders, matchStatus, request, shouldRespondWith, with, (<:>)) -import qualified Test.Hspec.Wai as THW +import Servant.API.Auth (BasicAuthData(BasicAuthData)) import Servant.Server.Internal.Auth (AuthHandler, AuthReturnType, BasicAuthCheck (BasicAuthCheck), BasicAuthResult (Authorized, Unauthorized), mkAuthHandler) - -import Servant.Server.Internal.Auth import Servant.Server.Internal.RoutingApplication (toApplication, RouteResult(..)) import Servant.Server.Internal.Router (tweakResponse, runRouter, Router, Router'(LeafRouter)) import Servant.Server.Internal.Config - (Config(..), NamedConfig(..)) + (NamedConfig(NamedConfig)) -- * comprehensive api test @@ -554,7 +552,7 @@ authConfig :: Config '[ BasicAuthCheck () , AuthHandler Request () ] authConfig = - let basicHandler = BasicAuthCheck $ (\usr pass -> + let basicHandler = BasicAuthCheck $ (\(BasicAuthData usr pass) -> if usr == "servant" && pass == "server" then return (Authorized ()) else return Unauthorized diff --git a/servant/src/Servant/API/Auth.hs b/servant/src/Servant/API/Auth.hs index 00a11adf..5aa2638c 100644 --- a/servant/src/Servant/API/Auth.hs +++ b/servant/src/Servant/API/Auth.hs @@ -4,6 +4,7 @@ {-# LANGUAGE PolyKinds #-} module Servant.API.Auth where +import Data.ByteString (ByteString) import Data.Typeable (Typeable) import GHC.TypeLits (Symbol) @@ -20,6 +21,11 @@ import GHC.TypeLits (Symbol) data BasicAuth (realm :: Symbol) deriving (Typeable) +-- | A simple datatype to hold data required to decorate a request +data BasicAuthData = BasicAuthData { basicAuthUsername :: !ByteString + , basicAuthPassword :: !ByteString + } + -- | A generalized Authentication combinator. data AuthProtect (tag :: k) deriving (Typeable)