Add general auth support to servant-client

This commit is contained in:
aaron levin 2016-02-17 21:45:08 +01:00
parent 0461c4642d
commit 23da4879ef
4 changed files with 90 additions and 4 deletions

View file

@ -27,6 +27,7 @@ source-repository head
library library
exposed-modules: exposed-modules:
Servant.Client Servant.Client
Servant.Common.Auth
Servant.Common.BaseUrl Servant.Common.BaseUrl
Servant.Common.BasicAuth Servant.Common.BasicAuth
Servant.Common.Req Servant.Common.Req

View file

@ -15,8 +15,11 @@
-- querying functions for each endpoint just from the type representing your -- querying functions for each endpoint just from the type representing your
-- API. -- API.
module Servant.Client module Servant.Client
( client ( AuthClientData
, AuthenticateReq(..)
, client
, HasClient(..) , HasClient(..)
, mkAuthenticateReq
, ServantError(..) , ServantError(..)
, module Servant.Common.BaseUrl , module Servant.Common.BaseUrl
) where ) where
@ -36,6 +39,7 @@ import Network.HTTP.Media
import qualified Network.HTTP.Types as H import qualified Network.HTTP.Types as H
import qualified Network.HTTP.Types.Header as HTTP import qualified Network.HTTP.Types.Header as HTTP
import Servant.API import Servant.API
import Servant.Common.Auth
import Servant.Common.BaseUrl import Servant.Common.BaseUrl
import Servant.Common.BasicAuth import Servant.Common.BasicAuth
import Servant.Common.Req import Servant.Common.Req
@ -424,6 +428,13 @@ instance HasClient subapi =>
type Client (WithNamedContext name context subapi) = Client subapi type Client (WithNamedContext name context subapi) = Client subapi
clientWithRoute Proxy = clientWithRoute (Proxy :: Proxy subapi) clientWithRoute Proxy = clientWithRoute (Proxy :: Proxy subapi)
instance ( HasClient api
) => HasClient (AuthProtect tag :> api) where
type Client (AuthProtect tag :> api)
= AuthenticateReq (AuthProtect tag) -> Client api
clientWithRoute Proxy req baseurl manager (AuthenticateReq (val,func)) =
clientWithRoute (Proxy :: Proxy api) (func val req) baseurl manager
-- * Basic Authentication -- * Basic Authentication

View file

@ -0,0 +1,30 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE TypeFamilies #-}
-- | Authentication for clients
module Servant.Common.Auth (
AuthenticateReq(AuthenticateReq, unAuthReq)
, AuthClientData
, mkAuthenticateReq
) where
import Servant.Common.Req (Req)
-- | For a resource protected by authentication (e.g. AuthProtect), we need
-- to provide the client with some data used to add authentication data
-- to a request
type family AuthClientData a :: *
-- | For better type inference and to avoid usage of a data family, we newtype
-- wrap the combination of some 'AuthClientData' and a function to add authentication
-- data to a request
newtype AuthenticateReq a =
AuthenticateReq { unAuthReq :: (AuthClientData a, AuthClientData a -> Req -> Req) }
-- | Handy helper to avoid wrapping datatypes in tuples everywhere.
mkAuthenticateReq :: AuthClientData a
-> (AuthClientData a -> Req -> Req)
-> AuthenticateReq a
mkAuthenticateReq val func = AuthenticateReq (val, func)

View file

