Addresses problems with URL encodings

This changes the way URL encoding for query parameters is handled,
making it possible to correctly encode arbitrary binary data into query
parameter values.

Closes #1418
This commit is contained in:
Ian Shipman 2021-08-27 17:57:37 -05:00
parent 48bc24768e
commit 9666f1956b
5 changed files with 60 additions and 12 deletions

View file

@ -33,6 +33,9 @@ import Control.Arrow
(left, (+++)) (left, (+++))
import Control.Monad import Control.Monad
(unless) (unless)
import qualified Data.ByteString as BS
import Data.ByteString.Builder
(toLazyByteString)
import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy as BL
import Data.Either import Data.Either
(partitionEithers) (partitionEithers)
@ -76,7 +79,7 @@ import Servant.API
ReflectMethod (..), RemoteHost, ReqBody', SBoolI, Stream, ReflectMethod (..), RemoteHost, ReqBody', SBoolI, Stream,
StreamBody', Summary, ToHttpApiData, ToSourceIO (..), Vault, StreamBody', Summary, ToHttpApiData, ToSourceIO (..), Vault,
Verb, WithNamedContext, WithStatus (..), contentType, getHeadersHList, Verb, WithNamedContext, WithStatus (..), contentType, getHeadersHList,
getResponse, toQueryParam, toUrlPiece) getResponse, toEncodedUrlPiece, toUrlPiece)
import Servant.API.ContentTypes import Servant.API.ContentTypes
(contentTypes, AllMime (allMime), AllMimeUnrender (allMimeUnrender)) (contentTypes, AllMime (allMime), AllMimeUnrender (allMimeUnrender))
import Servant.API.TypeLevel (FragmentUnique, AtLeastOneFragment) import Servant.API.TypeLevel (FragmentUnique, AtLeastOneFragment)
@ -554,7 +557,7 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api, SBoolI (FoldRequire
(Proxy :: Proxy mods) add (maybe req add) mparam (Proxy :: Proxy mods) add (maybe req add) mparam
where where
add :: a -> Request add :: a -> Request
add param = appendToQueryString pname (Just $ toQueryParam param) req add param = appendToQueryString pname (Just $ encodeQueryParam param) req
pname :: Text pname :: Text
pname = pack $ symbolVal (Proxy :: Proxy sym) pname = pack $ symbolVal (Proxy :: Proxy sym)
@ -562,6 +565,9 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api, SBoolI (FoldRequire
hoistClientMonad pm _ f cl = \arg -> hoistClientMonad pm _ f cl = \arg ->
hoistClientMonad pm (Proxy :: Proxy api) f (cl arg) hoistClientMonad pm (Proxy :: Proxy api) f (cl arg)
encodeQueryParam :: ToHttpApiData a => a -> BS.ByteString
encodeQueryParam = BL.toStrict . toLazyByteString . toEncodedUrlPiece
-- | If you use a 'QueryParams' in one of your endpoints in your API, -- | If you use a 'QueryParams' in one of your endpoints in your API,
-- the corresponding querying function will automatically take -- the corresponding querying function will automatically take
-- an additional argument, a list of values of the type specified -- an additional argument, a list of values of the type specified
@ -603,7 +609,7 @@ instance (KnownSymbol sym, ToHttpApiData a, HasClient m api)
) )
where pname = pack $ symbolVal (Proxy :: Proxy sym) where pname = pack $ symbolVal (Proxy :: Proxy sym)
paramlist' = map (Just . toQueryParam) paramlist paramlist' = map (Just . encodeQueryParam) paramlist
hoistClientMonad pm _ f cl = \as -> hoistClientMonad pm _ f cl = \as ->
hoistClientMonad pm (Proxy :: Proxy api) f (cl as) hoistClientMonad pm (Proxy :: Proxy api) f (cl as)

View file

@ -145,13 +145,13 @@ appendToPath :: Text -> Request -> Request
appendToPath p req appendToPath p req
= req { requestPath = requestPath req <> "/" <> toEncodedUrlPiece p } = req { requestPath = requestPath req <> "/" <> toEncodedUrlPiece p }
appendToQueryString :: Text -- ^ param name appendToQueryString :: Text -- ^ param name
-> Maybe Text -- ^ param value -> Maybe BS.ByteString -- ^ param value
-> Request -> Request
-> Request -> Request
appendToQueryString pname pvalue req appendToQueryString pname pvalue req
= req { requestQueryString = requestQueryString req = req { requestQueryString = requestQueryString req
Seq.|> (encodeUtf8 pname, encodeUtf8 <$> pvalue)} Seq.|> (encodeUtf8 pname, pvalue)}
addHeader :: ToHttpApiData a => HeaderName -> a -> Request -> Request addHeader :: ToHttpApiData a => HeaderName -> a -> Request -> Request
addHeader name val req addHeader name val req

View file

