From 9666f1956b95e23d65276909e1f06353828cb4c0 Mon Sep 17 00:00:00 2001 From: Ian Shipman Date: Fri, 27 Aug 2021 17:57:37 -0500 Subject: [PATCH] Addresses problems with URL encodings This changes the way URL encoding for query parameters is handled, making it possible to correctly encode arbitrary binary data into query parameter values. Closes #1418 --- .../src/Servant/Client/Core/HasClient.hs | 12 +++++-- .../src/Servant/Client/Core/Request.hs | 6 ++-- .../src/Servant/Client/Internal/HttpClient.hs | 13 ++++++-- .../test/Servant/ClientTestUtils.hs | 32 +++++++++++++++++-- servant-client/test/Servant/SuccessSpec.hs | 9 +++++- 5 files changed, 60 insertions(+), 12 deletions(-) diff --git a/servant-client-core/src/Servant/Client/Core/HasClient.hs b/servant-client-core/src/Servant/Client/Core/HasClient.hs index 5f7ad3b3..d598bf66 100644 --- a/servant-client-core/src/Servant/Client/Core/HasClient.hs +++ b/servant-client-core/src/Servant/Client/Core/HasClient.hs @@ -33,6 +33,9 @@ import Control.Arrow (left, (+++)) import Control.Monad (unless) +import qualified Data.ByteString as BS +import Data.ByteString.Builder + (toLazyByteString) import qualified Data.ByteString.Lazy as BL import Data.Either (partitionEithers) @@ -76,7 +79,7 @@ import Servant.API ReflectMethod (..), RemoteHost, ReqBody', SBoolI, Stream, StreamBody', Summary, ToHttpApiData, ToSourceIO (..), Vault, Verb, WithNamedContext, WithStatus (..), contentType, getHeadersHList, - getResponse, toQueryParam, toUrlPiece) + getResponse, toEncodedUrlPiece, toUrlPiece) import Servant.API.ContentTypes (contentTypes, AllMime (allMime), AllMimeUnrender (allMimeUnrender)) import Servant.API.TypeLevel (FragmentUnique, AtLeastOneFragment) @@ -554,7 +557,7 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api, SBoolI (FoldRequire (Proxy :: Proxy mods) add (maybe req add) mparam where add :: a -> Request - add param = appendToQueryString pname (Just $ toQueryParam param) req + add param = appendToQueryString pname (Just $ encodeQueryParam param) req pname :: Text pname = pack $ symbolVal (Proxy :: Proxy sym) @@ -562,6 +565,9 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api, SBoolI (FoldRequire hoistClientMonad pm _ f cl = \arg -> hoistClientMonad pm (Proxy :: Proxy api) f (cl arg) +encodeQueryParam :: ToHttpApiData a => a -> BS.ByteString +encodeQueryParam = BL.toStrict . toLazyByteString . toEncodedUrlPiece + -- | If you use a 'QueryParams' in one of your endpoints in your API, -- the corresponding querying function will automatically take -- an additional argument, a list of values of the type specified @@ -603,7 +609,7 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api) ) where pname = pack $ symbolVal (Proxy :: Proxy sym) - paramlist' = map (Just . toQueryParam) paramlist + paramlist' = map (Just . encodeQueryParam) paramlist hoistClientMonad pm _ f cl = \as -> hoistClientMonad pm (Proxy :: Proxy api) f (cl as) diff --git a/servant-client-core/src/Servant/Client/Core/Request.hs b/servant-client-core/src/Servant/Client/Core/Request.hs index 9196c795..bdc3e382 100644 --- a/servant-client-core/src/Servant/Client/Core/Request.hs +++ b/servant-client-core/src/Servant/Client/Core/Request.hs @@ -145,13 +145,13 @@ appendToPath :: Text -> Request -> Request appendToPath p req = req { requestPath = requestPath req <> "/" <> toEncodedUrlPiece p } -appendToQueryString :: Text -- ^ param name - -> Maybe Text -- ^ param value +appendToQueryString :: Text -- ^ param name + -> Maybe BS.ByteString -- ^ param value -> Request -> Request appendToQueryString pname pvalue req = req { requestQueryString = requestQueryString req - Seq.|> (encodeUtf8 pname, encodeUtf8 <$> pvalue)} + Seq.|> (encodeUtf8 pname, pvalue)} addHeader :: ToHttpApiData a => HeaderName -> a -> Request -> Request addHeader name val req diff --git a/servant-client/src/Servant/Client/Internal/HttpClient.hs b/servant-client/src/Servant/Client/Internal/HttpClient.hs index 61d51bc4..a2c6864d 100644 --- a/servant-client/src/Servant/Client/Internal/HttpClient.hs +++ b/servant-client/src/Servant/Client/Internal/HttpClient.hs @@ -46,7 +46,7 @@ import qualified Data.ByteString.Lazy as BSL import Data.Either (either) import Data.Foldable - (toList) + (foldl',toList) import Data.Functor.Alt (Alt (..)) import Data.Maybe @@ -63,7 +63,7 @@ import GHC.Generics import Network.HTTP.Media (renderHeader) import Network.HTTP.Types - (hContentType, renderQuery, statusCode, Status) + (hContentType, renderQuery, statusCode, urlEncode, Status) import Servant.Client.Core import qualified Network.HTTP.Client as Client @@ -238,7 +238,7 @@ defaultMakeClientRequest burl r = Client.defaultRequest , Client.path = BSL.toStrict $ fromString (baseUrlPath burl) <> toLazyByteString (requestPath r) - , Client.queryString = renderQuery True . toList $ requestQueryString r + , Client.queryString = buildQueryString . toList $ requestQueryString r , Client.requestHeaders = maybeToList acceptHdr ++ maybeToList contentTypeHdr ++ headers , Client.requestBody = body @@ -289,6 +289,13 @@ defaultMakeClientRequest burl r = Client.defaultRequest Http -> False Https -> True + -- Query string builder which does not do any encoding + buildQueryString = ("?" <>) . foldl' addQueryParam mempty + + addQueryParam qs (k, v) = + qs <> (if BS.null qs then mempty else "&") <> urlEncode True k <> foldMap ("=" <>) v + + catchConnectionError :: IO a -> IO (Either ClientError a) catchConnectionError action = catch (Right <$> action) $ \e -> diff --git a/servant-client/test/Servant/ClientTestUtils.hs b/servant-client/test/Servant/ClientTestUtils.hs index 842712e1..198c6462 100644 --- a/servant-client/test/Servant/ClientTestUtils.hs +++ b/servant-client/test/Servant/ClientTestUtils.hs @@ -24,9 +24,15 @@ import Prelude.Compat import Control.Concurrent (ThreadId, forkIO, killThread) +import Control.Monad + (join) import Control.Monad.Error.Class (throwError) import Data.Aeson +import Data.ByteString + (ByteString) +import Data.ByteString.Builder + (byteString) import qualified Data.ByteString.Lazy as LazyByteString import Data.Char (chr, isPrint) @@ -54,10 +60,10 @@ import Web.FormUrlEncoded import Servant.API ((:<|>) ((:<|>)), (:>), AuthProtect, BasicAuth, BasicAuthData (..), Capture, CaptureAll, DeleteNoContent, - EmptyAPI, FormUrlEncoded, Fragment, Get, Header, Headers, + EmptyAPI, FormUrlEncoded, Fragment, FromHttpApiData (..), Get, Header, Headers, JSON, MimeRender (mimeRender), MimeUnrender (mimeUnrender), NoContent (NoContent), PlainText, Post, QueryFlag, QueryParam, - QueryParams, Raw, ReqBody, StdMethod (GET), UVerb, Union, + QueryParams, Raw, ReqBody, StdMethod (GET), ToHttpApiData (..), UVerb, Union, WithStatus (WithStatus), addHeader) import Servant.Client import qualified Servant.Client.Core.Auth as Auth @@ -109,6 +115,10 @@ type Api = :<|> "captureAll" :> CaptureAll "names" String :> Get '[JSON] [Person] :<|> "body" :> ReqBody '[FormUrlEncoded,JSON] Person :> Post '[JSON] Person :<|> "param" :> QueryParam "name" String :> Get '[FormUrlEncoded,JSON] Person + -- This endpoint makes use of a 'Raw' server because it is not currently + -- possible to handle arbitrary binary query param values with + -- @servant-server@ + :<|> "param-binary" :> QueryParam "payload" UrlEncodedByteString :> Raw :<|> "params" :> QueryParams "names" String :> Get '[JSON] [Person] :<|> "flag" :> QueryFlag "flag" :> Get '[JSON] Bool :<|> "fragment" :> Fragment String :> Get '[JSON] Person @@ -143,6 +153,7 @@ getCapture :: String -> ClientM Person getCaptureAll :: [String] -> ClientM [Person] getBody :: Person -> ClientM Person getQueryParam :: Maybe String -> ClientM Person +getQueryParamBinary :: Maybe UrlEncodedByteString -> HTTP.Method -> ClientM Response getQueryParams :: [String] -> ClientM [Person] getQueryFlag :: Bool -> ClientM Bool getFragment :: ClientM Person @@ -167,6 +178,7 @@ getRoot :<|> getCaptureAll :<|> getBody :<|> getQueryParam + :<|> getQueryParamBinary :<|> getQueryParams :<|> getQueryFlag :<|> getFragment @@ -194,6 +206,13 @@ server = serve api ( Just "alice" -> return alice Just n -> throwError $ ServerError 400 (n ++ " not found") "" [] Nothing -> throwError $ ServerError 400 "missing parameter" "" []) + :<|> const (Tagged $ \request respond -> + respond . maybe (Wai.responseLBS HTTP.notFound404 [] "Missing: payload") + (Wai.responseLBS HTTP.ok200 [] . LazyByteString.fromStrict) + . join + . lookup "payload" + $ Wai.queryString request + ) :<|> (\ names -> return (zipWith Person names [0..])) :<|> return :<|> return alice @@ -310,3 +329,12 @@ pathGen = fmap NonEmpty path filter (not . (`elem` ("?%[]/#;" :: String))) $ filter isPrint $ map chr [0..127] + +newtype UrlEncodedByteString = UrlEncodedByteString { unUrlEncodedByteString :: ByteString } + +instance ToHttpApiData UrlEncodedByteString where + toEncodedUrlPiece = byteString . HTTP.urlEncode True . unUrlEncodedByteString + toUrlPiece = decodeUtf8 . HTTP.urlEncode True . unUrlEncodedByteString + +instance FromHttpApiData UrlEncodedByteString where + parseUrlPiece = pure . UrlEncodedByteString . HTTP.urlDecode True . encodeUtf8 diff --git a/servant-client/test/Servant/SuccessSpec.hs b/servant-client/test/Servant/SuccessSpec.hs index 8729caf0..6b9f3bd0 100644 --- a/servant-client/test/Servant/SuccessSpec.hs +++ b/servant-client/test/Servant/SuccessSpec.hs @@ -22,11 +22,13 @@ import Prelude () import Prelude.Compat import Control.Arrow - (left) + ((+++), left) import Control.Concurrent.STM (atomically) import Control.Concurrent.STM.TVar (newTVar, readTVar) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Lazy as BL import Data.Foldable (forM_, toList) import Data.Maybe @@ -93,6 +95,11 @@ successSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do Left (FailureResponse _ r) <- runClient (getQueryParam (Just "bob")) baseUrl responseStatusCode r `shouldBe` HTTP.Status 400 "bob not found" + it "Servant.API.QueryParam binary data" $ \(_, baseUrl) -> do + let payload = BS.pack [0, 1, 2, 4, 8, 16, 32, 64, 128] + apiCall = getQueryParamBinary (Just $ UrlEncodedByteString payload) HTTP.methodGet + (show +++ responseBody) <$> runClient apiCall baseUrl `shouldReturn` Right (BL.fromStrict payload) + it "Servant.API.QueryParam.QueryParams" $ \(_, baseUrl) -> do left show <$> runClient (getQueryParams []) baseUrl `shouldReturn` Right [] left show <$> runClient (getQueryParams ["alice", "bob"]) baseUrl