Addresses problems with URL encodings

This changes the way URL encoding for query parameters is handled,
making it possible to correctly encode arbitrary binary data into query
parameter values.

Closes #1418
This commit is contained in:
Ian Shipman 2021-08-27 17:57:37 -05:00
parent 48bc24768e
commit 9666f1956b
5 changed files with 60 additions and 12 deletions

View file

@ -33,6 +33,9 @@ import Control.Arrow
(left, (+++))
import Control.Monad
(unless)
import qualified Data.ByteString as BS
import Data.ByteString.Builder
(toLazyByteString)
import qualified Data.ByteString.Lazy as BL
import Data.Either
(partitionEithers)
@ -76,7 +79,7 @@ import Servant.API
ReflectMethod (..), RemoteHost, ReqBody', SBoolI, Stream,
StreamBody', Summary, ToHttpApiData, ToSourceIO (..), Vault,
Verb, WithNamedContext, WithStatus (..), contentType, getHeadersHList,
getResponse, toQueryParam, toUrlPiece)
getResponse, toEncodedUrlPiece, toUrlPiece)
import Servant.API.ContentTypes
(contentTypes, AllMime (allMime), AllMimeUnrender (allMimeUnrender))
import Servant.API.TypeLevel (FragmentUnique, AtLeastOneFragment)
@ -554,7 +557,7 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api, SBoolI (FoldRequire
(Proxy :: Proxy mods) add (maybe req add) mparam
where
add :: a -> Request
add param = appendToQueryString pname (Just $ toQueryParam param) req
add param = appendToQueryString pname (Just $ encodeQueryParam param) req
pname :: Text
pname = pack $ symbolVal (Proxy :: Proxy sym)
@ -562,6 +565,9 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api, SBoolI (FoldRequire
hoistClientMonad pm _ f cl = \arg ->
hoistClientMonad pm (Proxy :: Proxy api) f (cl arg)
encodeQueryParam :: ToHttpApiData a => a -> BS.ByteString
encodeQueryParam = BL.toStrict . toLazyByteString . toEncodedUrlPiece
-- | If you use a 'QueryParams' in one of your endpoints in your API,
-- the corresponding querying function will automatically take
-- an additional argument, a list of values of the type specified
@ -603,7 +609,7 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api)
)
where pname = pack $ symbolVal (Proxy :: Proxy sym)
paramlist' = map (Just . toQueryParam) paramlist
paramlist' = map (Just . encodeQueryParam) paramlist
hoistClientMonad pm _ f cl = \as ->
hoistClientMonad pm (Proxy :: Proxy api) f (cl as)

View file

@ -146,12 +146,12 @@ appendToPath p req
= req { requestPath = requestPath req <> "/" <> toEncodedUrlPiece p }
appendToQueryString :: Text -- ^ param name
-> Maybe Text -- ^ param value
-> Maybe BS.ByteString -- ^ param value
-> Request
-> Request
appendToQueryString pname pvalue req
= req { requestQueryString = requestQueryString req
Seq.|> (encodeUtf8 pname, encodeUtf8 <$> pvalue)}
Seq.|> (encodeUtf8 pname, pvalue)}
addHeader :: ToHttpApiData a => HeaderName -> a -> Request -> Request
addHeader name val req

View file

