diff --git a/servant-client/src/Servant/Common/BaseUrl.hs b/servant-client/src/Servant/Common/BaseUrl.hs index eae87c42..f8cc61e2 100644 --- a/servant-client/src/Servant/Common/BaseUrl.hs +++ b/servant-client/src/Servant/Common/BaseUrl.hs @@ -1,8 +1,19 @@ -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE ViewPatterns #-} -module Servant.Common.BaseUrl where +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE ViewPatterns #-} +module Servant.Common.BaseUrl ( + -- * types + BaseUrl (..) + , InvalidBaseUrlException + , Scheme (..) + -- * functions + , parseBaseUrl + , showBaseUrl +) where +import Control.Monad.Catch (MonadThrow, throwM, Exception) import Data.List +import Data.Typeable import GHC.Generics import Network.URI import Safe @@ -34,20 +45,23 @@ showBaseUrl (BaseUrl urlscheme host port) = (Https, 443) -> "" _ -> ":" ++ show port -parseBaseUrl :: String -> Either String BaseUrl +data InvalidBaseUrlException = InvalidBaseUrlException String deriving (Show, Typeable) +instance Exception InvalidBaseUrlException + +parseBaseUrl :: MonadThrow m => String -> m BaseUrl parseBaseUrl s = case parseURI (removeTrailingSlash s) of -- This is a rather hacky implementation and should be replaced with something -- implemented in attoparsec (which is already a dependency anyhow (via aeson)). Just (URI "http:" (Just (URIAuth "" host (':' : (readMaybe -> Just port)))) "" "" "") -> - Right (BaseUrl Http host port) + return (BaseUrl Http host port) Just (URI "http:" (Just (URIAuth "" host "")) "" "" "") -> - Right (BaseUrl Http host 80) + return (BaseUrl Http host 80) Just (URI "https:" (Just (URIAuth "" host (':' : (readMaybe -> Just port)))) "" "" "") -> - Right (BaseUrl Https host port) + return (BaseUrl Https host port) Just (URI "https:" (Just (URIAuth "" host "")) "" "" "") -> - Right (BaseUrl Https host 443) + return (BaseUrl Https host 443) _ -> if "://" `isInfixOf` s - then Left ("invalid base url: " ++ s) + then throwM (InvalidBaseUrlException $ "Invalid base URL: " ++ s) else parseBaseUrl ("http://" ++ s) where removeTrailingSlash str = case lastMay str of diff --git a/servant-client/src/Servant/Common/Req.hs b/servant-client/src/Servant/Common/Req.hs index b726e7a9..ac2c3dba 100644 --- a/servant-client/src/Servant/Common/Req.hs +++ b/servant-client/src/Servant/Common/Req.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -18,6 +19,7 @@ import Data.String.Conversions import Data.Proxy import Data.Text (Text) import Data.Text.Encoding +import Data.Typeable import Network.HTTP.Client hiding (Proxy) import Network.HTTP.Client.TLS import Network.HTTP.Media @@ -53,7 +55,9 @@ data ServantError { responseContentTypeHeader :: ByteString , responseBody :: ByteString } - deriving (Show) + deriving (Show, Typeable) + +instance Exception ServantError data Req = Req { reqPath :: String