diff --git a/servant-client/servant-client.cabal b/servant-client/servant-client.cabal index 8e20f1a3..124d6307 100644 --- a/servant-client/servant-client.cabal +++ b/servant-client/servant-client.cabal @@ -27,6 +27,7 @@ source-repository head library exposed-modules: Servant.Client + Servant.Common.Auth Servant.Common.BaseUrl Servant.Common.BasicAuth Servant.Common.Req diff --git a/servant-client/src/Servant/Client.hs b/servant-client/src/Servant/Client.hs index d3373708..d62515dc 100644 --- a/servant-client/src/Servant/Client.hs +++ b/servant-client/src/Servant/Client.hs @@ -15,8 +15,11 @@ -- querying functions for each endpoint just from the type representing your -- API. module Servant.Client - ( client + ( AuthClientData + , AuthenticateReq(..) + , client , HasClient(..) + , mkAuthenticateReq , ServantError(..) , module Servant.Common.BaseUrl ) where @@ -36,6 +39,7 @@ import Network.HTTP.Media import qualified Network.HTTP.Types as H import qualified Network.HTTP.Types.Header as HTTP import Servant.API +import Servant.Common.Auth import Servant.Common.BaseUrl import Servant.Common.BasicAuth import Servant.Common.Req @@ -424,6 +428,13 @@ instance HasClient subapi => type Client (WithNamedContext name context subapi) = Client subapi clientWithRoute Proxy = clientWithRoute (Proxy :: Proxy subapi) +instance ( HasClient api + ) => HasClient (AuthProtect tag :> api) where + type Client (AuthProtect tag :> api) + = AuthenticateReq (AuthProtect tag) -> Client api + + clientWithRoute Proxy req baseurl manager (AuthenticateReq (val,func)) = + clientWithRoute (Proxy :: Proxy api) (func val req) baseurl manager -- * Basic Authentication diff --git a/servant-client/src/Servant/Common/Auth.hs b/servant-client/src/Servant/Common/Auth.hs new file mode 100644 index 00000000..5e450bd8 --- /dev/null +++ b/servant-client/src/Servant/Common/Auth.hs @@ -0,0 +1,30 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE TypeFamilies #-} + +-- | Authentication for clients + +module Servant.Common.Auth ( + AuthenticateReq(AuthenticateReq, unAuthReq) + , AuthClientData + , mkAuthenticateReq + ) where + +import Servant.Common.Req (Req) + +-- | For a resource protected by authentication (e.g. AuthProtect), we need +-- to provide the client with some data used to add authentication data +-- to a request +type family AuthClientData a :: * + +-- | For better type inference and to avoid usage of a data family, we newtype +-- wrap the combination of some 'AuthClientData' and a function to add authentication +-- data to a request +newtype AuthenticateReq a = + AuthenticateReq { unAuthReq :: (AuthClientData a, AuthClientData a -> Req -> Req) } + +-- | Handy helper to avoid wrapping datatypes in tuples everywhere. +mkAuthenticateReq :: AuthClientData a + -> (AuthClientData a -> Req -> Req) + -> AuthenticateReq a +mkAuthenticateReq val func = AuthenticateReq (val, func) diff --git a/servant-client/test/Servant/ClientSpec.hs b/servant-client/test/Servant/ClientSpec.hs index 4b6ccbb9..0ee1ed01 100644 --- a/servant-client/test/Servant/ClientSpec.hs +++ b/servant-client/test/Servant/ClientSpec.hs @@ -12,6 +12,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fcontext-stack=100 #-} @@ -41,7 +42,8 @@ import Network.HTTP.Media import Network.HTTP.Types (Status (..), badRequest400, methodGet, ok200, status400) import Network.Socket -import Network.Wai (Application, responseLBS) +import Network.Wai (Application, Request, + requestHeaders, responseLBS) import Network.Wai.Handler.Warp import System.IO.Unsafe (unsafePerformIO) import Test.Hspec @@ -53,6 +55,7 @@ import Servant.API import Servant.API.Internal.Test.ComprehensiveAPI import Servant.Client import Servant.Server +import qualified Servant.Common.Req as SCR -- This declaration simply checks that all instances are in place. _ = client comprehensiveAPI @@ -63,6 +66,7 @@ spec = describe "Servant.Client" $ do failSpec wrappedApiSpec basicAuthSpec + genAuthSpec -- * test data types @@ -149,8 +153,7 @@ failServer = serve failApi ( :<|> (\_request respond -> respond $ responseLBS ok200 [("content-type", "fooooo")] "") ) - --- * auth stuff +-- * basic auth stuff type BasicAuthAPI = BasicAuth "foo-realm" () :> "private" :> "basic" :> Get '[JSON] Person @@ -172,6 +175,30 @@ serverContext = basicAuthHandler :. EmptyContext basicAuthServer :: Application basicAuthServer = serveWithContext basicAuthAPI serverContext (const (return alice)) +-- * general auth stuff + +type GenAuthAPI = + AuthProtect "auth-tag" :> "private" :> "auth" :> Get '[JSON] Person + +genAuthAPI :: Proxy GenAuthAPI +genAuthAPI = Proxy + +type instance AuthServerData (AuthProtect "auth-tag") = () +type instance AuthClientData (AuthProtect "auth-tag") = () + +genAuthHandler :: AuthHandler Request () +genAuthHandler = + let handler req = case lookup "AuthHeader" (requestHeaders req) of + Nothing -> throwE (err401 { errBody = "Missing auth header" }) + Just _ -> return () + in mkAuthHandler handler + +serverConfig :: Config '[ AuthHandler Request () ] +serverConfig = genAuthHandler :. EmptyConfig + +genAuthServer :: Application +genAuthServer = serve genAuthAPI serverConfig (const (return alice)) + {-# NOINLINE manager #-} manager :: C.Manager manager = unsafePerformIO $ C.newManager C.defaultManagerSettings @@ -333,6 +360,23 @@ basicAuthSpec = beforeAll (startWaiApp basicAuthServer) $ afterAll endWaiApp $ d Left FailureResponse{..} <- runExceptT (getBasic basicAuthData) responseStatus `shouldBe` Status 403 "Forbidden" +genAuthSpec :: Spec +genAuthSpec = beforeAll (startWaiApp genAuthServer) $ afterAll endWaiApp $ do + context "Authentication works when requests are properly authenticated" $ do + + it "Authenticates a AuthProtect protected server appropriately" $ \(_, baseUrl) -> do + let getProtected = client genAuthAPI baseUrl manager + let authRequest = mkAuthenticateReq () (\_ req -> SCR.addHeader "AuthHeader" ("cool" :: String) req) + (left show <$> runExceptT (getProtected authRequest)) `shouldReturn` Right alice + + context "Authentication is rejected when requests are not authenticated properly" $ do + + it "Authenticates a AuthProtect protected server appropriately" $ \(_, baseUrl) -> do + let getProtected = client genAuthAPI baseUrl manager + let authRequest = mkAuthenticateReq () (\_ req -> SCR.addHeader "Wrong" ("header" :: String) req) + Left FailureResponse{..} <- runExceptT (getProtected authRequest) + responseStatus `shouldBe` (Status 401 "Unauthorized") + -- * utils startWaiApp :: Application -> IO (ThreadId, BaseUrl)