diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index a94ce045..49b587ed 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -9,7 +9,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} #if !MIN_VERSION_base(4,8,0) -{-# LANGUAGE OverlappingInstances #-} +{-# LANGUAGE OverlappingInstances #-} #endif module Servant.Server.Internal @@ -27,6 +27,8 @@ import Control.Monad.Trans.Except (ExceptT) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import qualified Data.Map as M +import Data.ByteString.Base64 (decodeLenient) +import qualified Data.ByteString.Lazy as BL import Data.Maybe (mapMaybe, fromMaybe) import Data.String (fromString) import Data.String.Conversions (cs, (<>), ConvertibleStrings) @@ -34,20 +36,22 @@ import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Typeable +import Data.Word8 (isSpace, _colon, toLower) import GHC.TypeLits (KnownSymbol, symbolVal) import Network.HTTP.Types hiding (Header, ResponseHeaders) import Network.Socket (SockAddr) -import Network.Wai (Application, lazyRequestBody, - rawQueryString, requestHeaders, - requestMethod, responseLBS, remoteHost, - isSecure, vault, httpVersion, Response, - Request) -import Servant.API ((:<|>) (..), (:>), Capture, - Delete, Get, Header, - IsSecure(..), MatrixFlag, MatrixParam, - MatrixParams, Patch, Post, Put, - QueryFlag, QueryParam, QueryParams, - Raw, RemoteHost, ReqBody, Vault) +import Network.Wai (Application, Request, Response, + ResponseReceived, lazyRequestBody, + pathInfo, rawQueryString, + requestBody, requestHeaders, + requestMethod, responseLBS, + strictRequestBody) +import Servant.API ((:<|>) (..), (:>), BasicAuth, Capture, + Delete, Get, Header, + MatrixFlag, MatrixParam, MatrixParams, + Patch, Post, Put, QueryFlag, + QueryParam, QueryParams, Raw, + ReqBody) import Servant.API.ContentTypes (AcceptHeader (..), AllCTRender (..), AllCTUnrender (..)) @@ -67,6 +71,11 @@ class HasServer layout where type Server layout = ServerT layout (ExceptT ServantErr IO) +-- | A type-indexed class to encapsulate Basic authentication handling. +-- Authentication handling is indexed by the lookup type. +class BasicAuthLookup lookup a | lookup -> a where + basicAuthLookup :: Proxy lookup -> B.ByteString -> B.ByteString -> IO (Maybe a) + -- * Instances -- | A server for @a ':<|>' b@ first tries to match the request against the route @@ -230,6 +239,42 @@ instance route Proxy = methodRouterHeaders methodDelete (Proxy :: Proxy ctypes) ok200 +-- | Authentication +instance +#if MIN_VERSION_base(4,8,0) + {-# OVERLAPPABLE #-} +#endif + (HasServer sublayout, BasicAuthLookup lookup authVal) => HasServer (BasicAuth realm lookup authVal :> sublayout) where + type ServerT (BasicAuth realm lookup authVal :> sublayout) m = authVal -> ServerT sublayout m + route proxy action request response = + case lookup "Authorization" (requestHeaders request) of + Nothing -> error "handle no authorization header" -- 401 + Just authBs -> + -- ripped from: https://hackage.haskell.org/package/wai-extra-1.3.4.5/docs/src/Network-Wai-Middleware-HttpAuth.html#basicAuth + let (x,y) = B.break isSpace authBs in + if B.map toLower x == "basic" + then checkB64 (B.dropWhile isSpace y) + else error "not basic authentication" -- 401 + where + checkB64 encoded = + case B.uncons passwordWithColonAtHead of + Just (_, password) -> do + -- let's check these credentials using the user-provided lookup method + maybeAuthData <- basicAuthLookup (Proxy :: Proxy lookup) username password + case maybeAuthData of + Nothing -> error "bad password" -- 403 + (Just authData) -> + route (Proxy :: Proxy sublayout) (action authData) request response + + -- no username:password present + Nothing -> error "No password" -- 403 + where + raw = decodeLenient encoded + -- split username and password at the colon ':' char. + (username, passwordWithColonAtHead) = B.breakByte _colon raw + + + -- | When implementing the handler for a 'Get' endpoint, -- just like for 'Servant.API.Delete.Delete', 'Servant.API.Post.Post' -- and 'Servant.API.Put.Put', the handler code runs in the diff --git a/servant/src/Servant/API/Authentication.hs b/servant/src/Servant/API/Authentication.hs index d82b4472..573c85c2 100644 --- a/servant/src/Servant/API/Authentication.hs +++ b/servant/src/Servant/API/Authentication.hs @@ -12,7 +12,7 @@ import GHC.TypeLits (Symbol) -- -- Example: -- >>> type MyApi = BasicAuth "book-realm" DB :> "books" :> Get '[JSON] [Book] -data BasicAuth (realm :: Symbol) lookup +data BasicAuth (realm :: Symbol) lookup a deriving (Typeable) -- $setup