Merge pull request #1104 from michaelsdunn1/master
Update CookieJar with intermediate request/responses using Network.HTTP.Client.HistoriedResponse.
This commit is contained in:
commit
28ac8c072b
3 changed files with 66 additions and 25 deletions
|
@ -102,6 +102,7 @@ test-suite spec
|
||||||
, kan-extensions
|
, kan-extensions
|
||||||
, servant-client
|
, servant-client
|
||||||
, servant-client-core
|
, servant-client-core
|
||||||
|
, stm
|
||||||
, text
|
, text
|
||||||
, transformers
|
, transformers
|
||||||
, transformers-compat
|
, transformers-compat
|
||||||
|
|
|
@ -23,21 +23,25 @@ import Control.Monad.Catch
|
||||||
(MonadCatch, MonadThrow)
|
(MonadCatch, MonadThrow)
|
||||||
import Control.Monad.Error.Class
|
import Control.Monad.Error.Class
|
||||||
(MonadError (..))
|
(MonadError (..))
|
||||||
|
import Control.Monad.IO.Class
|
||||||
|
(liftIO)
|
||||||
import Control.Monad.Reader
|
import Control.Monad.Reader
|
||||||
import Control.Monad.STM
|
import Control.Monad.STM
|
||||||
(atomically)
|
(STM, atomically)
|
||||||
import Control.Monad.Trans.Control
|
import Control.Monad.Trans.Control
|
||||||
(MonadBaseControl (..))
|
(MonadBaseControl (..))
|
||||||
import Control.Monad.Trans.Except
|
import Control.Monad.Trans.Except
|
||||||
import Data.ByteString.Builder
|
import Data.ByteString.Builder
|
||||||
(toLazyByteString)
|
(toLazyByteString)
|
||||||
import qualified Data.ByteString.Lazy as BSL
|
import qualified Data.ByteString.Lazy as BSL
|
||||||
|
import Data.Either
|
||||||
|
(either)
|
||||||
import Data.Foldable
|
import Data.Foldable
|
||||||
(for_, toList)
|
(toList)
|
||||||
import Data.Functor.Alt
|
import Data.Functor.Alt
|
||||||
(Alt (..))
|
(Alt (..))
|
||||||
import Data.Maybe
|
import Data.Maybe
|
||||||
(maybeToList)
|
(maybe, maybeToList)
|
||||||
import Data.Proxy
|
import Data.Proxy
|
||||||
(Proxy (..))
|
(Proxy (..))
|
||||||
import Data.Semigroup
|
import Data.Semigroup
|
||||||
|
@ -48,7 +52,7 @@ import Data.String
|
||||||
(fromString)
|
(fromString)
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
import Data.Time.Clock
|
import Data.Time.Clock
|
||||||
(getCurrentTime)
|
(UTCTime, getCurrentTime)
|
||||||
import GHC.Generics
|
import GHC.Generics
|
||||||
import Network.HTTP.Media
|
import Network.HTTP.Media
|
||||||
(renderHeader)
|
(renderHeader)
|
||||||
|
@ -158,19 +162,38 @@ performRequest req = do
|
||||||
writeTVar cj newCookieJar
|
writeTVar cj newCookieJar
|
||||||
pure newRequest
|
pure newRequest
|
||||||
|
|
||||||
eResponse <- liftIO $ catchConnectionError $ Client.httpLbs request m
|
response <- maybe (requestWithoutCookieJar m request) (requestWithCookieJar m request) cookieJar'
|
||||||
case eResponse of
|
let status = Client.responseStatus response
|
||||||
Left err -> throwError err
|
status_code = statusCode status
|
||||||
Right response -> do
|
ourResponse = clientResponseToResponse response
|
||||||
for_ cookieJar' $ \cj -> liftIO $ do
|
unless (status_code >= 200 && status_code < 300) $
|
||||||
now' <- getCurrentTime
|
throwError $ FailureResponse ourResponse
|
||||||
atomically $ modifyTVar' cj (fst . Client.updateCookieJar response request now')
|
return ourResponse
|
||||||
let status = Client.responseStatus response
|
where
|
||||||
status_code = statusCode status
|
requestWithoutCookieJar :: Client.Manager -> Client.Request -> ClientM (Client.Response BSL.ByteString)
|
||||||
ourResponse = clientResponseToResponse response
|
requestWithoutCookieJar m' request' = do
|
||||||
unless (status_code >= 200 && status_code < 300) $
|
eResponse <- liftIO . catchConnectionError $ Client.httpLbs request' m'
|
||||||
throwError $ FailureResponse ourResponse
|
either throwError return eResponse
|
||||||
return ourResponse
|
|
||||||
|
requestWithCookieJar :: Client.Manager -> Client.Request -> TVar Client.CookieJar -> ClientM (Client.Response BSL.ByteString)
|
||||||
|
requestWithCookieJar m' request' cj = do
|
||||||
|
eResponse <- liftIO . catchConnectionError . Client.withResponseHistory request' m' $ updateWithResponseCookies cj
|
||||||
|
either throwError return eResponse
|
||||||
|
|
||||||
|
updateWithResponseCookies :: TVar Client.CookieJar -> Client.HistoriedResponse Client.BodyReader -> IO (Client.Response BSL.ByteString)
|
||||||
|
updateWithResponseCookies cj responses = do
|
||||||
|
now <- getCurrentTime
|
||||||
|
bss <- Client.brConsume $ Client.responseBody fRes
|
||||||
|
let fRes' = fRes { Client.responseBody = BSL.fromChunks bss }
|
||||||
|
allResponses = Client.hrRedirects responses <> [(fReq, fRes')]
|
||||||
|
atomically $ mapM_ (updateCookieJar now) allResponses
|
||||||
|
return fRes'
|
||||||
|
where
|
||||||
|
updateCookieJar :: UTCTime -> (Client.Request, Client.Response BSL.ByteString) -> STM ()
|
||||||
|
updateCookieJar now' (req', res') = modifyTVar' cj (fst . Client.updateCookieJar res' req' now')
|
||||||
|
|
||||||
|
fReq = Client.hrFinalRequest responses
|
||||||
|
fRes = Client.hrFinalResponse responses
|
||||||
|
|
||||||
clientResponseToResponse :: Client.Response a -> GenResponse a
|
clientResponseToResponse :: Client.Response a -> GenResponse a
|
||||||
clientResponseToResponse r = Response
|
clientResponseToResponse r = Response
|
||||||
|
|
|
@ -28,6 +28,10 @@ import Control.Arrow
|
||||||
(left)
|
(left)
|
||||||
import Control.Concurrent
|
import Control.Concurrent
|
||||||
(ThreadId, forkIO, killThread)
|
(ThreadId, forkIO, killThread)
|
||||||
|
import Control.Concurrent.STM
|
||||||
|
(atomically)
|
||||||
|
import Control.Concurrent.STM.TVar
|
||||||
|
(newTVar, readTVar)
|
||||||
import Control.Exception
|
import Control.Exception
|
||||||
(bracket)
|
(bracket)
|
||||||
import Control.Monad.Error.Class
|
import Control.Monad.Error.Class
|
||||||
|
@ -37,17 +41,19 @@ import Data.Char
|
||||||
(chr, isPrint)
|
(chr, isPrint)
|
||||||
import Data.Foldable
|
import Data.Foldable
|
||||||
(forM_)
|
(forM_)
|
||||||
|
import Data.Maybe
|
||||||
|
(listToMaybe)
|
||||||
import Data.Monoid ()
|
import Data.Monoid ()
|
||||||
import Data.Proxy
|
import Data.Proxy
|
||||||
import Data.Semigroup
|
import Data.Semigroup
|
||||||
((<>))
|
((<>))
|
||||||
import qualified Generics.SOP as SOP
|
import qualified Generics.SOP as SOP
|
||||||
import GHC.Generics
|
import GHC.Generics
|
||||||
(Generic)
|
(Generic)
|
||||||
import qualified Network.HTTP.Client as C
|
import qualified Network.HTTP.Client as C
|
||||||
import qualified Network.HTTP.Types as HTTP
|
import qualified Network.HTTP.Types as HTTP
|
||||||
import Network.Socket
|
import Network.Socket
|
||||||
import qualified Network.Wai as Wai
|
import qualified Network.Wai as Wai
|
||||||
import Network.Wai.Handler.Warp
|
import Network.Wai.Handler.Warp
|
||||||
import System.IO.Unsafe
|
import System.IO.Unsafe
|
||||||
(unsafePerformIO)
|
(unsafePerformIO)
|
||||||
|
@ -64,12 +70,12 @@ import Servant.API
|
||||||
DeleteNoContent, EmptyAPI, FormUrlEncoded, Get, Header,
|
DeleteNoContent, EmptyAPI, FormUrlEncoded, Get, Header,
|
||||||
Headers, JSON, NoContent (NoContent), Post, Put, QueryFlag,
|
Headers, JSON, NoContent (NoContent), Post, Put, QueryFlag,
|
||||||
QueryParam, QueryParams, Raw, ReqBody, addHeader, getHeaders)
|
QueryParam, QueryParams, Raw, ReqBody, addHeader, getHeaders)
|
||||||
import Servant.Test.ComprehensiveAPI
|
|
||||||
import Servant.Client
|
import Servant.Client
|
||||||
import qualified Servant.Client.Core.Internal.Auth as Auth
|
import qualified Servant.Client.Core.Internal.Auth as Auth
|
||||||
import qualified Servant.Client.Core.Internal.Request as Req
|
import qualified Servant.Client.Core.Internal.Request as Req
|
||||||
import Servant.Server
|
import Servant.Server
|
||||||
import Servant.Server.Experimental.Auth
|
import Servant.Server.Experimental.Auth
|
||||||
|
import Servant.Test.ComprehensiveAPI
|
||||||
|
|
||||||
-- This declaration simply checks that all instances are in place.
|
-- This declaration simply checks that all instances are in place.
|
||||||
_ = client comprehensiveAPIWithoutStreaming
|
_ = client comprehensiveAPIWithoutStreaming
|
||||||
|
@ -128,6 +134,7 @@ type Api =
|
||||||
Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])])
|
Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])])
|
||||||
:<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool)
|
:<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool)
|
||||||
:<|> "deleteContentType" :> DeleteNoContent '[JSON] NoContent
|
:<|> "deleteContentType" :> DeleteNoContent '[JSON] NoContent
|
||||||
|
:<|> "redirectWithCookie" :> Raw
|
||||||
:<|> "empty" :> EmptyAPI
|
:<|> "empty" :> EmptyAPI
|
||||||
|
|
||||||
api :: Proxy Api
|
api :: Proxy Api
|
||||||
|
@ -148,6 +155,7 @@ getMultiple :: String -> Maybe Int -> Bool -> [(String, [Rational])]
|
||||||
-> ClientM (String, Maybe Int, Bool, [(String, [Rational])])
|
-> ClientM (String, Maybe Int, Bool, [(String, [Rational])])
|
||||||
getRespHeaders :: ClientM (Headers TestHeaders Bool)
|
getRespHeaders :: ClientM (Headers TestHeaders Bool)
|
||||||
getDeleteContentType :: ClientM NoContent
|
getDeleteContentType :: ClientM NoContent
|
||||||
|
getRedirectWithCookie :: HTTP.Method -> ClientM Response
|
||||||
|
|
||||||
getRoot
|
getRoot
|
||||||
:<|> getGet
|
:<|> getGet
|
||||||
|
@ -163,6 +171,7 @@ getRoot
|
||||||
:<|> getMultiple
|
:<|> getMultiple
|
||||||
:<|> getRespHeaders
|
:<|> getRespHeaders
|
||||||
:<|> getDeleteContentType
|
:<|> getDeleteContentType
|
||||||
|
:<|> getRedirectWithCookie
|
||||||
:<|> EmptyClient = client api
|
:<|> EmptyClient = client api
|
||||||
|
|
||||||
server :: Application
|
server :: Application
|
||||||
|
@ -184,9 +193,9 @@ server = serve api (
|
||||||
:<|> (\ a b c d -> return (a, b, c, d))
|
:<|> (\ a b c d -> return (a, b, c, d))
|
||||||
:<|> (return $ addHeader 1729 $ addHeader "eg2" True)
|
:<|> (return $ addHeader 1729 $ addHeader "eg2" True)
|
||||||
:<|> return NoContent
|
:<|> return NoContent
|
||||||
|
:<|> (Tagged $ \ _request respond -> respond $ Wai.responseLBS HTTP.found302 [("Location", "testlocation"), ("Set-Cookie", "testcookie=test")] "")
|
||||||
:<|> emptyServer)
|
:<|> emptyServer)
|
||||||
|
|
||||||
|
|
||||||
type FailApi =
|
type FailApi =
|
||||||
"get" :> Raw
|
"get" :> Raw
|
||||||
:<|> "capture" :> Capture "name" String :> Raw
|
:<|> "capture" :> Capture "name" String :> Raw
|
||||||
|
@ -364,6 +373,14 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do
|
||||||
Left e -> assertFailure $ show e
|
Left e -> assertFailure $ show e
|
||||||
Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")]
|
Right val -> getHeaders val `shouldBe` [("X-Example1", "1729"), ("X-Example2", "eg2")]
|
||||||
|
|
||||||
|
it "Stores Cookie in CookieJar after a redirect" $ \(_, baseUrl) -> do
|
||||||
|
mgr <- C.newManager C.defaultManagerSettings
|
||||||
|
cj <- atomically . newTVar $ C.createCookieJar []
|
||||||
|
_ <- runClientM (getRedirectWithCookie HTTP.methodGet) (ClientEnv mgr baseUrl (Just cj))
|
||||||
|
cookie <- listToMaybe . C.destroyCookieJar <$> atomically (readTVar cj)
|
||||||
|
C.cookie_name <$> cookie `shouldBe` Just "testcookie"
|
||||||
|
C.cookie_value <$> cookie `shouldBe` Just "test"
|
||||||
|
|
||||||
modifyMaxSuccess (const 20) $ do
|
modifyMaxSuccess (const 20) $ do
|
||||||
it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ \(_, baseUrl) ->
|
it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ \(_, baseUrl) ->
|
||||||
property $ forAllShrink pathGen shrink $ \(NonEmpty cap) num flag body ->
|
property $ forAllShrink pathGen shrink $ \(NonEmpty cap) num flag body ->
|
||||||
|
|
Loading…
Reference in a new issue