diff --git a/servant-client.cabal b/servant-client.cabal index 2b17dc21..77b832b7 100644 --- a/servant-client.cabal +++ b/servant-client.cabal @@ -45,6 +45,7 @@ library , exceptions , http-client , http-client-tls + , http-media , http-types , network-uri >= 2.6 , safe @@ -70,11 +71,13 @@ test-suite spec , deepseq , either , hspec == 2.* + , http-media , http-types , network >= 2.6 , QuickCheck >= 2.7 , servant >= 0.2.1 , servant-client , servant-server >= 0.2.1 + , text , wai , warp diff --git a/src/Servant/Client.hs b/src/Servant/Client.hs index 6a0f4f6b..46887186 100644 --- a/src/Servant/Client.hs +++ b/src/Servant/Client.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -22,8 +23,10 @@ import Data.Proxy import Data.String.Conversions import Data.Text (unpack) import GHC.TypeLits +import Network.HTTP.Media import qualified Network.HTTP.Types as H import Servant.API +import Servant.API.ContentTypes import Servant.Common.BaseUrl import Servant.Common.Req import Servant.Common.Text @@ -115,10 +118,10 @@ instance HasClient Delete where -- side querying function that is created when calling 'client' -- will just require an argument that specifies the scheme, host -- and port to send the request to. -instance FromJSON result => HasClient (Get result) where - type Client (Get result) = BaseUrl -> EitherT String IO result +instance (MimeUnrender ct result) => HasClient (Get (ct ': cts) result) where + type Client (Get (ct ': cts) result) = BaseUrl -> EitherT String IO result clientWithRoute Proxy req host = - performRequestJSON H.methodGet req 200 host + performRequestCT (Proxy :: Proxy ct) H.methodGet req 200 host -- | If you use a 'Header' in one of your endpoints in your API, -- the corresponding querying function will automatically take @@ -161,21 +164,21 @@ instance (KnownSymbol sym, ToText a, HasClient sublayout) -- side querying function that is created when calling 'client' -- will just require an argument that specifies the scheme, host -- and port to send the request to. -instance FromJSON a => HasClient (Post a) where - type Client (Post a) = BaseUrl -> EitherT String IO a +instance (MimeUnrender ct a) => HasClient (Post (ct ': cts) a) where + type Client (Post (ct ': cts) a) = BaseUrl -> EitherT String IO a clientWithRoute Proxy req uri = - performRequestJSON H.methodPost req 201 uri + performRequestCT (Proxy :: Proxy ct) H.methodPost req 201 uri -- | If you have a 'Put' endpoint in your API, the client -- side querying function that is created when calling 'client' -- will just require an argument that specifies the scheme, host -- and port to send the request to. -instance FromJSON a => HasClient (Put a) where - type Client (Put a) = BaseUrl -> EitherT String IO a +instance (MimeUnrender ct a) => HasClient (Put (ct ': cts) a) where + type Client (Put (ct ': cts) a) = BaseUrl -> EitherT String IO a clientWithRoute Proxy req host = - performRequestJSON H.methodPut req 200 host + performRequestCT (Proxy :: Proxy ct) H.methodPut req 200 host -- | If you use a 'QueryParam' in one of your endpoints in your API, -- the corresponding querying function will automatically take @@ -411,7 +414,7 @@ instance (KnownSymbol sym, HasClient sublayout) -- | Pick a 'Method' and specify where the server you want to query is. You get -- back the status code and the response body as a 'ByteString'. instance HasClient Raw where - type Client Raw = H.Method -> BaseUrl -> EitherT String IO (Int, ByteString) + type Client Raw = H.Method -> BaseUrl -> EitherT String IO (Int, ByteString, MediaType) clientWithRoute :: Proxy Raw -> Req -> Client Raw clientWithRoute Proxy req httpMethod host = @@ -435,15 +438,16 @@ instance HasClient Raw where -- > addBook :: Book -> BaseUrl -> EitherT String IO Book -- > addBook = client myApi -- > -- then you can just use "addBook" to query that endpoint -instance (ToJSON a, HasClient sublayout) - => HasClient (ReqBody a :> sublayout) where +instance (MimeRender ct a, HasClient sublayout) + => HasClient (ReqBody (ct ': cts) a :> sublayout) where - type Client (ReqBody a :> sublayout) = + type Client (ReqBody (ct ': cts) a :> sublayout) = a -> Client sublayout clientWithRoute Proxy req body = - clientWithRoute (Proxy :: Proxy sublayout) $ - setRQBody (encode body) req + clientWithRoute (Proxy :: Proxy sublayout) $ do + let ctProxy = Proxy :: Proxy ct + setRQBody (toByteString ctProxy body) (contentType ctProxy) req -- | Make the querying function append @path@ to the request path. instance (KnownSymbol path, HasClient sublayout) => HasClient (path :> sublayout) where diff --git a/src/Servant/Common/Req.hs b/src/Servant/Common/Req.hs index da85c02a..446bfd12 100644 --- a/src/Servant/Common/Req.hs +++ b/src/Servant/Common/Req.hs @@ -9,19 +9,18 @@ import Control.Monad import Control.Monad.Catch (MonadThrow) import Control.Monad.IO.Class import Control.Monad.Trans.Either -import Data.Aeson -import Data.Aeson.Parser -import Data.Aeson.Types -import Data.Attoparsec.ByteString -import Data.ByteString.Lazy hiding (pack) +import Data.ByteString.Lazy hiding (pack, filter, map, null) import Data.String import Data.String.Conversions -import Data.Text +import Data.Proxy +import Data.Text (Text) import Data.Text.Encoding -import Network.HTTP.Client +import Network.HTTP.Client hiding (Proxy) import Network.HTTP.Client.TLS +import Network.HTTP.Media import Network.HTTP.Types import Network.URI +import Servant.API.ContentTypes import Servant.Common.BaseUrl import Servant.Common.Text import System.IO.Unsafe @@ -29,14 +28,15 @@ import System.IO.Unsafe import qualified Network.HTTP.Client as Client data Req = Req - { reqPath :: String - , qs :: QueryText - , reqBody :: ByteString - , headers :: [(String, Text)] + { reqPath :: String + , qs :: QueryText + , reqBody :: Maybe (ByteString, MediaType) + , reqAccept :: [MediaType] + , headers :: [(String, Text)] } defReq :: Req -defReq = Req "" [] "" [] +defReq = Req "" [] Nothing [] [] appendToPath :: String -> Req -> Req appendToPath p req = @@ -62,12 +62,12 @@ addHeader name val req = req { headers = headers req ++ [(name, toText val)] } -setRQBody :: ByteString -> Req -> Req -setRQBody b req = req { reqBody = b } +setRQBody :: ByteString -> MediaType -> Req -> Req +setRQBody b t req = req { reqBody = Just (b, t) } reqToRequest :: (Functor m, MonadThrow m) => Req -> BaseUrl -> m Request reqToRequest req (BaseUrl reqScheme reqHost reqPort) = - fmap (setheaders . setrqb . setQS ) $ parseUrl url + fmap (setheaders . setAccept . setrqb . setQS ) $ parseUrl url where url = show $ nullURI { uriScheme = case reqScheme of Http -> "http:" @@ -80,10 +80,17 @@ reqToRequest req (BaseUrl reqScheme reqHost reqPort) = , uriPath = reqPath req } - setrqb r = r { requestBody = RequestBodyLBS (reqBody req) } + setrqb r = case reqBody req of + Nothing -> r + Just (b,t) -> r { requestBody = RequestBodyLBS b + , requestHeaders = requestHeaders r + ++ [(hContentType, cs . show $ t)] } setQS = setQueryString $ queryTextToQuery (qs req) - setheaders r = r { requestHeaders = Prelude.map toProperHeader (headers req) } - + setheaders r = r { requestHeaders = requestHeaders r + <> fmap toProperHeader (headers req) } + setAccept r = r { requestHeaders = filter ((/= "Accept") . fst) (requestHeaders r) + <> [("Accept", renderHeader $ reqAccept req) + | not . null . reqAccept $ req] } toProperHeader (name, val) = (fromString name, encodeUtf8 val) @@ -104,7 +111,7 @@ displayHttpRequest :: Method -> String displayHttpRequest httpmethod = "HTTP " ++ cs httpmethod ++ " request" -performRequest :: Method -> Req -> (Int -> Bool) -> BaseUrl -> EitherT String IO (Int, ByteString) +performRequest :: Method -> Req -> (Int -> Bool) -> BaseUrl -> EitherT String IO (Int, ByteString, MediaType) performRequest reqMethod req isWantedStatus reqHost = do partialRequest <- liftIO $ reqToRequest req reqHost @@ -123,20 +130,28 @@ performRequest reqMethod req isWantedStatus reqHost = do let status = Client.responseStatus response unless (isWantedStatus (statusCode status)) $ left (displayHttpRequest reqMethod ++ " failed with status: " ++ showStatus status) - return $ (statusCode status, Client.responseBody response) + ct <- case lookup "Content-Type" $ Client.responseHeaders response of + Nothing -> pure $ "application"//"octet-stream" + Just t -> case parseAccept t of + Nothing -> left $ "invalid Content-Type header: " <> cs t + Just t' -> pure t' + return (statusCode status, Client.responseBody response, ct) where showStatus (Status code message) = show code ++ " - " ++ cs message - -performRequestJSON :: FromJSON result => - Method -> Req -> Int -> BaseUrl -> EitherT String IO result -performRequestJSON reqMethod req wantedStatus reqHost = do - (_status, respBody) <- performRequest reqMethod req (== wantedStatus) reqHost +performRequestCT :: MimeUnrender ct result => + Proxy ct -> Method -> Req -> Int -> BaseUrl -> EitherT String IO result +performRequestCT ct reqMethod req wantedStatus reqHost = do + let acceptCT = contentType ct + (_status, respBody, respCT) <- + performRequest reqMethod (req { reqAccept = [acceptCT] }) (== wantedStatus) reqHost + unless (matches respCT (acceptCT)) $ + left $ "requested Content-Type " <> show acceptCT <> ", but got " <> show respCT either - (\ message -> left (displayHttpRequest reqMethod ++ " returned invalid json: " ++ message)) + (left . ((displayHttpRequest reqMethod ++ " returned invalid response of type" ++ show respCT) ++)) return - (decodeLenient respBody) + (fromByteString ct respBody) catchStatusCodeException :: IO a -> IO (Either Status a) @@ -145,10 +160,3 @@ catchStatusCodeException action = case e of Client.StatusCodeException status _ _ -> return $ Left status exc -> throwIO exc - --- | Like 'Data.Aeson.decode' but allows all JSON values instead of just --- objects and arrays. -decodeLenient :: FromJSON a => ByteString -> Either String a -decodeLenient input = do - v :: Value <- parseOnly (Data.Aeson.Parser.value <* endOfInput) (cs input) - parseEither parseJSON v diff --git a/test/Servant/ClientSpec.hs b/test/Servant/ClientSpec.hs index 429a2fe4..063c6345 100644 --- a/test/Servant/ClientSpec.hs +++ b/test/Servant/ClientSpec.hs @@ -3,6 +3,7 @@ {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_GHC -fcontext-stack=25 #-} module Servant.ClientSpec where @@ -14,9 +15,11 @@ import Data.Aeson import Data.ByteString.Lazy (ByteString) import Data.Char import Data.Foldable (forM_) +import Data.Monoid import Data.Proxy -import Data.Typeable +import qualified Data.Text as T import GHC.Generics +import Network.HTTP.Media import Network.HTTP.Types import Network.Socket import Network.Wai @@ -26,6 +29,7 @@ import Test.Hspec.QuickCheck import Test.QuickCheck import Servant.API +import Servant.API.ContentTypes import Servant.Client import Servant.Server @@ -40,28 +44,43 @@ data Person = Person { instance ToJSON Person instance FromJSON Person +instance ToFormUrlEncoded Person where + toFormUrlEncoded Person{..} = + [("name", T.pack name), ("age", T.pack (show age))] + +lookupEither :: (Show a, Eq a) => a -> [(a,b)] -> Either String b +lookupEither x xs = do + maybe (Left $ "could not find key " <> show x) return $ lookup x xs + +instance FromFormUrlEncoded Person where + fromFormUrlEncoded xs = do + n <- lookupEither "name" xs + a <- lookupEither "age" xs + return $ Person (T.unpack n) (read $ T.unpack a) + + alice :: Person alice = Person "Alice" 42 type Api = - "get" :> Get Person + "get" :> Get '[JSON] Person :<|> "delete" :> Delete - :<|> "capture" :> Capture "name" String :> Get Person - :<|> "body" :> ReqBody Person :> Post Person - :<|> "param" :> QueryParam "name" String :> Get Person - :<|> "params" :> QueryParams "names" String :> Get [Person] - :<|> "flag" :> QueryFlag "flag" :> Get Bool - :<|> "matrixparam" :> MatrixParam "name" String :> Get Person - :<|> "matrixparams" :> MatrixParams "name" String :> Get [Person] - :<|> "matrixflag" :> MatrixFlag "flag" :> Get Bool + :<|> "capture" :> Capture "name" String :> Get '[JSON,FormUrlEncoded] Person + :<|> "body" :> ReqBody '[FormUrlEncoded,JSON] Person :> Post '[JSON] Person + :<|> "param" :> QueryParam "name" String :> Get '[FormUrlEncoded,JSON] Person + :<|> "params" :> QueryParams "names" String :> Get '[JSON] [Person] + :<|> "flag" :> QueryFlag "flag" :> Get '[JSON] Bool + :<|> "matrixparam" :> MatrixParam "name" String :> Get '[JSON] Person + :<|> "matrixparams" :> MatrixParams "name" String :> Get '[JSON] [Person] + :<|> "matrixflag" :> MatrixFlag "flag" :> Get '[JSON] Bool :<|> "rawSuccess" :> Raw :<|> "rawFailure" :> Raw :<|> "multiple" :> Capture "first" String :> QueryParam "second" Int :> QueryFlag "third" :> - ReqBody [(String, [Rational])] :> - Get (String, Maybe Int, Bool, [(String, [Rational])]) + ReqBody '[JSON] [(String, [Rational])] :> + Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])]) api :: Proxy Api api = Proxy @@ -101,8 +120,8 @@ getQueryFlag :: Bool -> BaseUrl -> EitherT String IO Bool getMatrixParam :: Maybe String -> BaseUrl -> EitherT String IO Person getMatrixParams :: [String] -> BaseUrl -> EitherT String IO [Person] getMatrixFlag :: Bool -> BaseUrl -> EitherT String IO Bool -getRawSuccess :: Method -> BaseUrl -> EitherT String IO (Int, ByteString) -getRawFailure :: Method -> BaseUrl -> EitherT String IO (Int, ByteString) +getRawSuccess :: Method -> BaseUrl -> EitherT String IO (Int, ByteString, MediaType) +getRawFailure :: Method -> BaseUrl -> EitherT String IO (Int, ByteString, MediaType) getMultiple :: String -> Maybe Int -> Bool -> [(String, [Rational])] -> BaseUrl -> EitherT String IO (String, Maybe Int, Bool, [(String, [Rational])]) @@ -151,6 +170,7 @@ spec = do it (show flag) $ withServer $ \ host -> do runEitherT (getQueryFlag flag host) `shouldReturn` Right flag +{- it "Servant.API.MatrixParam" $ withServer $ \ host -> do runEitherT (getMatrixParam (Just "alice") host) `shouldReturn` Right alice Left result <- runEitherT (getMatrixParam (Just "bob") host) @@ -165,12 +185,13 @@ spec = do forM_ [False, True] $ \ flag -> it (show flag) $ withServer $ \ host -> do runEitherT (getMatrixFlag flag host) `shouldReturn` Right flag +-} it "Servant.API.Raw on success" $ withServer $ \ host -> do - runEitherT (getRawSuccess methodGet host) `shouldReturn` Right (200, "rawSuccess") + runEitherT (getRawSuccess methodGet host) `shouldReturn` Right (200, "rawSuccess", "application"//"octet-stream") it "Servant.API.Raw on failure" $ withServer $ \ host -> do - runEitherT (getRawFailure methodGet host) `shouldReturn` Right (400, "rawFailure") + runEitherT (getRawFailure methodGet host) `shouldReturn` Right (400, "rawFailure", "application"//"octet-stream") modifyMaxSuccess (const 20) $ do it "works for a combination of Capture, QueryParam, QueryFlag and ReqBody" $ @@ -183,9 +204,9 @@ spec = do context "client correctly handles error status codes" $ do - let test :: WrappedApi -> Spec - test (WrappedApi api) = - it (show (typeOf api)) $ + let test :: (WrappedApi, String) -> Spec + test (WrappedApi api, desc) = + it desc $ withWaiDaemon (return (serve api (left (500, "error message")))) $ \ host -> do let getResponse :: BaseUrl -> EitherT String IO () @@ -193,16 +214,15 @@ spec = do Left result <- runEitherT (getResponse host) result `shouldContain` "error message" mapM_ test $ - (WrappedApi (Proxy :: Proxy Delete)) : - (WrappedApi (Proxy :: Proxy (Get ()))) : - (WrappedApi (Proxy :: Proxy (Post ()))) : - (WrappedApi (Proxy :: Proxy (Put ()))) : + (WrappedApi (Proxy :: Proxy Delete), "Delete") : + (WrappedApi (Proxy :: Proxy (Get '[JSON] ())), "Delete") : + (WrappedApi (Proxy :: Proxy (Post '[JSON] ())), "Delete") : + (WrappedApi (Proxy :: Proxy (Put '[JSON] ())), "Delete") : [] data WrappedApi where WrappedApi :: (HasServer api, Server api ~ EitherT (Int, String) IO a, - HasClient api, Client api ~ (BaseUrl -> EitherT String IO ()), - Typeable api) => + HasClient api, Client api ~ (BaseUrl -> EitherT String IO ())) => Proxy api -> WrappedApi