diff --git a/servant-client/servant-client.cabal b/servant-client/servant-client.cabal index 0d70cea6..c908be71 100644 --- a/servant-client/servant-client.cabal +++ b/servant-client/servant-client.cabal @@ -26,6 +26,7 @@ source-repository head library exposed-modules: Servant.Client + Servant.Client.Authentication Servant.Common.BaseUrl Servant.Common.Req build-depends: @@ -64,6 +65,7 @@ test-suite spec , transformers , transformers-compat , aeson + , base64-bytestring , bytestring , deepseq , hspec == 2.* diff --git a/servant-client/src/Servant/Client.hs b/servant-client/src/Servant/Client.hs index 657fe5af..d88905d5 100644 --- a/servant-client/src/Servant/Client.hs +++ b/servant-client/src/Servant/Client.hs @@ -37,6 +37,8 @@ import Network.HTTP.Media import qualified Network.HTTP.Types as H import qualified Network.HTTP.Types.Header as HTTP import Servant.API +import Servant.API.Authentication (AuthProtect) +import Servant.Client.Authentication (AuthenticateRequest(authReq)) import Servant.Common.BaseUrl import Servant.Common.Req @@ -119,6 +121,15 @@ instance (KnownSymbol capture, ToText a, HasClient sublayout) where p = unpack (toText val) +-- | Authentication +instance (AuthenticateRequest authdata, HasClient sublayout) => HasClient (AuthProtect authdata (usr :: *) policy :> sublayout) where + type Client (AuthProtect authdata usr policy :> sublayout) = authdata -> Client sublayout + + clientWithRoute Proxy req baseurl val = + clientWithRoute (Proxy :: Proxy sublayout) + (authReq val req) + baseurl + -- | If you have a 'Delete' endpoint in your API, the client -- side querying function that is created when calling 'client' -- will just require an argument that specifies the scheme, host diff --git a/servant-client/src/Servant/Client/Authentication.hs b/servant-client/src/Servant/Client/Authentication.hs new file mode 100644 index 00000000..1d7d43d9 --- /dev/null +++ b/servant-client/src/Servant/Client/Authentication.hs @@ -0,0 +1,12 @@ +-- | Authentication for clients + +module Servant.Client.Authentication ( + AuthenticateRequest ( authReq ) + ) where + +import Servant.Common.Req (Req) + +-- | Class to represent the ability to authenticate a 'Request' +-- object. For example, we may add special headers to the 'Request'. +class AuthenticateRequest a where + authReq :: a -> Req -> Req diff --git a/servant-client/test/Servant/ClientSpec.hs b/servant-client/test/Servant/ClientSpec.hs index 9db7c1a9..5acac845 100644 --- a/servant-client/test/Servant/ClientSpec.hs +++ b/servant-client/test/Servant/ClientSpec.hs @@ -20,12 +20,14 @@ import Control.Concurrent import Control.Exception import Control.Monad.Trans.Except import Data.Aeson +import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Lazy (ByteString) import Data.Char import Data.Foldable (forM_) import Data.Monoid import Data.Proxy import qualified Data.Text as T +import qualified Data.Text.Encoding as TE import GHC.Generics import qualified Network.HTTP.Client as C import Network.HTTP.Media @@ -40,8 +42,12 @@ import Test.HUnit import Test.QuickCheck import Servant.API +import Servant.API.Authentication import Servant.Client +import qualified Servant.Common.Req as SCR +import Servant.Client.Authentication (AuthenticateRequest(authReq)) import Servant.Server +import Servant.Server.Internal.Authentication -- * test data types @@ -94,6 +100,24 @@ type Api = Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])]) :<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool) :<|> "deleteContentType" :> Delete '[JSON] () + :<|> AuthProtect (BasicAuth "realm") Person 'Strict :> Get '[JSON] Person + +-- base64-encoded "servant:server" +base64ServantColonServer :: ByteString +base64ServantColonServer = "c2VydmFudDpzZXJ2ZXI=" + +type AuthUser = T.Text + +basicAuthCheck :: BasicAuth "realm" -> IO (Maybe Person) +basicAuthCheck (BasicAuth user pass) = if user == "servant" && pass == "server" + then return (Just $ Person "servant" 17) + else return Nothing + +instance AuthenticateRequest (BasicAuth realm) where + authReq (BasicAuth user pass) req = + let authText = TE.decodeUtf8 ("Basic " <> B64.encode (user <> ":" <> pass)) in + SCR.addHeader "Authorization" authText req + api :: Proxy Api api = Proxy @@ -120,6 +144,7 @@ server = serve api ( :<|> (\ a b c d -> return (a, b, c, d)) :<|> (return $ addHeader 1729 $ addHeader "eg2" True) :<|> return () + :<|> basicAuthStrict basicAuthCheck (const . return $ alice) ) withServer :: (BaseUrl -> IO a) -> IO a @@ -144,6 +169,7 @@ withFailServer action = withWaiDaemon (return failServer) action spec :: IO () spec = withServer $ \ baseUrl -> do +<<<<<<< HEAD manager <- C.newManager C.defaultManagerSettings let getGet :: ExceptT ServantError IO Person getDeleteEmpty :: ExceptT ServantError IO () @@ -174,7 +200,8 @@ spec = withServer $ \ baseUrl -> do :<|> getRawFailure :<|> getMultiple :<|> getRespHeaders - :<|> getDeleteContentType) + :<|> getDeleteContentType + :<|> getPrivatePerson) = client api baseUrl manager hspec $ do @@ -249,6 +276,15 @@ spec = withServer $ \ baseUrl -> do Left e -> assertFailure $ show e Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")] + it "Handles Authentication appropriatley" $ withServer $ \ _ -> do + (Arrow.left show <$> runExceptT (getPrivatePerson (BasicAuth "servant" "server"))) `shouldReturn` Right alice + + it "returns 401 when not properly authenticated" $ do + Left res <- runExceptT (getPrivatePerson (BasicAuth "xxx" "yyy")) + case res of + FailureResponse (Status 401 _) _ _ -> return () + _ -> fail $ "expcted 401 response, but got " <> show res + modifyMaxSuccess (const 20) $ do it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ property $ forAllShrink pathGen shrink $ \(NonEmpty cap) num flag body -> @@ -257,7 +293,6 @@ spec = withServer $ \ baseUrl -> do return $ result === Right (cap, num, flag, body) - context "client correctly handles error status codes" $ do let test :: (WrappedApi, String) -> Spec test (WrappedApi api, desc) = @@ -323,6 +358,7 @@ failSpec = withFailServer $ \ baseUrl -> do InvalidContentTypeHeader "fooooo" _ -> return () _ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res + data WrappedApi where WrappedApi :: (HasServer api, Server api ~ ExceptT ServantErr IO a, HasClient api, Client api ~ ExceptT ServantError IO ()) =>