@ -46,7 +46,7 @@ import qualified Data.ByteString.Lazy as BSL
import Data.Either
(either)
import Data.Foldable
(toList)
(foldl',toList)
import Data.Functor.Alt
(Alt (..))
import Data.Maybe
@ -63,7 +63,7 @@ import GHC.Generics
import Network.HTTP.Media
(renderHeader)
import Network.HTTP.Types
(hContentType, renderQuery, statusCode, Status)
(hContentType, renderQuery, statusCode, urlEncode, Status)
import Servant.Client.Core
import qualified Network.HTTP.Client as Client
@ -238,7 +238,7 @@ defaultMakeClientRequest burl r = Client.defaultRequest
, Client.path = BSL.toStrict
$ fromString (baseUrlPath burl)
<> toLazyByteString (requestPath r)
, Client.queryString = renderQuery True . toList $ requestQueryString r
, Client.queryString = buildQueryString . toList $ requestQueryString r
, Client.requestHeaders =
maybeToList acceptHdr ++ maybeToList contentTypeHdr ++ headers
, Client.requestBody = body
@ -289,6 +289,13 @@ defaultMakeClientRequest burl r = Client.defaultRequest
Http -> False
Https -> True
-- Query string builder which does not do any encoding
buildQueryString = ("?" <>) . foldl' addQueryParam mempty
addQueryParam qs (k, v) =
qs <> (if BS.null qs then mempty else "&") <> urlEncode True k <> foldMap ("=" <>) v
catchConnectionError :: IO a -> IO (Either ClientError a)
catchConnectionError action =
catch (Right <$> action) $ \e ->

View file

@ -24,9 +24,15 @@ import Prelude.Compat
import Control.Concurrent
(ThreadId, forkIO, killThread)
import Control.Monad
(join)
import Control.Monad.Error.Class
(throwError)
import Data.Aeson
import Data.ByteString
(ByteString)
import Data.ByteString.Builder
(byteString)
import qualified Data.ByteString.Lazy as LazyByteString
import Data.Char
(chr, isPrint)
@ -54,10 +60,10 @@ import Web.FormUrlEncoded
import Servant.API
((:<|>) ((:<|>)), (:>), AuthProtect, BasicAuth,
BasicAuthData (..), Capture, CaptureAll, DeleteNoContent,
EmptyAPI, FormUrlEncoded, Fragment, Get, Header, Headers,
EmptyAPI, FormUrlEncoded, Fragment, FromHttpApiData (..), Get, Header, Headers,
JSON, MimeRender (mimeRender), MimeUnrender (mimeUnrender),
NoContent (NoContent), PlainText, Post, QueryFlag, QueryParam,
QueryParams, Raw, ReqBody, StdMethod (GET), UVerb, Union,
QueryParams, Raw, ReqBody, StdMethod (GET), ToHttpApiData (..), UVerb, Union,
WithStatus (WithStatus), addHeader)
import Servant.Client
import qualified Servant.Client.Core.Auth as Auth
@ -109,6 +115,10 @@ type Api =
:<|> "captureAll" :> CaptureAll "names" String :> Get '[JSON] [Person]
:<|> "body" :> ReqBody '[FormUrlEncoded,JSON] Person :> Post '[JSON] Person
:<|> "param" :> QueryParam "name" String :> Get '[FormUrlEncoded,JSON] Person
-- This endpoint makes use of a 'Raw' server because it is not currently
-- possible to handle arbitrary binary query param values with
-- @servant-server@
:<|> "param-binary" :> QueryParam "payload" UrlEncodedByteString :> Raw
:<|> "params" :> QueryParams "names" String :> Get '[JSON] [Person]
:<|> "flag" :> QueryFlag "flag" :> Get '[JSON] Bool
:<|> "fragment" :> Fragment String :> Get '[JSON] Person
@ -143,6 +153,7 @@ getCapture :: String -> ClientM Person
getCaptureAll :: [String] -> ClientM [Person]
getBody :: Person -> ClientM Person
getQueryParam :: Maybe String -> ClientM Person
getQueryParamBinary :: Maybe UrlEncodedByteString -> HTTP.Method -> ClientM Response
getQueryParams :: [String] -> ClientM [Person]
getQueryFlag :: Bool -> ClientM Bool
getFragment :: ClientM Person
@ -167,6 +178,7 @@ getRoot
:<|> getCaptureAll
:<|> getBody
:<|> getQueryParam
:<|> getQueryParamBinary
:<|> getQueryParams
:<|> getQueryFlag
:<|> getFragment
@ -194,6 +206,13 @@ server = serve api (
Just "alice" -> return alice
Just n -> throwError $ ServerError 400 (n ++ " not found") "" []
Nothing -> throwError $ ServerError 400 "missing parameter" "" [])
:<|> const (Tagged $ \request respond ->
respond . maybe (Wai.responseLBS HTTP.notFound404 [] "Missing: payload")
(Wai.responseLBS HTTP.ok200 [] . LazyByteString.fromStrict)
. join
. lookup "payload"
$ Wai.queryString request
)
:<|> (\ names -> return (zipWith Person names [0..]))
:<|> return
:<|> return alice
@ -310,3 +329,12 @@ pathGen = fmap NonEmpty path
filter (not . (`elem` ("?%[]/#;" :: String))) $
filter isPrint $
map chr [0..127]
newtype UrlEncodedByteString = UrlEncodedByteString { unUrlEncodedByteString :: ByteString }
instance ToHttpApiData UrlEncodedByteString where
toEncodedUrlPiece = byteString . HTTP.urlEncode True . unUrlEncodedByteString
toUrlPiece = decodeUtf8 . HTTP.urlEncode True . unUrlEncodedByteString
instance FromHttpApiData UrlEncodedByteString where
parseUrlPiece = pure . UrlEncodedByteString . HTTP.urlDecode True . encodeUtf8

View file

@ -22,11 +22,13 @@ import Prelude ()
import Prelude.Compat
import Control.Arrow
(left)
((+++), left)
import Control.Concurrent.STM
(atomically)
import Control.Concurrent.STM.TVar
(newTVar, readTVar)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.Foldable
(forM_, toList)
import Data.Maybe
@ -93,6 +95,11 @@ successSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do
Left (FailureResponse _ r) <- runClient (getQueryParam (Just "bob")) baseUrl
responseStatusCode r `shouldBe` HTTP.Status 400 "bob not found"
it "Servant.API.QueryParam binary data" $ \(_, baseUrl) -> do
let payload = BS.pack [0, 1, 2, 4, 8, 16, 32, 64, 128]
apiCall = getQueryParamBinary (Just $ UrlEncodedByteString payload) HTTP.methodGet
(show +++ responseBody) <$> runClient apiCall baseUrl `shouldReturn` Right (BL.fromStrict payload)
it "Servant.API.QueryParam.QueryParams" $ \(_, baseUrl) -> do
left show <$> runClient (getQueryParams []) baseUrl `shouldReturn` Right []
left show <$> runClient (getQueryParams ["alice", "bob"]) baseUrl