Add authentication to servant-client
This commit is contained in:
parent
4fce780c44
commit
6672ee7918
4 changed files with 63 additions and 2 deletions
|
@ -26,6 +26,7 @@ source-repository head
|
||||||
library
|
library
|
||||||
exposed-modules:
|
exposed-modules:
|
||||||
Servant.Client
|
Servant.Client
|
||||||
|
Servant.Client.Authentication
|
||||||
Servant.Common.BaseUrl
|
Servant.Common.BaseUrl
|
||||||
Servant.Common.Req
|
Servant.Common.Req
|
||||||
build-depends:
|
build-depends:
|
||||||
|
@ -64,6 +65,7 @@ test-suite spec
|
||||||
, transformers
|
, transformers
|
||||||
, transformers-compat
|
, transformers-compat
|
||||||
, aeson
|
, aeson
|
||||||
|
, base64-bytestring
|
||||||
, bytestring
|
, bytestring
|
||||||
, deepseq
|
, deepseq
|
||||||
, hspec == 2.*
|
, hspec == 2.*
|
||||||
|
|
|
@ -37,6 +37,8 @@ import Network.HTTP.Media
|
||||||
import qualified Network.HTTP.Types as H
|
import qualified Network.HTTP.Types as H
|
||||||
import qualified Network.HTTP.Types.Header as HTTP
|
import qualified Network.HTTP.Types.Header as HTTP
|
||||||
import Servant.API
|
import Servant.API
|
||||||
|
import Servant.API.Authentication (AuthProtect)
|
||||||
|
import Servant.Client.Authentication (AuthenticateRequest(authReq))
|
||||||
import Servant.Common.BaseUrl
|
import Servant.Common.BaseUrl
|
||||||
import Servant.Common.Req
|
import Servant.Common.Req
|
||||||
|
|
||||||
|
@ -119,6 +121,15 @@ instance (KnownSymbol capture, ToText a, HasClient sublayout)
|
||||||
|
|
||||||
where p = unpack (toText val)
|
where p = unpack (toText val)
|
||||||
|
|
||||||
|
-- | Authentication
|
||||||
|
instance (AuthenticateRequest authdata, HasClient sublayout) => HasClient (AuthProtect authdata (usr :: *) policy :> sublayout) where
|
||||||
|
type Client (AuthProtect authdata usr policy :> sublayout) = authdata -> Client sublayout
|
||||||
|
|
||||||
|
clientWithRoute Proxy req baseurl val =
|
||||||
|
clientWithRoute (Proxy :: Proxy sublayout)
|
||||||
|
(authReq val req)
|
||||||
|
baseurl
|
||||||
|
|
||||||
-- | If you have a 'Delete' endpoint in your API, the client
|
-- | If you have a 'Delete' endpoint in your API, the client
|
||||||
-- side querying function that is created when calling 'client'
|
-- side querying function that is created when calling 'client'
|
||||||
-- will just require an argument that specifies the scheme, host
|
-- will just require an argument that specifies the scheme, host
|
||||||
|
|
12
servant-client/src/Servant/Client/Authentication.hs
Normal file
12
servant-client/src/Servant/Client/Authentication.hs
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
-- | Authentication for clients
|
||||||
|
|
||||||
|
module Servant.Client.Authentication (
|
||||||
|
AuthenticateRequest ( authReq )
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Servant.Common.Req (Req)
|
||||||
|
|
||||||
|
-- | Class to represent the ability to authenticate a 'Request'
|
||||||
|
-- object. For example, we may add special headers to the 'Request'.
|
||||||
|
class AuthenticateRequest a where
|
||||||
|
authReq :: a -> Req -> Req
|
|
@ -20,12 +20,14 @@ import Control.Concurrent
|
||||||
import Control.Exception
|
import Control.Exception
|
||||||
import Control.Monad.Trans.Except
|
import Control.Monad.Trans.Except
|
||||||
import Data.Aeson
|
import Data.Aeson
|
||||||
|
import qualified Data.ByteString.Base64 as B64
|
||||||
import Data.ByteString.Lazy (ByteString)
|
import Data.ByteString.Lazy (ByteString)
|
||||||
import Data.Char
|
import Data.Char
|
||||||
import Data.Foldable (forM_)
|
import Data.Foldable (forM_)
|
||||||
import Data.Monoid
|
import Data.Monoid
|
||||||
import Data.Proxy
|
import Data.Proxy
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
|
import qualified Data.Text.Encoding as TE
|
||||||
import GHC.Generics
|
import GHC.Generics
|
||||||
import qualified Network.HTTP.Client as C
|
import qualified Network.HTTP.Client as C
|
||||||
import Network.HTTP.Media
|
import Network.HTTP.Media
|
||||||
|
@ -40,8 +42,12 @@ import Test.HUnit
|
||||||
import Test.QuickCheck
|
import Test.QuickCheck
|
||||||
|
|
||||||
import Servant.API
|
import Servant.API
|
||||||
|
import Servant.API.Authentication
|
||||||
import Servant.Client
|
import Servant.Client
|
||||||
|
import qualified Servant.Common.Req as SCR
|
||||||
|
import Servant.Client.Authentication (AuthenticateRequest(authReq))
|
||||||
import Servant.Server
|
import Servant.Server
|
||||||
|
import Servant.Server.Internal.Authentication
|
||||||
|
|
||||||
-- * test data types
|
-- * test data types
|
||||||
|
|
||||||
|
@ -94,6 +100,24 @@ type Api =
|
||||||
Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])])
|
Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])])
|
||||||
:<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool)
|
:<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool)
|
||||||
:<|> "deleteContentType" :> Delete '[JSON] ()
|
:<|> "deleteContentType" :> Delete '[JSON] ()
|
||||||
|
:<|> AuthProtect (BasicAuth "realm") Person 'Strict :> Get '[JSON] Person
|
||||||
|
|
||||||
|
-- base64-encoded "servant:server"
|
||||||
|
base64ServantColonServer :: ByteString
|
||||||
|
base64ServantColonServer = "c2VydmFudDpzZXJ2ZXI="
|
||||||
|
|
||||||
|
type AuthUser = T.Text
|
||||||
|
|
||||||
|
basicAuthCheck :: BasicAuth "realm" -> IO (Maybe Person)
|
||||||
|
basicAuthCheck (BasicAuth user pass) = if user == "servant" && pass == "server"
|
||||||
|
then return (Just $ Person "servant" 17)
|
||||||
|
else return Nothing
|
||||||
|
|
||||||
|
instance AuthenticateRequest (BasicAuth realm) where
|
||||||
|
authReq (BasicAuth user pass) req =
|
||||||
|
let authText = TE.decodeUtf8 ("Basic " <> B64.encode (user <> ":" <> pass)) in
|
||||||
|
SCR.addHeader "Authorization" authText req
|
||||||
|
|
||||||
api :: Proxy Api
|
api :: Proxy Api
|
||||||
api = Proxy
|
api = Proxy
|
||||||
|
|
||||||
|
@ -120,6 +144,7 @@ server = serve api (
|
||||||
:<|> (\ a b c d -> return (a, b, c, d))
|
:<|> (\ a b c d -> return (a, b, c, d))
|
||||||
:<|> (return $ addHeader 1729 $ addHeader "eg2" True)
|
:<|> (return $ addHeader 1729 $ addHeader "eg2" True)
|
||||||
:<|> return ()
|
:<|> return ()
|
||||||
|
:<|> basicAuthStrict basicAuthCheck (const . return $ alice)
|
||||||
)
|
)
|
||||||
|
|
||||||
withServer :: (BaseUrl -> IO a) -> IO a
|
withServer :: (BaseUrl -> IO a) -> IO a
|
||||||
|
@ -144,6 +169,7 @@ withFailServer action = withWaiDaemon (return failServer) action
|
||||||
|
|
||||||
spec :: IO ()
|
spec :: IO ()
|
||||||
spec = withServer $ \ baseUrl -> do
|
spec = withServer $ \ baseUrl -> do
|
||||||
|
<<<<<<< HEAD
|
||||||
manager <- C.newManager C.defaultManagerSettings
|
manager <- C.newManager C.defaultManagerSettings
|
||||||
let getGet :: ExceptT ServantError IO Person
|
let getGet :: ExceptT ServantError IO Person
|
||||||
getDeleteEmpty :: ExceptT ServantError IO ()
|
getDeleteEmpty :: ExceptT ServantError IO ()
|
||||||
|
@ -174,7 +200,8 @@ spec = withServer $ \ baseUrl -> do
|
||||||
:<|> getRawFailure
|
:<|> getRawFailure
|
||||||
:<|> getMultiple
|
:<|> getMultiple
|
||||||
:<|> getRespHeaders
|
:<|> getRespHeaders
|
||||||
:<|> getDeleteContentType)
|
:<|> getDeleteContentType
|
||||||
|
:<|> getPrivatePerson)
|
||||||
= client api baseUrl manager
|
= client api baseUrl manager
|
||||||
|
|
||||||
hspec $ do
|
hspec $ do
|
||||||
|
@ -249,6 +276,15 @@ spec = withServer $ \ baseUrl -> do
|
||||||
Left e -> assertFailure $ show e
|
Left e -> assertFailure $ show e
|
||||||
Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")]
|
Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")]
|
||||||
|
|
||||||
|
it "Handles Authentication appropriatley" $ withServer $ \ _ -> do
|
||||||
|
(Arrow.left show <$> runExceptT (getPrivatePerson (BasicAuth "servant" "server"))) `shouldReturn` Right alice
|
||||||
|
|
||||||
|
it "returns 401 when not properly authenticated" $ do
|
||||||
|
Left res <- runExceptT (getPrivatePerson (BasicAuth "xxx" "yyy"))
|
||||||
|
case res of
|
||||||
|
FailureResponse (Status 401 _) _ _ -> return ()
|
||||||
|
_ -> fail $ "expcted 401 response, but got " <> show res
|
||||||
|
|
||||||
modifyMaxSuccess (const 20) $ do
|
modifyMaxSuccess (const 20) $ do
|
||||||
it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $
|
it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $
|
||||||
property $ forAllShrink pathGen shrink $ \(NonEmpty cap) num flag body ->
|
property $ forAllShrink pathGen shrink $ \(NonEmpty cap) num flag body ->
|
||||||
|
@ -257,7 +293,6 @@ spec = withServer $ \ baseUrl -> do
|
||||||
return $
|
return $
|
||||||
result === Right (cap, num, flag, body)
|
result === Right (cap, num, flag, body)
|
||||||
|
|
||||||
|
|
||||||
context "client correctly handles error status codes" $ do
|
context "client correctly handles error status codes" $ do
|
||||||
let test :: (WrappedApi, String) -> Spec
|
let test :: (WrappedApi, String) -> Spec
|
||||||
test (WrappedApi api, desc) =
|
test (WrappedApi api, desc) =
|
||||||
|
@ -323,6 +358,7 @@ failSpec = withFailServer $ \ baseUrl -> do
|
||||||
InvalidContentTypeHeader "fooooo" _ -> return ()
|
InvalidContentTypeHeader "fooooo" _ -> return ()
|
||||||
_ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res
|
_ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res
|
||||||
|
|
||||||
|
|
||||||
data WrappedApi where
|
data WrappedApi where
|
||||||
WrappedApi :: (HasServer api, Server api ~ ExceptT ServantErr IO a,
|
WrappedApi :: (HasServer api, Server api ~ ExceptT ServantErr IO a,
|
||||||
HasClient api, Client api ~ ExceptT ServantError IO ()) =>
|
HasClient api, Client api ~ ExceptT ServantError IO ()) =>
|
||||||
|
|
Loading…
Reference in a new issue