diff --git a/servant-client-core/src/Servant/Client/Core.hs b/servant-client-core/src/Servant/Client/Core.hs index 6d216d1f..822df161 100644 --- a/servant-client-core/src/Servant/Client/Core.hs +++ b/servant-client-core/src/Servant/Client/Core.hs @@ -225,11 +225,10 @@ instance OVERLAPPABLE_ { requestAccept = fromList $ toList accept , requestMethod = method } - case mimeUnrender (Proxy :: Proxy ct) $ responseBody response of - Left err -> throwError $ DecodeFailure (pack err) response - Right val -> return val - where method = reflectMethod (Proxy :: Proxy method) - accept = contentTypes (Proxy :: Proxy ct) + response `decodedAs` (Proxy :: Proxy ct) + where + accept = contentTypes (Proxy :: Proxy ct) + method = reflectMethod (Proxy :: Proxy method) instance OVERLAPPING_ ( RunClient m, ReflectMethod method diff --git a/servant-client-core/src/Servant/Client/Core/Internal/Class.hs b/servant-client-core/src/Servant/Client/Core/Internal/Class.hs index 0428fcb8..37287fd9 100644 --- a/servant-client-core/src/Servant/Client/Core/Internal/Class.hs +++ b/servant-client-core/src/Servant/Client/Core/Internal/Class.hs @@ -1,11 +1,44 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE OverloadedStrings #-} {-| Types for possible backends to run client-side `Request` queries -} module Servant.Client.Core.Internal.Class where -import Control.Monad.Error.Class (MonadError) -import Servant.Client.Core.Internal.Request (Request, Response, - ServantError) +import Control.Monad (unless) +import Control.Monad.Error.Class (MonadError, throwError) +import Data.Proxy (Proxy) +import qualified Data.Text as T +import Network.HTTP.Media (MediaType, matches, + parseAccept, (//)) +import Servant.API (MimeUnrender, + contentTypes, + mimeUnrender) +import Servant.Client.Core.Internal.Request (Request, Response (..), + ServantError (..)) +import Data.Foldable (toList) class (MonadError ServantError m) => RunClient m where + -- | How to make a request. runRequest :: Request -> m Response + +checkContentTypeHeader :: RunClient m => Response -> m MediaType +checkContentTypeHeader response = + case lookup "Content-Type" $ toList $ responseHeaders response of + Nothing -> pure $ "application"//"octet-stream" + Just t -> case parseAccept t of + Nothing -> throwError $ InvalidContentTypeHeader response + Just t' -> pure t' + +decodedAs :: forall ct a m. (MimeUnrender ct a, RunClient m) + => Response -> Proxy ct -> m a +decodedAs response contentType = do + responseContentType <- checkContentTypeHeader response + unless (any (matches responseContentType) accept) $ + throwError $ UnsupportedContentType responseContentType response + case mimeUnrender contentType $ responseBody response of + Left err -> throwError $ DecodeFailure (T.pack err) response + Right val -> return val + where + accept = toList $ contentTypes contentType diff --git a/servant-client/test/Servant/ClientSpec.hs b/servant-client/test/Servant/ClientSpec.hs index 46835495..3a5ef1f6 100644 --- a/servant-client/test/Servant/ClientSpec.hs +++ b/servant-client/test/Servant/ClientSpec.hs @@ -24,7 +24,7 @@ {-# OPTIONS_GHC -fno-warn-name-shadowing #-} #include "overlapping-compat.h" -module Servant.ClientSpec where +module Servant.ClientSpec (spec) where import Prelude () import Prelude.Compat @@ -35,18 +35,13 @@ import Control.Concurrent (ThreadId, forkIO, import Control.Exception (bracket) import Control.Monad.Error.Class (throwError) import Data.Aeson -import qualified Data.ByteString.Lazy as BS import Data.Char (chr, isPrint) -import Data.Foldable (toList) import Data.Foldable (forM_) -import Data.Maybe (isJust) import Data.Monoid hiding (getLast) import Data.Proxy -import Data.Sequence (findIndexL) import qualified Generics.SOP as SOP import GHC.Generics (Generic) import qualified Network.HTTP.Client as C -import Network.HTTP.Media import qualified Network.HTTP.Types as HTTP import Network.Socket import qualified Network.Wai as Wai @@ -96,8 +91,8 @@ spec = describe "Servant.Client" $ do -- * test data types data Person = Person - { name :: String - , age :: Integer + { _name :: String + , _age :: Integer } deriving (Eq, Show, Generic) instance ToJSON Person @@ -233,14 +228,14 @@ genAuthAPI = Proxy type instance AuthServerData (AuthProtect "auth-tag") = () type instance AuthClientData (AuthProtect "auth-tag") = () -genAuthHandler :: AuthHandler Request () +genAuthHandler :: AuthHandler Wai.Request () genAuthHandler = - let handler req = case lookup "AuthHeader" (toList $ requestHeaders req) of + let handler req = case lookup "AuthHeader" (Wai.requestHeaders req) of Nothing -> throwError (err401 { errBody = "Missing auth header" }) Just _ -> return () in mkAuthHandler handler -genAuthServerContext :: Context '[ AuthHandler Request () ] +genAuthServerContext :: Context '[ AuthHandler Wai.Request () ] genAuthServerContext = genAuthHandler :. EmptyContext genAuthServer :: Application @@ -297,6 +292,7 @@ genericClientServer = serve (Proxy :: Proxy GenericClientAPI) ( manager' :: C.Manager manager' = unsafePerformIO $ C.newManager C.defaultManagerSettings +runClient :: ClientM a -> BaseUrl -> IO (Either ServantError a) runClient x baseUrl' = runClientM x (ClientEnv manager' baseUrl') sucessSpec :: Spec @@ -344,8 +340,6 @@ sucessSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do Right r -> do responseStatusCode r `shouldBe` HTTP.status200 responseBody r `shouldBe` "rawSuccess" - findIndexL (\x -> fst x == HTTP.hContentType) (responseHeaders r) - `shouldSatisfy` isJust it "Servant.API.Raw should return a Left in case of failure" $ \(_, baseUrl) -> do res <- runClient (getRawFailure HTTP.methodGet) baseUrl