diff --git a/servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs b/servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs index 71200700..6be92ec6 100644 --- a/servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs +++ b/servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs @@ -16,6 +16,8 @@ module Servant.Client.Core.Internal.HasClient where import Prelude () import Prelude.Compat +import Control.Concurrent.MVar + (modifyMVar, newMVar) import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import Data.Foldable @@ -36,13 +38,14 @@ import qualified Network.HTTP.Types as H import Servant.API ((:<|>) ((:<|>)), (:>), AuthProtect, BasicAuth, BasicAuthData, BuildHeadersTo (..), Capture', CaptureAll, Description, - EmptyAPI, FramingUnrender (..), FromSourceIO (..), Header', - Headers (..), HttpVersion, IsSecure, MimeRender (mimeRender), + EmptyAPI, FramingRender (..), FramingUnrender (..), + FromSourceIO (..), Header', Headers (..), HttpVersion, + IsSecure, MimeRender (mimeRender), MimeUnrender (mimeUnrender), NoContent (NoContent), QueryFlag, QueryParam', QueryParams, Raw, ReflectMethod (..), RemoteHost, ReqBody', SBoolI, Stream, StreamBody', Summary, ToHttpApiData, - Vault, Verb, WithNamedContext, contentType, getHeadersHList, - getResponse, toQueryParam, toUrlPiece) + ToSourceIO (..), Vault, Verb, WithNamedContext, contentType, + getHeadersHList, getResponse, toQueryParam, toUrlPiece) import Servant.API.ContentTypes (contentTypes) import Servant.API.Modifiers @@ -538,7 +541,7 @@ instance (MimeRender ct a, HasClient m api) hoistClientMonad pm (Proxy :: Proxy api) f (cl a) instance - ( HasClient m api + ( HasClient m api, MimeRender ctype chunk, FramingRender framing, ToSourceIO chunk a ) => HasClient m (StreamBody' mods framing ctype a :> api) where @@ -547,7 +550,39 @@ instance hoistClientMonad pm _ f cl = \a -> hoistClientMonad pm (Proxy :: Proxy api) f (cl a) - clientWithRoute _pm Proxy _req _body = error "HasClient @StreamBody" + clientWithRoute pm Proxy req body + = clientWithRoute pm (Proxy :: Proxy api) + $ setRequestBody (RequestBodyStreamChunked givesPopper) (contentType ctypeP) req + where + ctypeP = Proxy :: Proxy ctype + framingP = Proxy :: Proxy framing + + sourceIO = framingRender + framingP + (mimeRender ctypeP :: chunk -> BL.ByteString) + (toSourceIO body) + + -- not pretty. + givesPopper :: (IO BS.ByteString -> IO ()) -> IO () + givesPopper needsPopper = S.unSourceT sourceIO $ \step0 -> do + ref <- newMVar step0 + + -- Note sure we need locking, but it's feels safer. + let popper :: IO BS.ByteString + popper = modifyMVar ref nextBs + + needsPopper popper + + nextBs S.Stop = return (S.Stop, BS.empty) + nextBs (S.Error err) = fail err + nextBs (S.Skip s) = nextBs s + nextBs (S.Effect ms) = ms >>= nextBs + nextBs (S.Yield lbs s) = case BL.toChunks lbs of + [] -> nextBs s + (x:xs) | BS.null x -> nextBs step' + | otherwise -> return (step', x) + where + step' = S.Yield (BL.fromChunks xs) s -- | Make the querying function append @path@ to the request path. instance (KnownSymbol path, HasClient m api) => HasClient m (path :> api) where diff --git a/servant-client/test/Servant/StreamSpec.hs b/servant-client/test/Servant/StreamSpec.hs index e41aa370..0e9c557a 100644 --- a/servant-client/test/Servant/StreamSpec.hs +++ b/servant-client/test/Servant/StreamSpec.hs @@ -35,8 +35,9 @@ import qualified Network.HTTP.Client as C import Prelude () import Prelude.Compat import Servant.API - ((:<|>) ((:<|>)), (:>), JSON, NetstringFraming, - NewlineFraming, NoFraming, OctetStream, SourceIO, StreamGet) + ((:<|>) ((:<|>)), (:>), JSON, NetstringFraming, StreamBody, + NewlineFraming, NoFraming, OctetStream, SourceIO, StreamGet, + ) import Servant.Client.Streaming import Servant.ClientSpec (Person (..)) @@ -72,13 +73,15 @@ type StreamApi = "streamGetNewline" :> StreamGet NewlineFraming JSON (SourceIO Person) :<|> "streamGetNetstring" :> StreamGet NetstringFraming JSON (SourceIO Person) :<|> "streamALot" :> StreamGet NoFraming OctetStream (SourceIO BS.ByteString) + :<|> "streamBody" :> StreamBody NoFraming OctetStream (SourceIO BS.ByteString) :> StreamGet NoFraming OctetStream (SourceIO BS.ByteString) api :: Proxy StreamApi api = Proxy getGetNL, getGetNS :: ClientM (SourceIO Person) getGetALot :: ClientM (SourceIO BS.ByteString) -getGetNL :<|> getGetNS :<|> getGetALot = client api +getStreamBody :: SourceT IO BS.ByteString -> ClientM (SourceIO BS.ByteString) +getGetNL :<|> getGetNS :<|> getGetALot :<|> getStreamBody = client api alice :: Person alice = Person "Alice" 42 @@ -90,9 +93,9 @@ server :: Application server = serve api $ return (source [alice, bob, alice]) :<|> return (source [alice, bob, alice]) - -- 2 ^ (18 + 10) = 256M :<|> return (SourceT ($ lots (powerOfTwo 18))) + :<|> return where lots n | n < 0 = Stop @@ -126,6 +129,13 @@ streamSpec = beforeAll (CS.startWaiApp server) $ afterAll CS.endWaiApp $ do withClient getGetNS baseUrl $ \(Right res) -> testRunSourceIO res `shouldReturn` Right [alice, bob, alice] + it "works with Servant.API.StreamBody" $ \(_, baseUrl) -> do + withClient (getStreamBody (source input)) baseUrl $ \(Right res) -> + testRunSourceIO res `shouldReturn` Right output + where + input = ["foo", "", "bar"] + output = ["foo", "bar"] + {- it "streams in constant memory" $ \(_, baseUrl) -> do Right rs <- runClient getGetALot baseUrl