@ -46,7 +46,7 @@ import qualified Data.ByteString.Lazy as BSL
import Data.Either import Data.Either
(either) (either)
import Data.Foldable import Data.Foldable
(toList) (foldl',toList)
import Data.Functor.Alt import Data.Functor.Alt
(Alt (..)) (Alt (..))
import Data.Maybe import Data.Maybe
@ -63,7 +63,7 @@ import GHC.Generics
import Network.HTTP.Media import Network.HTTP.Media
(renderHeader) (renderHeader)
import Network.HTTP.Types import Network.HTTP.Types
(hContentType, renderQuery, statusCode, Status) (hContentType, renderQuery, statusCode, urlEncode, Status)
import Servant.Client.Core import Servant.Client.Core
import qualified Network.HTTP.Client as Client import qualified Network.HTTP.Client as Client
@ -238,7 +238,7 @@ defaultMakeClientRequest burl r = Client.defaultRequest
, Client.path = BSL.toStrict , Client.path = BSL.toStrict
$ fromString (baseUrlPath burl) $ fromString (baseUrlPath burl)
<> toLazyByteString (requestPath r) <> toLazyByteString (requestPath r)
, Client.queryString = renderQuery True . toList $ requestQueryString r , Client.queryString = buildQueryString . toList $ requestQueryString r
, Client.requestHeaders = , Client.requestHeaders =
maybeToList acceptHdr ++ maybeToList contentTypeHdr ++ headers maybeToList acceptHdr ++ maybeToList contentTypeHdr ++ headers
, Client.requestBody = body , Client.requestBody = body
@ -289,6 +289,13 @@ defaultMakeClientRequest burl r = Client.defaultRequest
Http -> False Http -> False
Https -> True Https -> True
-- Query string builder which does not do any encoding
buildQueryString = ("?" <>) . foldl' addQueryParam mempty
addQueryParam qs (k, v) =
qs <> (if BS.null qs then mempty else "&") <> urlEncode True k <> foldMap ("=" <>) v
catchConnectionError :: IO a -> IO (Either ClientError a) catchConnectionError :: IO a -> IO (Either ClientError a)
catchConnectionError action = catchConnectionError action =
catch (Right <$> action) $ \e -> catch (Right <$> action) $ \e ->

View file

@ -24,9 +24,15 @@ import Prelude.Compat
import Control.Concurrent import Control.Concurrent
(ThreadId, forkIO, killThread) (ThreadId, forkIO, killThread)
import Control.Monad
(join)
import Control.Monad.Error.Class import Control.Monad.Error.Class
(throwError) (throwError)
import Data.Aeson import Data.Aeson
import Data.ByteString
(ByteString)
import Data.ByteString.Builder
(byteString)
import qualified Data.ByteString.Lazy as LazyByteString import qualified Data.ByteString.Lazy as LazyByteString
import Data.Char import Data.Char
(chr, isPrint) (chr, isPrint)
@ -54,10 +60,10 @@ import Web.FormUrlEncoded
import Servant.API import Servant.API
((:<|>) ((:<|>)), (:>), AuthProtect, BasicAuth, ((:<|>) ((:<|>)), (:>), AuthProtect, BasicAuth,
BasicAuthData (..), Capture, CaptureAll, DeleteNoContent, BasicAuthData (..), Capture, CaptureAll, DeleteNoContent,
EmptyAPI, FormUrlEncoded, Fragment, Get, Header, Headers, EmptyAPI, FormUrlEncoded, Fragment, FromHttpApiData (..), Get, Header, Headers,
JSON, MimeRender (mimeRender), MimeUnrender (mimeUnrender), JSON, MimeRender (mimeRender), MimeUnrender (mimeUnrender),
NoContent (NoContent), PlainText, Post, QueryFlag, QueryParam, NoContent (NoContent), PlainText, Post, QueryFlag, QueryParam,
QueryParams, Raw, ReqBody, StdMethod (GET), UVerb, Union, QueryParams, Raw, ReqBody, StdMethod (GET), ToHttpApiData (..), UVerb, Union,
WithStatus (WithStatus), addHeader) WithStatus (WithStatus), addHeader)
import Servant.Client import Servant.Client
import qualified Servant.Client.Core.Auth as Auth import qualified Servant.Client.Core.Auth as Auth
@ -109,6 +115,10 @@ type Api =
:<|> "captureAll" :> CaptureAll "names" String :> Get '[JSON] [Person] :<|> "captureAll" :> CaptureAll "names" String :> Get '[JSON] [Person]
:<|> "body" :> ReqBody '[FormUrlEncoded,JSON] Person :> Post '[JSON] Person :<|> "body" :> ReqBody '[FormUrlEncoded,JSON] Person :> Post '[JSON] Person
:<|> "param" :> QueryParam "name" String :> Get '[FormUrlEncoded,JSON] Person :<|> "param" :> QueryParam "name" String :> Get '[FormUrlEncoded,JSON] Person
-- This endpoint makes use of a 'Raw' server because it is not currently
-- possible to handle arbitrary binary query param values with
-- @servant-server@
:<|> "param-binary" :> QueryParam "payload" UrlEncodedByteString :> Raw
:<|> "params" :> QueryParams "names" String :> Get '[JSON] [Person] :<|> "params" :> QueryParams "names" String :> Get '[JSON] [Person]
:<|> "flag" :> QueryFlag "flag" :> Get '[JSON] Bool :<|> "flag" :> QueryFlag "flag" :> Get '[JSON] Bool
:<|> "fragment" :> Fragment String :> Get '[JSON] Person :<|> "fragment" :> Fragment String :> Get '[JSON] Person
@ -143,6 +153,7 @@ getCapture :: String -> ClientM Person
getCaptureAll :: [String] -> ClientM [Person] getCaptureAll :: [String] -> ClientM [Person]
getBody :: Person -> ClientM Person getBody :: Person -> ClientM Person
getQueryParam :: Maybe String -> ClientM Person getQueryParam :: Maybe String -> ClientM Person
getQueryParamBinary :: Maybe UrlEncodedByteString -> HTTP.Method -> ClientM Response
getQueryParams :: [String] -> ClientM [Person] getQueryParams :: [String] -> ClientM [Person]
getQueryFlag :: Bool -> ClientM Bool getQueryFlag :: Bool -> ClientM Bool
getFragment :: ClientM Person getFragment :: ClientM Person
@ -167,6 +178,7 @@ getRoot
:<|> getCaptureAll :<|> getCaptureAll
:<|> getBody :<|> getBody
:<|> getQueryParam :<|> getQueryParam
:<|> getQueryParamBinary
:<|> getQueryParams :<|> getQueryParams
:<|> getQueryFlag :<|> getQueryFlag
:<|> getFragment :<|> getFragment
@ -194,6 +206,13 @@ server = serve api (
Just "alice" -> return alice Just "alice" -> return alice
Just n -> throwError $ ServerError 400 (n ++ " not found") "" [] Just n -> throwError $ ServerError 400 (n ++ " not found") "" []
Nothing -> throwError $ ServerError 400 "missing parameter" "" []) Nothing -> throwError $ ServerError 400 "missing parameter" "" [])
:<|> const (Tagged $ \request respond ->
respond . maybe (Wai.responseLBS HTTP.notFound404 [] "Missing: payload")
(Wai.responseLBS HTTP.ok200 [] . LazyByteString.fromStrict)
. join
. lookup "payload"
$ Wai.queryString request
)
:<|> (\ names -> return (zipWith Person names [0..])) :<|> (\ names -> return (zipWith Person names [0..]))
:<|> return :<|> return
:<|> return alice :<|> return alice
@ -310,3 +329,12 @@ pathGen = fmap NonEmpty path
filter (not . (`elem` ("?%[]/#;" :: String))) $ filter (not . (`elem` ("?%[]/#;" :: String))) $
filter isPrint $ filter isPrint $
map chr [0..127] map chr [0..127]
newtype UrlEncodedByteString = UrlEncodedByteString { unUrlEncodedByteString :: ByteString }
instance ToHttpApiData UrlEncodedByteString where
toEncodedUrlPiece = byteString . HTTP.urlEncode True . unUrlEncodedByteString
toUrlPiece = decodeUtf8 . HTTP.urlEncode True . unUrlEncodedByteString
instance FromHttpApiData UrlEncodedByteString where
parseUrlPiece = pure . UrlEncodedByteString . HTTP.urlDecode True . encodeUtf8

