From 948951b6e9a828a4996b40ea653e6c79eb8ec5de Mon Sep 17 00:00:00 2001 From: aaron levin Date: Tue, 18 Aug 2015 15:35:36 -0400 Subject: [PATCH] Add authentication to servant-client --- servant-client/servant-client.cabal | 2 + servant-client/src/Servant/Client.hs | 11 +++ .../src/Servant/Client/Authentication.hs | 12 +++ servant-client/test/Servant/ClientSpec.hs | 75 ++++++++++++++++++- 4 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 servant-client/src/Servant/Client/Authentication.hs diff --git a/servant-client/servant-client.cabal b/servant-client/servant-client.cabal index 7fe69521..0119c78a 100644 --- a/servant-client/servant-client.cabal +++ b/servant-client/servant-client.cabal @@ -26,6 +26,7 @@ source-repository head library exposed-modules: Servant.Client + Servant.Client.Authentication Servant.Common.BaseUrl Servant.Common.Req build-depends: @@ -65,6 +66,7 @@ test-suite spec , transformers , transformers-compat , aeson + , base64-bytestring , bytestring , deepseq , hspec == 2.* diff --git a/servant-client/src/Servant/Client.hs b/servant-client/src/Servant/Client.hs index 987a2bd4..20bebba7 100644 --- a/servant-client/src/Servant/Client.hs +++ b/servant-client/src/Servant/Client.hs @@ -37,6 +37,8 @@ import Network.HTTP.Media import qualified Network.HTTP.Types as H import qualified Network.HTTP.Types.Header as HTTP import Servant.API +import Servant.API.Authentication (AuthProtect) +import Servant.Client.Authentication (AuthenticateRequest(authReq)) import Servant.Common.BaseUrl import Servant.Common.Req @@ -119,6 +121,15 @@ instance (KnownSymbol capture, ToHttpApiData a, HasClient sublayout) where p = unpack (toUrlPiece val) +-- | Authentication +instance (AuthenticateRequest authdata, HasClient sublayout) => HasClient (AuthProtect authdata (usr :: *) policy :> sublayout) where + type Client (AuthProtect authdata usr policy :> sublayout) = authdata -> Client sublayout + + clientWithRoute Proxy req baseurl val = + clientWithRoute (Proxy :: Proxy sublayout) + (authReq val req) + baseurl + -- | If you have a 'Delete' endpoint in your API, the client -- side querying function that is created when calling 'client' -- will just require an argument that specifies the scheme, host diff --git a/servant-client/src/Servant/Client/Authentication.hs b/servant-client/src/Servant/Client/Authentication.hs new file mode 100644 index 00000000..1d7d43d9 --- /dev/null +++ b/servant-client/src/Servant/Client/Authentication.hs @@ -0,0 +1,12 @@ +-- | Authentication for clients + +module Servant.Client.Authentication ( + AuthenticateRequest ( authReq ) + ) where + +import Servant.Common.Req (Req) + +-- | Class to represent the ability to authenticate a 'Request' +-- object. For example, we may add special headers to the 'Request'. +class AuthenticateRequest a where + authReq :: a -> Req -> Req diff --git a/servant-client/test/Servant/ClientSpec.hs b/servant-client/test/Servant/ClientSpec.hs index fc3cdcfb..f82da60a 100644 --- a/servant-client/test/Servant/ClientSpec.hs +++ b/servant-client/test/Servant/ClientSpec.hs @@ -30,13 +30,16 @@ import Control.Concurrent (forkIO, killThread, ThreadId) import Control.Exception (bracket) import Control.Monad.Trans.Except (ExceptT, runExceptT, throwE) import Data.Aeson -import Data.Char (chr, isPrint) +import qualified Data.ByteString.Base64 as B64 +import Data.ByteString.Lazy (ByteString) import Data.Foldable (forM_) import Data.Monoid hiding (getLast) import Data.Proxy import qualified Data.Text as T import GHC.Generics (Generic) import GHC.TypeLits +import qualified Data.Text.Encoding as TE +import GHC.Generics import qualified Network.HTTP.Client as C import Network.HTTP.Media import Network.HTTP.Types (Status (..), badRequest400, @@ -51,8 +54,12 @@ import Test.HUnit import Test.QuickCheck import Servant.API +import Servant.API.Authentication import Servant.Client +import qualified Servant.Common.Req as SCR +import Servant.Client.Authentication (AuthenticateRequest(authReq)) import Servant.Server +import Servant.Server.Internal.Authentication spec :: Spec spec = describe "Servant.Client" $ do @@ -108,6 +115,24 @@ type Api = Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])]) :<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool) :<|> "deleteContentType" :> Delete '[JSON] () + :<|> AuthProtect (BasicAuth "realm") Person 'Strict :> Get '[JSON] Person + +-- base64-encoded "servant:server" +base64ServantColonServer :: ByteString +base64ServantColonServer = "c2VydmFudDpzZXJ2ZXI=" + +type AuthUser = T.Text + +basicAuthCheck :: BasicAuth "realm" -> IO (Maybe Person) +basicAuthCheck (BasicAuth user pass) = if user == "servant" && pass == "server" + then return (Just $ Person "servant" 17) + else return Nothing + +instance AuthenticateRequest (BasicAuth realm) where + authReq (BasicAuth user pass) req = + let authText = TE.decodeUtf8 ("Basic " <> B64.encode (user <> ":" <> pass)) in + SCR.addHeader "Authorization" authText req + api :: Proxy Api api = Proxy @@ -128,6 +153,7 @@ server = serve api ( :<|> (\ a b c d -> return (a, b, c, d)) :<|> (return $ addHeader 1729 $ addHeader "eg2" True) :<|> return () + :<|> basicAuthStrict basicAuthCheck (const . return $ alice) ) @@ -154,6 +180,42 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do it "Servant.API.Get" $ \(_, baseUrl) -> do let getGet = getNth (Proxy :: Proxy 0) $ client api baseUrl manager + manager <- C.newManager C.defaultManagerSettings + let getGet :: ExceptT ServantError IO Person + getDeleteEmpty :: ExceptT ServantError IO () + getCapture :: String -> ExceptT ServantError IO Person + getBody :: Person -> ExceptT ServantError IO Person + getQueryParam :: Maybe String -> ExceptT ServantError IO Person + getQueryParams :: [String] -> ExceptT ServantError IO [Person] + getQueryFlag :: Bool -> ExceptT ServantError IO Bool + getMatrixParam :: Maybe String -> ExceptT ServantError IO Person + getMatrixParams :: [String] -> ExceptT ServantError IO [Person] + getMatrixFlag :: Bool -> ExceptT ServantError IO Bool + getRawSuccess :: Method -> ExceptT ServantError IO (Int, ByteString, MediaType, [HTTP.Header], C.Response ByteString) + getRawFailure :: Method -> ExceptT ServantError IO (Int, ByteString, MediaType, [HTTP.Header], C.Response ByteString) + getMultiple :: String -> Maybe Int -> Bool -> [(String, [Rational])] -> ExceptT ServantError IO (String, Maybe Int, Bool, [(String, [Rational])]) + getRespHeaders :: ExceptT ServantError IO (Headers TestHeaders Bool) + getDeleteContentType :: ExceptT ServantError IO () + ( getGet + :<|> getDeleteEmpty + :<|> getCapture + :<|> getBody + :<|> getQueryParam + :<|> getQueryParams + :<|> getQueryFlag + :<|> getMatrixParam + :<|> getMatrixParams + :<|> getMatrixFlag + :<|> getRawSuccess + :<|> getRawFailure + :<|> getMultiple + :<|> getRespHeaders + :<|> getDeleteContentType + :<|> getPrivatePerson) + = client api baseUrl manager + + hspec $ do + it "Servant.API.Get" $ do (left show <$> runExceptT getGet) `shouldReturn` Right alice describe "Servant.API.Delete" $ do @@ -217,6 +279,15 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do Left e -> assertFailure $ show e Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")] + it "Handles Authentication appropriatley" $ withServer $ \ _ -> do + (Arrow.left show <$> runExceptT (getPrivatePerson (BasicAuth "servant" "server"))) `shouldReturn` Right alice + + it "returns 401 when not properly authenticated" $ do + Left res <- runExceptT (getPrivatePerson (BasicAuth "xxx" "yyy")) + case res of + FailureResponse (Status 401 _) _ _ -> return () + _ -> fail $ "expcted 401 response, but got " <> show res + modifyMaxSuccess (const 20) $ do it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ \(_, baseUrl) -> let getMultiple = getNth (Proxy :: Proxy 9) $ client api baseUrl manager @@ -226,7 +297,6 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do return $ result === Right (cap, num, flag, body) - wrappedApiSpec :: Spec wrappedApiSpec = describe "error status codes" $ do let serveW api = serve api $ throwE $ ServantErr 500 "error message" "" [] @@ -284,6 +354,7 @@ failSpec = beforeAll (startWaiApp failServer) $ afterAll endWaiApp $ do InvalidContentTypeHeader "fooooo" _ -> return () _ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res + data WrappedApi where WrappedApi :: (HasServer api, Server api ~ ExceptT ServantErr IO a, HasClient api, Client api ~ ExceptT ServantError IO ()) =>