diff --git a/src/Servant/Client.hs b/src/Servant/Client.hs index 8b4f5628..ed63de47 100644 --- a/src/Servant/Client.hs +++ b/src/Servant/Client.hs @@ -21,7 +21,7 @@ import Data.Proxy import Data.String.Conversions import Data.Text (unpack) import GHC.TypeLits -import Network.HTTP.Types +import qualified Network.HTTP.Types as H import Servant.API import Servant.Common.BaseUrl import Servant.Common.Req @@ -108,7 +108,7 @@ instance HasClient Delete where type Client Delete = BaseUrl -> EitherT String IO () clientWithRoute Proxy req host = - performRequestJSON methodDelete req 204 host + performRequestJSON H.methodDelete req 204 host -- | If you have a 'Get' endpoint in your API, the client -- side querying function that is created when calling 'client' @@ -117,7 +117,44 @@ instance HasClient Delete where instance FromJSON result => HasClient (Get result) where type Client (Get result) = BaseUrl -> EitherT String IO result clientWithRoute Proxy req host = - performRequestJSON methodGet req 200 host + performRequestJSON H.methodGet req 200 host + +-- | If you use a 'Header' in one of your endpoints in your API, +-- the corresponding querying function will automatically take +-- an additional argument of the type specified by your 'Header', +-- wrapped in Maybe. +-- +-- That function will take care of encoding this argument as Text +-- in the request headers. +-- +-- All you need is for your type to have a 'ToText' instance. +-- +-- Example: +-- +-- > newtype Referer = Referer Text +-- > deriving (Eq, Show, FromText, ToText) +-- > +-- > -- GET /view-my-referer +-- > type MyApi = "view-my-referer" :> Header "Referer" Referer :> Get Referer +-- > +-- > myApi :: Proxy MyApi +-- > myApi = Proxy +-- > +-- > viewReferer :: Maybe Referer -> BaseUrl -> EitherT String IO Book +-- > viewReferer = client myApi +-- > -- then you can just use "viewRefer" to query that endpoint +-- > -- specifying Nothing or Just "http://haskell.org/" as arguments +instance (KnownSymbol sym, ToText a, HasClient sublayout) + => HasClient (Header sym a :> sublayout) where + + type Client (Header sym a :> sublayout) = + Maybe a -> Client sublayout + + clientWithRoute Proxy req mval = + clientWithRoute (Proxy :: Proxy sublayout) $ + maybe req (\value -> addHeader hname value req) mval + + where hname = symbolVal (Proxy :: Proxy sym) -- | If you have a 'Post' endpoint in your API, the client -- side querying function that is created when calling 'client' @@ -127,7 +164,7 @@ instance FromJSON a => HasClient (Post a) where type Client (Post a) = BaseUrl -> EitherT String IO a clientWithRoute Proxy req uri = - performRequestJSON methodPost req 201 uri + performRequestJSON H.methodPost req 201 uri -- | If you have a 'Put' endpoint in your API, the client -- side querying function that is created when calling 'client' @@ -137,7 +174,7 @@ instance FromJSON a => HasClient (Put a) where type Client (Put a) = BaseUrl -> EitherT String IO a clientWithRoute Proxy req host = - performRequestJSON methodPut req 200 host + performRequestJSON H.methodPut req 200 host -- | If you use a 'QueryParam' in one of your endpoints in your API, -- the corresponding querying function will automatically take @@ -258,7 +295,7 @@ instance (KnownSymbol sym, HasClient sublayout) -- | Pick a 'Method' and specify where the server you want to query is. You get -- back the status code and the response body as a 'ByteString'. instance HasClient Raw where - type Client Raw = Method -> BaseUrl -> EitherT String IO (Int, ByteString) + type Client Raw = H.Method -> BaseUrl -> EitherT String IO (Int, ByteString) clientWithRoute :: Proxy Raw -> Req -> Client Raw clientWithRoute Proxy req httpMethod host = diff --git a/src/Servant/Common/Req.hs b/src/Servant/Common/Req.hs index 62c469d8..9f7db5c5 100644 --- a/src/Servant/Common/Req.hs +++ b/src/Servant/Common/Req.hs @@ -13,13 +13,16 @@ import Data.Aeson import Data.Aeson.Parser import Data.Aeson.Types import Data.Attoparsec.ByteString -import Data.ByteString.Lazy +import Data.ByteString.Lazy hiding (pack) +import Data.String import Data.String.Conversions import Data.Text +import Data.Text.Encoding import Network.HTTP.Client import Network.HTTP.Types import Network.URI import Servant.Common.BaseUrl +import Servant.Common.Text import System.IO.Unsafe import qualified Network.HTTP.Client as Client @@ -28,10 +31,11 @@ data Req = Req { reqPath :: String , qs :: QueryText , reqBody :: ByteString + , headers :: [(String, Text)] } defReq :: Req -defReq = Req "" [] "" +defReq = Req "" [] "" [] appendToPath :: String -> Req -> Req appendToPath p req = @@ -45,11 +49,17 @@ appendToQueryString pname pvalue req = req { qs = qs req ++ [(pname, pvalue)] } +addHeader :: ToText a => String -> a -> Req -> Req +addHeader name val req = req { headers = headers req + ++ [(name, toText val)] + } + setRQBody :: ByteString -> Req -> Req setRQBody b req = req { reqBody = b } reqToRequest :: (Functor m, MonadThrow m) => Req -> BaseUrl -> m Request -reqToRequest req (BaseUrl reqScheme reqHost reqPort) = fmap (setrqb . setQS ) $ parseUrl url +reqToRequest req (BaseUrl reqScheme reqHost reqPort) = + fmap (setheaders . setrqb . setQS ) $ parseUrl url where url = show $ nullURI { uriScheme = case reqScheme of Http -> "http:" @@ -64,6 +74,10 @@ reqToRequest req (BaseUrl reqScheme reqHost reqPort) = fmap (setrqb . setQS ) $ setrqb r = r { requestBody = RequestBodyLBS (reqBody req) } setQS = setQueryString $ queryTextToQuery (qs req) + setheaders r = r { requestHeaders = Prelude.map toProperHeader (headers req) } + + toProperHeader (name, val) = + (fromString name, encodeUtf8 val) -- * performing requests