View file

@ -22,11 +22,13 @@ import Prelude ()
import Prelude.Compat import Prelude.Compat
import Control.Arrow import Control.Arrow
(left) ((+++), left)
import Control.Concurrent.STM import Control.Concurrent.STM
(atomically) (atomically)
import Control.Concurrent.STM.TVar import Control.Concurrent.STM.TVar
(newTVar, readTVar) (newTVar, readTVar)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.Foldable import Data.Foldable
(forM_, toList) (forM_, toList)
import Data.Maybe import Data.Maybe
@ -93,6 +95,11 @@ successSpec = beforeAll (startWaiApp server) $ afterAll endWaiApp $ do
Left (FailureResponse _ r) <- runClient (getQueryParam (Just "bob")) baseUrl Left (FailureResponse _ r) <- runClient (getQueryParam (Just "bob")) baseUrl
responseStatusCode r `shouldBe` HTTP.Status 400 "bob not found" responseStatusCode r `shouldBe` HTTP.Status 400 "bob not found"
it "Servant.API.QueryParam binary data" $ \(_, baseUrl) -> do
let payload = BS.pack [0, 1, 2, 4, 8, 16, 32, 64, 128]
apiCall = getQueryParamBinary (Just $ UrlEncodedByteString payload) HTTP.methodGet
(show +++ responseBody) <$> runClient apiCall baseUrl `shouldReturn` Right (BL.fromStrict payload)
it "Servant.API.QueryParam.QueryParams" $ \(_, baseUrl) -> do it "Servant.API.QueryParam.QueryParams" $ \(_, baseUrl) -> do
left show <$> runClient (getQueryParams []) baseUrl `shouldReturn` Right [] left show <$> runClient (getQueryParams []) baseUrl `shouldReturn` Right []
left show <$> runClient (getQueryParams ["alice", "bob"]) baseUrl left show <$> runClient (getQueryParams ["alice", "bob"]) baseUrl