Add authentication to servant-client
This commit is contained in:
parent
3f5106da77
commit
948951b6e9
4 changed files with 98 additions and 2 deletions
|
@ -26,6 +26,7 @@ source-repository head
|
||||||
library
|
library
|
||||||
exposed-modules:
|
exposed-modules:
|
||||||
Servant.Client
|
Servant.Client
|
||||||
|
Servant.Client.Authentication
|
||||||
Servant.Common.BaseUrl
|
Servant.Common.BaseUrl
|
||||||
Servant.Common.Req
|
Servant.Common.Req
|
||||||
build-depends:
|
build-depends:
|
||||||
|
@ -65,6 +66,7 @@ test-suite spec
|
||||||
, transformers
|
, transformers
|
||||||
, transformers-compat
|
, transformers-compat
|
||||||
, aeson
|
, aeson
|
||||||
|
, base64-bytestring
|
||||||
, bytestring
|
, bytestring
|
||||||
, deepseq
|
, deepseq
|
||||||
, hspec == 2.*
|
, hspec == 2.*
|
||||||
|
|
|
@ -37,6 +37,8 @@ import Network.HTTP.Media
|
||||||
import qualified Network.HTTP.Types as H
|
import qualified Network.HTTP.Types as H
|
||||||
import qualified Network.HTTP.Types.Header as HTTP
|
import qualified Network.HTTP.Types.Header as HTTP
|
||||||
import Servant.API
|
import Servant.API
|
||||||
|
import Servant.API.Authentication (AuthProtect)
|
||||||
|
import Servant.Client.Authentication (AuthenticateRequest(authReq))
|
||||||
import Servant.Common.BaseUrl
|
import Servant.Common.BaseUrl
|
||||||
import Servant.Common.Req
|
import Servant.Common.Req
|
||||||
|
|
||||||
|
@ -119,6 +121,15 @@ instance (KnownSymbol capture, ToHttpApiData a, HasClient sublayout)
|
||||||
|
|
||||||
where p = unpack (toUrlPiece val)
|
where p = unpack (toUrlPiece val)
|
||||||
|
|
||||||
|
-- | Authentication
|
||||||
|
instance (AuthenticateRequest authdata, HasClient sublayout) => HasClient (AuthProtect authdata (usr :: *) policy :> sublayout) where
|
||||||
|
type Client (AuthProtect authdata usr policy :> sublayout) = authdata -> Client sublayout
|
||||||
|
|
||||||
|
clientWithRoute Proxy req baseurl val =
|
||||||
|
clientWithRoute (Proxy :: Proxy sublayout)
|
||||||
|
(authReq val req)
|
||||||
|
baseurl
|
||||||
|
|
||||||
-- | If you have a 'Delete' endpoint in your API, the client
|
-- | If you have a 'Delete' endpoint in your API, the client
|
||||||
-- side querying function that is created when calling 'client'
|
-- side querying function that is created when calling 'client'
|
||||||
-- will just require an argument that specifies the scheme, host
|
-- will just require an argument that specifies the scheme, host
|
||||||
|
|
12
servant-client/src/Servant/Client/Authentication.hs
Normal file
12
servant-client/src/Servant/Client/Authentication.hs
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
-- | Authentication for clients
|
||||||
|
|
||||||
|
module Servant.Client.Authentication (
|
||||||
|
AuthenticateRequest ( authReq )
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Servant.Common.Req (Req)
|
||||||
|
|
||||||
|
-- | Class to represent the ability to authenticate a 'Request'
|
||||||
|
-- object. For example, we may add special headers to the 'Request'.
|
||||||
|
class AuthenticateRequest a where
|
||||||
|
authReq :: a -> Req -> Req
|
|
@ -30,13 +30,16 @@ import Control.Concurrent (forkIO, killThread, ThreadId)
|
||||||
import Control.Exception (bracket)
|
import Control.Exception (bracket)
|
||||||
import Control.Monad.Trans.Except (ExceptT, runExceptT, throwE)
|
import Control.Monad.Trans.Except (ExceptT, runExceptT, throwE)
|
||||||
import Data.Aeson
|
import Data.Aeson
|
||||||
import Data.Char (chr, isPrint)
|
import qualified Data.ByteString.Base64 as B64
|
||||||
|
import Data.ByteString.Lazy (ByteString)
|
||||||
import Data.Foldable (forM_)
|
import Data.Foldable (forM_)
|
||||||
import Data.Monoid hiding (getLast)
|
import Data.Monoid hiding (getLast)
|
||||||
import Data.Proxy
|
import Data.Proxy
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
import GHC.Generics (Generic)
|
import GHC.Generics (Generic)
|
||||||
import GHC.TypeLits
|
import GHC.TypeLits
|
||||||
|
import qualified Data.Text.Encoding as TE
|
||||||
|
import GHC.Generics
|
||||||
import qualified Network.HTTP.Client as C
|
import qualified Network.HTTP.Client as C
|
||||||
import Network.HTTP.Media
|
import Network.HTTP.Media
|
||||||
import Network.HTTP.Types (Status (..), badRequest400,
|
import Network.HTTP.Types (Status (..), badRequest400,
|
||||||
|
@ -51,8 +54,12 @@ import Test.HUnit
|
||||||
import Test.QuickCheck
|
import Test.QuickCheck
|
||||||
|
|
||||||
import Servant.API
|
import Servant.API
|
||||||
|
import Servant.API.Authentication
|
||||||
import Servant.Client
|
import Servant.Client
|
||||||
|
import qualified Servant.Common.Req as SCR
|
||||||
|
import Servant.Client.Authentication (AuthenticateRequest(authReq))
|
||||||
import Servant.Server
|
import Servant.Server
|
||||||
|
import Servant.Server.Internal.Authentication
|
||||||
|
|
||||||
spec :: Spec
|
spec :: Spec
|
||||||
spec = describe "Servant.Client" $ do
|
spec = describe "Servant.Client" $ do
|
||||||
|
@ -108,6 +115,24 @@ type Api =
|
||||||
Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])])
|
Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])])
|
||||||
:<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool)
|
:<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool)
|
||||||
:<|> "deleteContentType" :> Delete '[JSON] ()
|
:<|> "deleteContentType" :> Delete '[JSON] ()
|
||||||
|
:<|> AuthProtect (BasicAuth "realm") Person 'Strict :> Get '[JSON] Person
|
||||||
|
|
||||||
|
-- base64-encoded "servant:server"
|
||||||
|
base64ServantColonServer :: ByteString
|
||||||
|
base64ServantColonServer = "c2VydmFudDpzZXJ2ZXI="
|
||||||
|
|
||||||
|
type AuthUser = T.Text
|
||||||
|
|
||||||
|
basicAuthCheck :: BasicAuth "realm" -> IO (Maybe Person)
|
||||||
|
basicAuthCheck (BasicAuth user pass) = if user == "servant" && pass == "server"
|
||||||
|
then return (Just $ Person "servant" 17)
|
||||||
|
else return Nothing
|
||||||
|
|
||||||
|
instance AuthenticateRequest (BasicAuth realm) where
|
||||||
|
authReq (BasicAuth user pass) req =
|
||||||
|
let authText = TE.decodeUtf8 ("Basic " <> B64.encode (user <> ":" <> pass)) in
|
||||||
|
SCR.addHeader "Authorization" authText req
|
||||||
|
|
||||||
api :: Proxy Api
|
api :: Proxy Api
|
||||||
api = Proxy
|
api = Proxy
|
||||||
|
|
||||||
|
@ -128,6 +153,7 @@ server = serve api (
|
||||||
:<|> (\ a b c d -> return (a, b, c, d))
|
:<|> (\ a b c d -> return (a, b, c, d))
|
||||||
:<|> (return $ addHeader 1729 $ addHeader "eg2" True)
|
:<|> (return $ addHeader 1729 $ addHeader "eg2" True)
|
||||||
:<|> return ()
|
:<|> return ()
|
||||||
|
:<|> basicAuthStrict basicAuthCheck (const . return $ alice)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,6 +180,42 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do
|
||||||
|
|
||||||
it "Servant.API.Get" $ \(_, baseUrl) -> do
|
it "Servant.API.Get" $ \(_, baseUrl) -> do
|
||||||
let getGet = getNth (Proxy :: Proxy 0) $ client api baseUrl manager
|
let getGet = getNth (Proxy :: Proxy 0) $ client api baseUrl manager
|
||||||
|
manager <- C.newManager C.defaultManagerSettings
|
||||||
|
let getGet :: ExceptT ServantError IO Person
|
||||||
|
getDeleteEmpty :: ExceptT ServantError IO ()
|
||||||
|
getCapture :: String -> ExceptT ServantError IO Person
|
||||||
|
getBody :: Person -> ExceptT ServantError IO Person
|
||||||
|
getQueryParam :: Maybe String -> ExceptT ServantError IO Person
|
||||||
|
getQueryParams :: [String] -> ExceptT ServantError IO [Person]
|
||||||
|
getQueryFlag :: Bool -> ExceptT ServantError IO Bool
|
||||||
|
getMatrixParam :: Maybe String -> ExceptT ServantError IO Person
|
||||||
|
getMatrixParams :: [String] -> ExceptT ServantError IO [Person]
|
||||||
|
getMatrixFlag :: Bool -> ExceptT ServantError IO Bool
|
||||||
|
getRawSuccess :: Method -> ExceptT ServantError IO (Int, ByteString, MediaType, [HTTP.Header], C.Response ByteString)
|
||||||
|
getRawFailure :: Method -> ExceptT ServantError IO (Int, ByteString, MediaType, [HTTP.Header], C.Response ByteString)
|
||||||
|
getMultiple :: String -> Maybe Int -> Bool -> [(String, [Rational])] -> ExceptT ServantError IO (String, Maybe Int, Bool, [(String, [Rational])])
|
||||||
|
getRespHeaders :: ExceptT ServantError IO (Headers TestHeaders Bool)
|
||||||
|
getDeleteContentType :: ExceptT ServantError IO ()
|
||||||
|
( getGet
|
||||||
|
:<|> getDeleteEmpty
|
||||||
|
:<|> getCapture
|
||||||
|
:<|> getBody
|
||||||
|
:<|> getQueryParam
|
||||||
|
:<|> getQueryParams
|
||||||
|
:<|> getQueryFlag
|
||||||
|
:<|> getMatrixParam
|
||||||
|
:<|> getMatrixParams
|
||||||
|
:<|> getMatrixFlag
|
||||||
|
:<|> getRawSuccess
|
||||||
|
:<|> getRawFailure
|
||||||
|
:<|> getMultiple
|
||||||
|
:<|> getRespHeaders
|
||||||
|
:<|> getDeleteContentType
|
||||||
|
:<|> getPrivatePerson)
|
||||||
|
= client api baseUrl manager
|
||||||
|
|
||||||
|
hspec $ do
|
||||||
|
it "Servant.API.Get" $ do
|
||||||
(left show <$> runExceptT getGet) `shouldReturn` Right alice
|
(left show <$> runExceptT getGet) `shouldReturn` Right alice
|
||||||
|
|
||||||
describe "Servant.API.Delete" $ do
|
describe "Servant.API.Delete" $ do
|
||||||
|
@ -217,6 +279,15 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do
|
||||||
Left e -> assertFailure $ show e
|
Left e -> assertFailure $ show e
|
||||||
Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")]
|
Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")]
|
||||||
|
|
||||||
|
it "Handles Authentication appropriatley" $ withServer $ \ _ -> do
|
||||||
|
(Arrow.left show <$> runExceptT (getPrivatePerson (BasicAuth "servant" "server"))) `shouldReturn` Right alice
|
||||||
|
|
||||||
|
it "returns 401 when not properly authenticated" $ do
|
||||||
|
Left res <- runExceptT (getPrivatePerson (BasicAuth "xxx" "yyy"))
|
||||||
|
case res of
|
||||||
|
FailureResponse (Status 401 _) _ _ -> return ()
|
||||||
|
_ -> fail $ "expcted 401 response, but got " <> show res
|
||||||
|
|
||||||
modifyMaxSuccess (const 20) $ do
|
modifyMaxSuccess (const 20) $ do
|
||||||
it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ \(_, baseUrl) ->
|
it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ \(_, baseUrl) ->
|
||||||
let getMultiple = getNth (Proxy :: Proxy 9) $ client api baseUrl manager
|
let getMultiple = getNth (Proxy :: Proxy 9) $ client api baseUrl manager
|
||||||
|
@ -226,7 +297,6 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do
|
||||||
return $
|
return $
|
||||||
result === Right (cap, num, flag, body)
|
result === Right (cap, num, flag, body)
|
||||||
|
|
||||||
|
|
||||||
wrappedApiSpec :: Spec
|
wrappedApiSpec :: Spec
|
||||||
wrappedApiSpec = describe "error status codes" $ do
|
wrappedApiSpec = describe "error status codes" $ do
|
||||||
let serveW api = serve api $ throwE $ ServantErr 500 "error message" "" []
|
let serveW api = serve api $ throwE $ ServantErr 500 "error message" "" []
|
||||||
|
@ -284,6 +354,7 @@ failSpec = beforeAll (startWaiApp failServer) $ afterAll endWaiApp $ do
|
||||||
InvalidContentTypeHeader "fooooo" _ -> return ()
|
InvalidContentTypeHeader "fooooo" _ -> return ()
|
||||||
_ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res
|
_ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res
|
||||||
|
|
||||||
|
|
||||||
data WrappedApi where
|
data WrappedApi where
|
||||||
WrappedApi :: (HasServer api, Server api ~ ExceptT ServantErr IO a,
|
WrappedApi :: (HasServer api, Server api ~ ExceptT ServantErr IO a,
|
||||||
HasClient api, Client api ~ ExceptT ServantError IO ()) =>
|
HasClient api, Client api ~ ExceptT ServantError IO ()) =>
|
||||||
|
|
Loading…
Reference in a new issue