diff --git a/servant-client/servant-client.cabal b/servant-client/servant-client.cabal index a739d7c7..b93c75e8 100644 --- a/servant-client/servant-client.cabal +++ b/servant-client/servant-client.cabal @@ -102,6 +102,7 @@ test-suite spec , kan-extensions , servant-client , servant-client-core + , stm , text , transformers , transformers-compat diff --git a/servant-client/src/Servant/Client/Internal/HttpClient.hs b/servant-client/src/Servant/Client/Internal/HttpClient.hs index eb566d6b..b0e6a83c 100644 --- a/servant-client/src/Servant/Client/Internal/HttpClient.hs +++ b/servant-client/src/Servant/Client/Internal/HttpClient.hs @@ -23,21 +23,25 @@ import Control.Monad.Catch (MonadCatch, MonadThrow) import Control.Monad.Error.Class (MonadError (..)) +import Control.Monad.IO.Class + (liftIO) import Control.Monad.Reader import Control.Monad.STM - (atomically) + (STM, atomically) import Control.Monad.Trans.Control (MonadBaseControl (..)) import Control.Monad.Trans.Except import Data.ByteString.Builder (toLazyByteString) import qualified Data.ByteString.Lazy as BSL +import Data.Either + (either) import Data.Foldable - (for_, toList) + (toList) import Data.Functor.Alt (Alt (..)) import Data.Maybe - (maybeToList) + (maybe, maybeToList) import Data.Proxy (Proxy (..)) import Data.Semigroup @@ -48,7 +52,7 @@ import Data.String (fromString) import qualified Data.Text as T import Data.Time.Clock - (getCurrentTime) + (UTCTime, getCurrentTime) import GHC.Generics import Network.HTTP.Media (renderHeader) @@ -158,19 +162,38 @@ performRequest req = do writeTVar cj newCookieJar pure newRequest - eResponse <- liftIO $ catchConnectionError $ Client.httpLbs request m - case eResponse of - Left err -> throwError err - Right response -> do - for_ cookieJar' $ \cj -> liftIO $ do - now' <- getCurrentTime - atomically $ modifyTVar' cj (fst . Client.updateCookieJar response request now') - let status = Client.responseStatus response - status_code = statusCode status - ourResponse = clientResponseToResponse response - unless (status_code >= 200 && status_code < 300) $ - throwError $ FailureResponse ourResponse - return ourResponse + response <- maybe (requestWithoutCookieJar m request) (requestWithCookieJar m request) cookieJar' + let status = Client.responseStatus response + status_code = statusCode status + ourResponse = clientResponseToResponse response + unless (status_code >= 200 && status_code < 300) $ + throwError $ FailureResponse ourResponse + return ourResponse + where + requestWithoutCookieJar :: Client.Manager -> Client.Request -> ClientM (Client.Response BSL.ByteString) + requestWithoutCookieJar m' request' = do + eResponse <- liftIO . catchConnectionError $ Client.httpLbs request' m' + either throwError return eResponse + + requestWithCookieJar :: Client.Manager -> Client.Request -> TVar Client.CookieJar -> ClientM (Client.Response BSL.ByteString) + requestWithCookieJar m' request' cj = do + eResponse <- liftIO . catchConnectionError . Client.withResponseHistory request' m' $ updateWithResponseCookies cj + either throwError return eResponse + + updateWithResponseCookies :: TVar Client.CookieJar -> Client.HistoriedResponse Client.BodyReader -> IO (Client.Response BSL.ByteString) + updateWithResponseCookies cj responses = do + now <- getCurrentTime + bss <- Client.brConsume $ Client.responseBody fRes + let fRes' = fRes { Client.responseBody = BSL.fromChunks bss } + allResponses = Client.hrRedirects responses <> [(fReq, fRes')] + atomically $ mapM_ (updateCookieJar now) allResponses + return fRes' + where + updateCookieJar :: UTCTime -> (Client.Request, Client.Response BSL.ByteString) -> STM () + updateCookieJar now' (req', res') = modifyTVar' cj (fst . Client.updateCookieJar res' req' now') + + fReq = Client.hrFinalRequest responses + fRes = Client.hrFinalResponse responses clientResponseToResponse :: Client.Response a -> GenResponse a clientResponseToResponse r = Response diff --git a/servant-client/test/Servant/ClientSpec.hs b/servant-client/test/Servant/ClientSpec.hs index d864fef0..636ab351 100644 --- a/servant-client/test/Servant/ClientSpec.hs +++ b/servant-client/test/Servant/ClientSpec.hs @@ -28,6 +28,10 @@ import Control.Arrow (left) import Control.Concurrent (ThreadId, forkIO, killThread) +import Control.Concurrent.STM + (atomically) +import Control.Concurrent.STM.TVar + (newTVar, readTVar) import Control.Exception (bracket) import Control.Monad.Error.Class @@ -37,17 +41,19 @@ import Data.Char (chr, isPrint) import Data.Foldable (forM_) +import Data.Maybe + (listToMaybe) import Data.Monoid () import Data.Proxy import Data.Semigroup ((<>)) -import qualified Generics.SOP as SOP +import qualified Generics.SOP as SOP import GHC.Generics (Generic) -import qualified Network.HTTP.Client as C -import qualified Network.HTTP.Types as HTTP +import qualified Network.HTTP.Client as C +import qualified Network.HTTP.Types as HTTP import Network.Socket -import qualified Network.Wai as Wai +import qualified Network.Wai as Wai import Network.Wai.Handler.Warp import System.IO.Unsafe (unsafePerformIO) @@ -64,12 +70,12 @@ import Servant.API DeleteNoContent, EmptyAPI, FormUrlEncoded, Get, Header, Headers, JSON, NoContent (NoContent), Post, Put, QueryFlag, QueryParam, QueryParams, Raw, ReqBody, addHeader, getHeaders) -import Servant.Test.ComprehensiveAPI import Servant.Client -import qualified Servant.Client.Core.Internal.Auth as Auth -import qualified Servant.Client.Core.Internal.Request as Req +import qualified Servant.Client.Core.Internal.Auth as Auth +import qualified Servant.Client.Core.Internal.Request as Req import Servant.Server import Servant.Server.Experimental.Auth +import Servant.Test.ComprehensiveAPI -- This declaration simply checks that all instances are in place. _ = client comprehensiveAPIWithoutStreaming @@ -128,6 +134,7 @@ type Api = Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])]) :<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool) :<|> "deleteContentType" :> DeleteNoContent '[JSON] NoContent + :<|> "redirectWithCookie" :> Raw :<|> "empty" :> EmptyAPI api :: Proxy Api @@ -148,6 +155,7 @@ getMultiple :: String -> Maybe Int -> Bool -> [(String, [Rational])] -> ClientM (String, Maybe Int, Bool, [(String, [Rational])]) getRespHeaders :: ClientM (Headers TestHeaders Bool) getDeleteContentType :: ClientM NoContent +getRedirectWithCookie :: HTTP.Method -> ClientM Response getRoot :<|> getGet @@ -163,6 +171,7 @@ getRoot :<|> getMultiple :<|> getRespHeaders :<|> getDeleteContentType + :<|> getRedirectWithCookie :<|> EmptyClient = client api server :: Application @@ -184,9 +193,9 @@ server = serve api ( :<|> (\ a b c d -> return (a, b, c, d)) :<|> (return $ addHeader 1729 $ addHeader "eg2" True) :<|> return NoContent + :<|> (Tagged $ \ _request respond -> respond $ Wai.responseLBS HTTP.found302 [("Location", "testlocation"), ("Set-Cookie", "testcookie=test")] "") :<|> emptyServer) - type FailApi = "get" :> Raw :<|> "capture" :> Capture "name" String :> Raw @@ -364,6 +373,14 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do Left e -> assertFailure $ show e Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")] + it "Stores Cookie in CookieJar after a redirect" $ \(_, baseUrl) -> do + mgr <- C.newManager C.defaultManagerSettings + cj <- atomically . newTVar $ C.createCookieJar [] + _ <- runClientM (getRedirectWithCookie HTTP.methodGet) (ClientEnv mgr baseUrl (Just cj)) + cookie <- listToMaybe . C.destroyCookieJar <$> atomically (readTVar cj) + C.cookie_name <$> cookie `shouldBe` Just "testcookie" + C.cookie_value <$> cookie `shouldBe` Just "test" + modifyMaxSuccess (const 20) $ do it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ \(_, baseUrl) -> property $ forAllShrink pathGen shrink $ \(NonEmpty cap) num flag body ->