Add general auth support to servant-client

This commit is contained in:
aaron levin 2016-02-17 21:45:08 +01:00
parent 0461c4642d
commit 23da4879ef
4 changed files with 90 additions and 4 deletions

View file

@ -27,6 +27,7 @@ source-repository head
library
exposed-modules:
Servant.Client
Servant.Common.Auth
Servant.Common.BaseUrl
Servant.Common.BasicAuth
Servant.Common.Req

View file

@ -15,8 +15,11 @@
-- querying functions for each endpoint just from the type representing your
-- API.
module Servant.Client
( client
( AuthClientData
, AuthenticateReq(..)
, client
, HasClient(..)
, mkAuthenticateReq
, ServantError(..)
, module Servant.Common.BaseUrl
) where
@ -36,6 +39,7 @@ import Network.HTTP.Media
import qualified Network.HTTP.Types as H
import qualified Network.HTTP.Types.Header as HTTP
import Servant.API
import Servant.Common.Auth
import Servant.Common.BaseUrl
import Servant.Common.BasicAuth
import Servant.Common.Req
@ -424,6 +428,13 @@ instance HasClient subapi =>
type Client (WithNamedContext name context subapi) = Client subapi
clientWithRoute Proxy = clientWithRoute (Proxy :: Proxy subapi)
instance ( HasClient api
) => HasClient (AuthProtect tag :> api) where
type Client (AuthProtect tag :> api)
= AuthenticateReq (AuthProtect tag) -> Client api
clientWithRoute Proxy req baseurl manager (AuthenticateReq (val,func)) =
clientWithRoute (Proxy :: Proxy api) (func val req) baseurl manager
-- * Basic Authentication

View file

@ -0,0 +1,30 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE TypeFamilies #-}
-- | Authentication for clients
module Servant.Common.Auth (
AuthenticateReq(AuthenticateReq, unAuthReq)
, AuthClientData
, mkAuthenticateReq
) where
import Servant.Common.Req (Req)
-- | For a resource protected by authentication (e.g. AuthProtect), we need
-- to provide the client with some data used to add authentication data
-- to a request
type family AuthClientData a :: *
-- | For better type inference and to avoid usage of a data family, we newtype
-- wrap the combination of some 'AuthClientData' and a function to add authentication
-- data to a request
newtype AuthenticateReq a =
AuthenticateReq { unAuthReq :: (AuthClientData a, AuthClientData a -> Req -> Req) }
-- | Handy helper to avoid wrapping datatypes in tuples everywhere.
mkAuthenticateReq :: AuthClientData a
-> (AuthClientData a -> Req -> Req)
-> AuthenticateReq a
mkAuthenticateReq val func = AuthenticateReq (val, func)

View file

@ -12,6 +12,7 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fcontext-stack=100 #-}
@ -41,7 +42,8 @@ import Network.HTTP.Media
import Network.HTTP.Types (Status (..), badRequest400,
methodGet, ok200, status400)
import Network.Socket
import Network.Wai (Application, responseLBS)
import Network.Wai (Application, Request,
requestHeaders, responseLBS)
import Network.Wai.Handler.Warp
import System.IO.Unsafe (unsafePerformIO)
import Test.Hspec
@ -53,6 +55,7 @@ import Servant.API
import Servant.API.Internal.Test.ComprehensiveAPI
import Servant.Client
import Servant.Server
import qualified Servant.Common.Req as SCR
-- This declaration simply checks that all instances are in place.
_ = client comprehensiveAPI
@ -63,6 +66,7 @@ spec = describe "Servant.Client" $ do
failSpec
wrappedApiSpec
basicAuthSpec
genAuthSpec
-- * test data types
@ -149,8 +153,7 @@ failServer = serve failApi (
:<|> (\_request respond -> respond $ responseLBS ok200 [("content-type", "fooooo")] "")
)
-- * auth stuff
-- * basic auth stuff
type BasicAuthAPI =
BasicAuth "foo-realm" () :> "private" :> "basic" :> Get '[JSON] Person
@ -172,6 +175,30 @@ serverContext = basicAuthHandler :. EmptyContext
basicAuthServer :: Application
basicAuthServer = serveWithContext basicAuthAPI serverContext (const (return alice))
-- * general auth stuff
type GenAuthAPI =
AuthProtect "auth-tag" :> "private" :> "auth" :> Get '[JSON] Person
genAuthAPI :: Proxy GenAuthAPI
genAuthAPI = Proxy
type instance AuthServerData (AuthProtect "auth-tag") = ()
type instance AuthClientData (AuthProtect "auth-tag") = ()
genAuthHandler :: AuthHandler Request ()
genAuthHandler =
let handler req = case lookup "AuthHeader" (requestHeaders req) of
Nothing -> throwE (err401 { errBody = "Missing auth header" })
Just _ -> return ()
in mkAuthHandler handler
serverConfig :: Config '[ AuthHandler Request () ]
serverConfig = genAuthHandler :. EmptyConfig
genAuthServer :: Application
genAuthServer = serve genAuthAPI serverConfig (const (return alice))
{-# NOINLINE manager #-}
manager :: C.Manager
manager = unsafePerformIO $ C.newManager C.defaultManagerSettings
@ -333,6 +360,23 @@ basicAuthSpec = beforeAll (startWaiApp basicAuthServer) $ afterAll endWaiApp $ d
Left FailureResponse{..} <- runExceptT (getBasic basicAuthData)
responseStatus `shouldBe` Status 403 "Forbidden"
genAuthSpec :: Spec
genAuthSpec = beforeAll (startWaiApp genAuthServer) $ afterAll endWaiApp $ do
context "Authentication works when requests are properly authenticated" $ do
it "Authenticates a AuthProtect protected server appropriately" $ \(_, baseUrl) -> do
let getProtected = client genAuthAPI baseUrl manager
let authRequest = mkAuthenticateReq () (\_ req -> SCR.addHeader "AuthHeader" ("cool" :: String) req)
(left show <$> runExceptT (getProtected authRequest)) `shouldReturn` Right alice
context "Authentication is rejected when requests are not authenticated properly" $ do
it "Authenticates a AuthProtect protected server appropriately" $ \(_, baseUrl) -> do
let getProtected = client genAuthAPI baseUrl manager
let authRequest = mkAuthenticateReq () (\_ req -> SCR.addHeader "Wrong" ("header" :: String) req)
Left FailureResponse{..} <- runExceptT (getProtected authRequest)
responseStatus `shouldBe` (Status 401 "Unauthorized")
-- * utils
startWaiApp :: Application -> IO (ThreadId, BaseUrl)