add implementRequestStreamingCombinator

This commit is contained in:
Sönke Hahn 2016-10-23 20:43:40 -04:00
parent be5e6e59c7
commit cee7b1ffd1
3 changed files with 105 additions and 30 deletions

View file

@ -152,11 +152,13 @@ test-suite spec
, base-compat
, aeson
, base64-bytestring
, blaze-builder
, bytestring
, deepseq
, directory
, exceptions
, http-types
, mtl
, network
, resourcet
, safe
, servant

View file

@ -11,10 +11,14 @@ module Servant.Server.CombinatorUtils (
implementRequestCheck,
implementAuthCombinator,
argumentCombinator,
implementRequestStreamingCombinator,
-- * re-exports
RouteResult(..),
) where
import Data.ByteString
import Data.Proxy
import Data.Text
import Network.Wai
@ -84,3 +88,15 @@ argumentCombinator getArg = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addBodyCheck delayed -- fixme: shouldn't be body
(return ())
(\ () -> withRequest $ \ request -> liftRouteResult (getArg request))
implementRequestStreamingCombinator ::
forall api combinator arg context .
(ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler),
WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler),
HasServer api context) =>
(IO ByteString -> arg)
-> CombinatorImplementation combinator arg api context
implementRequestStreamingCombinator getArg = CI $ \ Proxy context delayed ->
route (Proxy :: Proxy api) context $ addBodyCheck delayed
(return ())
(\ () -> withRequest $ \ request -> liftRouteResult $ Route $ getArg $ requestBody request)

View file

@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
@ -11,11 +12,14 @@ module Servant.Server.CombinatorUtilsSpec where
import Control.Concurrent
import Data.ByteString.Builder
import Data.ByteString.Lazy
import Control.DeepSeq
import Control.Monad.IO.Class
import Data.ByteString as SBS hiding (map)
import Data.ByteString.Lazy as LBS hiding (map)
import Data.Monoid
import Data.Proxy
import Data.String.Conversions
import Data.Text
import Data.Text hiding (map)
import Network.HTTP.Types
import Network.Wai
import Network.Wai.Internal
@ -36,7 +40,7 @@ runApp app req = do
Nothing -> error "shouldn't happen"
Just response -> return (Just response, response)
responseBodyLbs :: Response -> IO ByteString
responseBodyLbs :: Response -> IO LBS.ByteString
responseBodyLbs response = do
let (_, _, action) = responseToStream response
action $ \ streamingBody -> do
@ -70,6 +74,74 @@ spec = do
response <- runApp app request
responseBodyLbs response `shouldReturn` "[]"
it "allows to write an auth combinator" $ do
let server (User name) = return name
app = serve (Proxy :: Proxy (AuthCombinator :> Get' String)) server
request = defaultRequest{
requestHeaders =
("Auth", "secret") :
requestHeaders defaultRequest
}
response <- runApp app request
responseStatus response `shouldBe` ok200
responseBodyLbs response `shouldReturn` "\"Alice\""
-- fixme: rename
it "allows to write a combinator by providing a function (Request -> a)" $ do
let server = return
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
request = defaultRequest{
requestHeaders =
("Foo", "foo") :
requestHeaders defaultRequest
}
response <- runApp app request
responseBodyLbs response `shouldReturn` "\"foo\""
context "streaming request bodies" $ do
let toBody :: [IO SBS.ByteString] -> IO (IO SBS.ByteString)
toBody list = do
mvar <- newMVar list
return $ do
modifyMVar mvar $ \case
(a : r) -> do
chunk <- a
return (r, chunk)
[] -> return ([], "")
fromBody :: IO SBS.ByteString -> IO SBS.ByteString
fromBody getChunk = do
chunk <- getChunk
if chunk == ""
then return ""
else do
rest <- fromBody getChunk
return $ chunk <> rest
it "allows to write combinators" $ do
body <- toBody $ map return ["foo", "bar"]
let server (Source b) = liftIO $ cs <$> fromBody b
app = serve (Proxy :: Proxy (StreamRequest :> Get' String)) server
request = defaultRequest{
requestBody = body
}
response <- runApp app request
responseBodyLbs response `shouldReturn` "\"foobar\""
it "allows to stream lazily" $ do
mvar <- newEmptyMVar
body <- toBody [return "foo", takeMVar mvar >> return "bar"]
let server (Source b) = liftIO $ do
first <- b
deepseq first (return ())
putMVar mvar ()
cs <$> (first <>) <$> fromBody b
app = serve (Proxy :: Proxy (StreamRequest :> Get' String)) server
request = defaultRequest{
requestBody = body
}
response <- runApp app request
responseBodyLbs response `shouldReturn` "\"foobar\""
it "allows to write a combinator that errors out" $ do
let server = return
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
@ -83,35 +155,9 @@ spec = do
it "allows to write a combinator using IO" $ do
pending
it "allows to write a combinator by providing a function (Request -> a)" $ do
let server = return
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
request = defaultRequest{
requestHeaders =
("Foo", "foo") :
requestHeaders defaultRequest
}
response <- runApp app request
responseBodyLbs response `shouldReturn` "\"foo\""
it "allows to write an auth combinator" $ do
let server (User name) = return name
app = serve (Proxy :: Proxy (AuthCombinator :> Get' String)) server
request = defaultRequest{
requestHeaders =
("Auth", "secret") :
requestHeaders defaultRequest
}
response <- runApp app request
responseStatus response `shouldBe` ok200
responseBodyLbs response `shouldReturn` "\"Alice\""
it "allows to pick the request check phase" $ do
pending
it "allows to write streaming combinators for request bodies" $ do
pending
it "disallows to access the request body unless in the checkBody phase" $ do
pending
@ -173,3 +219,14 @@ getCustom :: Request -> RouteResult String
getCustom request = case lookup "Foo" (requestHeaders request) of
Nothing -> FailFatal err400
Just l -> Route $ cs l
data StreamRequest
data Source = Source (IO SBS.ByteString)
instance HasServer api context => HasServer (StreamRequest :> api) context where
type ServerT (StreamRequest :> api) m = Source -> ServerT api m
route = runCI $ implementRequestStreamingCombinator getSource
getSource :: IO SBS.ByteString -> Source
getSource = Source