diff --git a/servant-client.cabal b/servant-client.cabal index b27e5bc2..a5576cbc 100644 --- a/servant-client.cabal +++ b/servant-client.cabal @@ -71,6 +71,7 @@ test-suite spec , deepseq , either , hspec == 2.* + , http-media , http-types , network >= 2.6 , QuickCheck >= 2.7 diff --git a/src/Servant/Client.hs b/src/Servant/Client.hs index 69a8d71d..88bdc249 100644 --- a/src/Servant/Client.hs +++ b/src/Servant/Client.hs @@ -412,7 +412,7 @@ instance (KnownSymbol sym, HasClient sublayout) -- | Pick a 'Method' and specify where the server you want to query is. You get -- back the status code and the response body as a 'ByteString'. instance HasClient Raw where - type Client Raw = H.Method -> BaseUrl -> EitherT String IO (Int, ByteString) + type Client Raw = H.Method -> BaseUrl -> EitherT String IO (Int, ByteString, MediaType) clientWithRoute :: Proxy Raw -> Req -> Client Raw clientWithRoute Proxy req httpMethod host = diff --git a/src/Servant/Common/Req.hs b/src/Servant/Common/Req.hs index 0e8cf1c6..13b12285 100644 --- a/src/Servant/Common/Req.hs +++ b/src/Servant/Common/Req.hs @@ -14,7 +14,6 @@ import Data.Aeson.Parser import Data.Aeson.Types import Data.Attoparsec.ByteString import Data.ByteString.Lazy hiding (pack) -import qualified Data.ByteString.Char8 as BS import Data.String import Data.String.Conversions import Data.Text @@ -85,7 +84,7 @@ reqToRequest req (BaseUrl reqScheme reqHost reqPort) = setrqb r = case (reqBody req) of Nothing -> r Just (b,t) -> r { requestBody = RequestBodyLBS b - , requestHeaders = [(hContentType, BS.pack . show $ t)] } + , requestHeaders = [(hContentType, cs . show $ t)] } setQS = setQueryString $ queryTextToQuery (qs req) setheaders r = r { requestHeaders = requestHeaders r ++ Prelude.map toProperHeader (headers req) } @@ -110,7 +109,7 @@ displayHttpRequest :: Method -> String displayHttpRequest httpmethod = "HTTP " ++ cs httpmethod ++ " request" -performRequest :: Method -> Req -> (Int -> Bool) -> BaseUrl -> EitherT String IO (Int, ByteString) +performRequest :: Method -> Req -> (Int -> Bool) -> BaseUrl -> EitherT String IO (Int, ByteString, MediaType) performRequest reqMethod req isWantedStatus reqHost = do partialRequest <- liftIO $ reqToRequest req reqHost @@ -129,7 +128,12 @@ performRequest reqMethod req isWantedStatus reqHost = do let status = Client.responseStatus response unless (isWantedStatus (statusCode status)) $ left (displayHttpRequest reqMethod ++ " failed with status: " ++ showStatus status) - return $ (statusCode status, Client.responseBody response) + ct <- case lookup "Content-Type" $ Client.responseHeaders response of + Nothing -> pure $ "application"//"octet-stream" + Just t -> case parseAccept t of + Nothing -> left $ "invalid Content-Type header: " <> cs t + Just t' -> pure t' + return $ (statusCode status, Client.responseBody response, ct) where showStatus (Status code message) = show code ++ " - " ++ cs message @@ -138,7 +142,7 @@ performRequest reqMethod req isWantedStatus reqHost = do performRequestJSON :: FromJSON result => Method -> Req -> Int -> BaseUrl -> EitherT String IO result performRequestJSON reqMethod req wantedStatus reqHost = do - (_status, respBody) <- performRequest reqMethod req (== wantedStatus) reqHost + (_status, respBody, _) <- performRequest reqMethod req (== wantedStatus) reqHost either (\ message -> left (displayHttpRequest reqMethod ++ " returned invalid json: " ++ message)) return diff --git a/test/Servant/ClientSpec.hs b/test/Servant/ClientSpec.hs index 429a2fe4..70919bdf 100644 --- a/test/Servant/ClientSpec.hs +++ b/test/Servant/ClientSpec.hs @@ -17,6 +17,7 @@ import Data.Foldable (forM_) import Data.Proxy import Data.Typeable import GHC.Generics +import Network.HTTP.Media import Network.HTTP.Types import Network.Socket import Network.Wai @@ -101,8 +102,8 @@ getQueryFlag :: Bool -> BaseUrl -> EitherT String IO Bool getMatrixParam :: Maybe String -> BaseUrl -> EitherT String IO Person getMatrixParams :: [String] -> BaseUrl -> EitherT String IO [Person] getMatrixFlag :: Bool -> BaseUrl -> EitherT String IO Bool -getRawSuccess :: Method -> BaseUrl -> EitherT String IO (Int, ByteString) -getRawFailure :: Method -> BaseUrl -> EitherT String IO (Int, ByteString) +getRawSuccess :: Method -> BaseUrl -> EitherT String IO (Int, ByteString, MediaType) +getRawFailure :: Method -> BaseUrl -> EitherT String IO (Int, ByteString, MediaType) getMultiple :: String -> Maybe Int -> Bool -> [(String, [Rational])] -> BaseUrl -> EitherT String IO (String, Maybe Int, Bool, [(String, [Rational])]) @@ -167,10 +168,10 @@ spec = do runEitherT (getMatrixFlag flag host) `shouldReturn` Right flag it "Servant.API.Raw on success" $ withServer $ \ host -> do - runEitherT (getRawSuccess methodGet host) `shouldReturn` Right (200, "rawSuccess") + runEitherT (getRawSuccess methodGet host) `shouldReturn` Right (200, "rawSuccess", "application"//"octet-stream") it "Servant.API.Raw on failure" $ withServer $ \ host -> do - runEitherT (getRawFailure methodGet host) `shouldReturn` Right (400, "rawFailure") + runEitherT (getRawFailure methodGet host) `shouldReturn` Right (400, "rawFailure", "application"//"octet-stream") modifyMaxSuccess (const 20) $ do it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $