Add authentication to servant-client

This commit is contained in:
aaron levin 2015-08-18 15:35:36 -04:00 committed by aaron levin
parent 3f5106da77
commit 948951b6e9
4 changed files with 98 additions and 2 deletions

View file

@ -26,6 +26,7 @@ source-repository head
library library
exposed-modules: exposed-modules:
Servant.Client Servant.Client
Servant.Client.Authentication
Servant.Common.BaseUrl Servant.Common.BaseUrl
Servant.Common.Req Servant.Common.Req
build-depends: build-depends:
@ -65,6 +66,7 @@ test-suite spec
, transformers , transformers
, transformers-compat , transformers-compat
, aeson , aeson
, base64-bytestring
, bytestring , bytestring
, deepseq , deepseq
, hspec == 2.* , hspec == 2.*

View file

@ -37,6 +37,8 @@ import Network.HTTP.Media
import qualified Network.HTTP.Types as H import qualified Network.HTTP.Types as H
import qualified Network.HTTP.Types.Header as HTTP import qualified Network.HTTP.Types.Header as HTTP
import Servant.API import Servant.API
import Servant.API.Authentication (AuthProtect)
import Servant.Client.Authentication (AuthenticateRequest(authReq))
import Servant.Common.BaseUrl import Servant.Common.BaseUrl
import Servant.Common.Req import Servant.Common.Req
@ -119,6 +121,15 @@ instance (KnownSymbol capture, ToHttpApiData a, HasClient sublayout)
where p = unpack (toUrlPiece val) where p = unpack (toUrlPiece val)
-- | Authentication
instance (AuthenticateRequest authdata, HasClient sublayout) => HasClient (AuthProtect authdata (usr :: *) policy :> sublayout) where
type Client (AuthProtect authdata usr policy :> sublayout) = authdata -> Client sublayout
clientWithRoute Proxy req baseurl val =
clientWithRoute (Proxy :: Proxy sublayout)
(authReq val req)
baseurl
-- | If you have a 'Delete' endpoint in your API, the client -- | If you have a 'Delete' endpoint in your API, the client
-- side querying function that is created when calling 'client' -- side querying function that is created when calling 'client'
-- will just require an argument that specifies the scheme, host -- will just require an argument that specifies the scheme, host

View file

@ -0,0 +1,12 @@
-- | Authentication for clients
module Servant.Client.Authentication (
AuthenticateRequest ( authReq )
) where
import Servant.Common.Req (Req)
-- | Class to represent the ability to authenticate a 'Request'
-- object. For example, we may add special headers to the 'Request'.
class AuthenticateRequest a where
authReq :: a -> Req -> Req

View file