@ -12,6 +12,7 @@
{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fcontext-stack=100 #-} {-# OPTIONS_GHC -fcontext-stack=100 #-}
@ -41,7 +42,8 @@ import Network.HTTP.Media
import Network.HTTP.Types (Status (..), badRequest400, import Network.HTTP.Types (Status (..), badRequest400,
methodGet, ok200, status400) methodGet, ok200, status400)
import Network.Socket import Network.Socket
import Network.Wai (Application, responseLBS) import Network.Wai (Application, Request,
requestHeaders, responseLBS)
import Network.Wai.Handler.Warp import Network.Wai.Handler.Warp
import System.IO.Unsafe (unsafePerformIO) import System.IO.Unsafe (unsafePerformIO)
import Test.Hspec import Test.Hspec
@ -53,6 +55,7 @@ import Servant.API
import Servant.API.Internal.Test.ComprehensiveAPI import Servant.API.Internal.Test.ComprehensiveAPI
import Servant.Client import Servant.Client
import Servant.Server import Servant.Server
import qualified Servant.Common.Req as SCR
-- This declaration simply checks that all instances are in place. -- This declaration simply checks that all instances are in place.
_ = client comprehensiveAPI _ = client comprehensiveAPI
@ -63,6 +66,7 @@ spec = describe "Servant.Client" $ do
failSpec failSpec
wrappedApiSpec wrappedApiSpec
basicAuthSpec basicAuthSpec
genAuthSpec
-- * test data types -- * test data types
@ -149,8 +153,7 @@ failServer = serve failApi (
:<|> (\_request respond -> respond $ responseLBS ok200 [("content-type", "fooooo")] "") :<|> (\_request respond -> respond $ responseLBS ok200 [("content-type", "fooooo")] "")
) )
-- * basic auth stuff
-- * auth stuff
type BasicAuthAPI = type BasicAuthAPI =
BasicAuth "foo-realm" () :> "private" :> "basic" :> Get '[JSON] Person BasicAuth "foo-realm" () :> "private" :> "basic" :> Get '[JSON] Person
@ -172,6 +175,30 @@ serverContext = basicAuthHandler :. EmptyContext
basicAuthServer :: Application basicAuthServer :: Application
basicAuthServer = serveWithContext basicAuthAPI serverContext (const (return alice)) basicAuthServer = serveWithContext basicAuthAPI serverContext (const (return alice))
-- * general auth stuff
type GenAuthAPI =
AuthProtect "auth-tag" :> "private" :> "auth" :> Get '[JSON] Person
genAuthAPI :: Proxy GenAuthAPI
genAuthAPI = Proxy
type instance AuthServerData (AuthProtect "auth-tag") = ()
type instance AuthClientData (AuthProtect "auth-tag") = ()
genAuthHandler :: AuthHandler Request ()
genAuthHandler =
let handler req = case lookup "AuthHeader" (requestHeaders req) of
Nothing -> throwE (err401 { errBody = "Missing auth header" })
Just _ -> return ()
in mkAuthHandler handler
serverConfig :: Config '[ AuthHandler Request () ]
serverConfig = genAuthHandler :. EmptyConfig
genAuthServer :: Application
genAuthServer = serve genAuthAPI serverConfig (const (return alice))
{-# NOINLINE manager #-} {-# NOINLINE manager #-}
manager :: C.Manager manager :: C.Manager
manager = unsafePerformIO $ C.newManager C.defaultManagerSettings manager = unsafePerformIO $ C.newManager C.defaultManagerSettings
@ -333,6 +360,23 @@ basicAuthSpec = beforeAll (startWaiApp basicAuthServer) $ afterAll endWaiApp $ d
Left FailureResponse{..} <- runExceptT (getBasic basicAuthData) Left FailureResponse{..} <- runExceptT (getBasic basicAuthData)
responseStatus `shouldBe` Status 403 "Forbidden" responseStatus `shouldBe` Status 403 "Forbidden"
genAuthSpec :: Spec
genAuthSpec = beforeAll (startWaiApp genAuthServer) $ afterAll endWaiApp $ do
context "Authentication works when requests are properly authenticated" $ do
it "Authenticates a AuthProtect protected server appropriately" $ \(_, baseUrl) -> do
let getProtected = client genAuthAPI baseUrl manager
let authRequest = mkAuthenticateReq () (\_ req -> SCR.addHeader "AuthHeader" ("cool" :: String) req)
(left show <$> runExceptT (getProtected authRequest)) `shouldReturn` Right alice
context "Authentication is rejected when requests are not authenticated properly" $ do
it "Authenticates a AuthProtect protected server appropriately" $ \(_, baseUrl) -> do
let getProtected = client genAuthAPI baseUrl manager
let authRequest = mkAuthenticateReq () (\_ req -> SCR.addHeader "Wrong" ("header" :: String) req)
Left FailureResponse{..} <- runExceptT (getProtected authRequest)
responseStatus `shouldBe` (Status 401 "Unauthorized")
-- * utils -- * utils
startWaiApp :: Application -> IO (ThreadId, BaseUrl) startWaiApp :: Application -> IO (ThreadId, BaseUrl)