@ -30,13 +30,16 @@ import Control.Concurrent (forkIO, killThread, ThreadId)
import Control.Exception (bracket) import Control.Exception (bracket)
import Control.Monad.Trans.Except (ExceptT, runExceptT, throwE) import Control.Monad.Trans.Except (ExceptT, runExceptT, throwE)
import Data.Aeson import Data.Aeson
import Data.Char (chr, isPrint) import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Lazy (ByteString)
import Data.Foldable (forM_) import Data.Foldable (forM_)
import Data.Monoid hiding (getLast) import Data.Monoid hiding (getLast)
import Data.Proxy import Data.Proxy
import qualified Data.Text as T import qualified Data.Text as T
import GHC.Generics (Generic) import GHC.Generics (Generic)
import GHC.TypeLits import GHC.TypeLits
import qualified Data.Text.Encoding as TE
import GHC.Generics
import qualified Network.HTTP.Client as C import qualified Network.HTTP.Client as C
import Network.HTTP.Media import Network.HTTP.Media
import Network.HTTP.Types (Status (..), badRequest400, import Network.HTTP.Types (Status (..), badRequest400,
@ -51,8 +54,12 @@ import Test.HUnit
import Test.QuickCheck import Test.QuickCheck
import Servant.API import Servant.API
import Servant.API.Authentication
import Servant.Client import Servant.Client
import qualified Servant.Common.Req as SCR
import Servant.Client.Authentication (AuthenticateRequest(authReq))
import Servant.Server import Servant.Server
import Servant.Server.Internal.Authentication
spec :: Spec spec :: Spec
spec = describe "Servant.Client" $ do spec = describe "Servant.Client" $ do
@ -108,6 +115,24 @@ type Api =
Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])]) Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])])
:<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool) :<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool)
:<|> "deleteContentType" :> Delete '[JSON] () :<|> "deleteContentType" :> Delete '[JSON] ()
:<|> AuthProtect (BasicAuth "realm") Person 'Strict :> Get '[JSON] Person
-- base64-encoded "servant:server"
base64ServantColonServer :: ByteString
base64ServantColonServer = "c2VydmFudDpzZXJ2ZXI="
type AuthUser = T.Text
basicAuthCheck :: BasicAuth "realm" -> IO (Maybe Person)
basicAuthCheck (BasicAuth user pass) = if user == "servant" && pass == "server"
then return (Just $ Person "servant" 17)
else return Nothing
instance AuthenticateRequest (BasicAuth realm) where
authReq (BasicAuth user pass) req =
let authText = TE.decodeUtf8 ("Basic " <> B64.encode (user <> ":" <> pass)) in
SCR.addHeader "Authorization" authText req
api :: Proxy Api api :: Proxy Api
api = Proxy api = Proxy
@ -128,6 +153,7 @@ server = serve api (
:<|> (\ a b c d -> return (a, b, c, d)) :<|> (\ a b c d -> return (a, b, c, d))
:<|> (return $ addHeader 1729 $ addHeader "eg2" True) :<|> (return $ addHeader 1729 $ addHeader "eg2" True)
:<|> return () :<|> return ()
:<|> basicAuthStrict basicAuthCheck (const . return $ alice)
) )
@ -154,6 +180,42 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do
it "Servant.API.Get" $ \(_, baseUrl) -> do it "Servant.API.Get" $ \(_, baseUrl) -> do
let getGet = getNth (Proxy :: Proxy 0) $ client api baseUrl manager let getGet = getNth (Proxy :: Proxy 0) $ client api baseUrl manager
manager <- C.newManager C.defaultManagerSettings
let getGet :: ExceptT ServantError IO Person
getDeleteEmpty :: ExceptT ServantError IO ()
getCapture :: String -> ExceptT ServantError IO Person
getBody :: Person -> ExceptT ServantError IO Person
getQueryParam :: Maybe String -> ExceptT ServantError IO Person
getQueryParams :: [String] -> ExceptT ServantError IO [Person]
getQueryFlag :: Bool -> ExceptT ServantError IO Bool
getMatrixParam :: Maybe String -> ExceptT ServantError IO Person
getMatrixParams :: [String] -> ExceptT ServantError IO [Person]
getMatrixFlag :: Bool -> ExceptT ServantError IO Bool
getRawSuccess :: Method -> ExceptT ServantError IO (Int, ByteString, MediaType, [HTTP.Header], C.Response ByteString)
getRawFailure :: Method -> ExceptT ServantError IO (Int, ByteString, MediaType, [HTTP.Header], C.Response ByteString)
getMultiple :: String -> Maybe Int -> Bool -> [(String, [Rational])] -> ExceptT ServantError IO (String, Maybe Int, Bool, [(String, [Rational])])
getRespHeaders :: ExceptT ServantError IO (Headers TestHeaders Bool)
getDeleteContentType :: ExceptT ServantError IO ()
( getGet
:<|> getDeleteEmpty
:<|> getCapture
:<|> getBody
:<|> getQueryParam
:<|> getQueryParams
:<|> getQueryFlag
:<|> getMatrixParam
:<|> getMatrixParams
:<|> getMatrixFlag
:<|> getRawSuccess
:<|> getRawFailure
:<|> getMultiple
:<|> getRespHeaders
:<|> getDeleteContentType
:<|> getPrivatePerson)
= client api baseUrl manager
hspec $ do
it "Servant.API.Get" $ do
(left show <$> runExceptT getGet) `shouldReturn` Right alice (left show <$> runExceptT getGet) `shouldReturn` Right alice
describe "Servant.API.Delete" $ do describe "Servant.API.Delete" $ do
@ -217,6 +279,15 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do
Left e -> assertFailure $ show e Left e -> assertFailure $ show e
Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")] Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")]
it "Handles Authentication appropriatley" $ withServer $ \ _ -> do
(Arrow.left show <$> runExceptT (getPrivatePerson (BasicAuth "servant" "server"))) `shouldReturn` Right alice
it "returns 401 when not properly authenticated" $ do
Left res <- runExceptT (getPrivatePerson (BasicAuth "xxx" "yyy"))
case res of
FailureResponse (Status 401 _) _ _ -> return ()
_ -> fail $ "expcted 401 response, but got " <> show res
modifyMaxSuccess (const 20) $ do modifyMaxSuccess (const 20) $ do
it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ \(_, baseUrl) -> it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ \(_, baseUrl) ->
let getMultiple = getNth (Proxy :: Proxy 9) $ client api baseUrl manager let getMultiple = getNth (Proxy :: Proxy 9) $ client api baseUrl manager
@ -226,7 +297,6 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do
return $ return $
result === Right (cap, num, flag, body) result === Right (cap, num, flag, body)
wrappedApiSpec :: Spec wrappedApiSpec :: Spec
wrappedApiSpec = describe "error status codes" $ do wrappedApiSpec = describe "error status codes" $ do
let serveW api = serve api $ throwE $ ServantErr 500 "error message" "" [] let serveW api = serve api $ throwE $ ServantErr 500 "error message" "" []
@ -284,6 +354,7 @@ failSpec = beforeAll (startWaiApp failServer) $ afterAll endWaiApp $ do
InvalidContentTypeHeader "fooooo" _ -> return () InvalidContentTypeHeader "fooooo" _ -> return ()
_ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res _ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res
data WrappedApi where data WrappedApi where
WrappedApi :: (HasServer api, Server api ~ ExceptT ServantErr IO a, WrappedApi :: (HasServer api, Server api ~ ExceptT ServantErr IO a,
HasClient api, Client api ~ ExceptT ServantError IO ()) => HasClient api, Client api ~ ExceptT ServantError IO ()) =>