diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index f781068f..08c3579e 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -152,11 +152,13 @@ test-suite spec , base-compat , aeson , base64-bytestring + , blaze-builder , bytestring + , deepseq + , directory , exceptions , http-types , mtl - , network , resourcet , safe , servant diff --git a/servant-server/src/Servant/Server/CombinatorUtils.hs b/servant-server/src/Servant/Server/CombinatorUtils.hs index ecd0d250..3fca9b57 100644 --- a/servant-server/src/Servant/Server/CombinatorUtils.hs +++ b/servant-server/src/Servant/Server/CombinatorUtils.hs @@ -11,10 +11,14 @@ module Servant.Server.CombinatorUtils ( implementRequestCheck, implementAuthCombinator, argumentCombinator, + implementRequestStreamingCombinator, + -- * re-exports + RouteResult(..), ) where +import Data.ByteString import Data.Proxy import Data.Text import Network.Wai @@ -84,3 +88,15 @@ argumentCombinator getArg = CI $ \ Proxy context delayed -> route (Proxy :: Proxy api) context $ addBodyCheck delayed -- fixme: shouldn't be body (return ()) (\ () -> withRequest $ \ request -> liftRouteResult (getArg request)) + +implementRequestStreamingCombinator :: + forall api combinator arg context . + (ServerT (combinator :> api) Handler ~ (arg -> ServerT api Handler), + WithArg arg (ServerT api Handler) ~ (arg -> ServerT api Handler), + HasServer api context) => + (IO ByteString -> arg) + -> CombinatorImplementation combinator arg api context +implementRequestStreamingCombinator getArg = CI $ \ Proxy context delayed -> + route (Proxy :: Proxy api) context $ addBodyCheck delayed + (return ()) + (\ () -> withRequest $ \ request -> liftRouteResult $ Route $ getArg $ requestBody request) diff --git a/servant-server/test/Servant/Server/CombinatorUtilsSpec.hs b/servant-server/test/Servant/Server/CombinatorUtilsSpec.hs index fd1887c2..f33a829e 100644 --- a/servant-server/test/Servant/Server/CombinatorUtilsSpec.hs +++ b/servant-server/test/Servant/Server/CombinatorUtilsSpec.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -11,11 +12,14 @@ module Servant.Server.CombinatorUtilsSpec where import Control.Concurrent import Data.ByteString.Builder -import Data.ByteString.Lazy +import Control.DeepSeq +import Control.Monad.IO.Class +import Data.ByteString as SBS hiding (map) +import Data.ByteString.Lazy as LBS hiding (map) import Data.Monoid import Data.Proxy import Data.String.Conversions -import Data.Text +import Data.Text hiding (map) import Network.HTTP.Types import Network.Wai import Network.Wai.Internal @@ -36,7 +40,7 @@ runApp app req = do Nothing -> error "shouldn't happen" Just response -> return (Just response, response) -responseBodyLbs :: Response -> IO ByteString +responseBodyLbs :: Response -> IO LBS.ByteString responseBodyLbs response = do let (_, _, action) = responseToStream response action $ \ streamingBody -> do @@ -70,6 +74,74 @@ spec = do response <- runApp app request responseBodyLbs response `shouldReturn` "[]" + it "allows to write an auth combinator" $ do + let server (User name) = return name + app = serve (Proxy :: Proxy (AuthCombinator :> Get' String)) server + request = defaultRequest{ + requestHeaders = + ("Auth", "secret") : + requestHeaders defaultRequest + } + response <- runApp app request + responseStatus response `shouldBe` ok200 + responseBodyLbs response `shouldReturn` "\"Alice\"" + + -- fixme: rename + it "allows to write a combinator by providing a function (Request -> a)" $ do + let server = return + app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server + request = defaultRequest{ + requestHeaders = + ("Foo", "foo") : + requestHeaders defaultRequest + } + response <- runApp app request + responseBodyLbs response `shouldReturn` "\"foo\"" + + context "streaming request bodies" $ do + let toBody :: [IO SBS.ByteString] -> IO (IO SBS.ByteString) + toBody list = do + mvar <- newMVar list + return $ do + modifyMVar mvar $ \case + (a : r) -> do + chunk <- a + return (r, chunk) + [] -> return ([], "") + fromBody :: IO SBS.ByteString -> IO SBS.ByteString + fromBody getChunk = do + chunk <- getChunk + if chunk == "" + then return "" + else do + rest <- fromBody getChunk + return $ chunk <> rest + + it "allows to write combinators" $ do + body <- toBody $ map return ["foo", "bar"] + let server (Source b) = liftIO $ cs <$> fromBody b + app = serve (Proxy :: Proxy (StreamRequest :> Get' String)) server + request = defaultRequest{ + requestBody = body + } + response <- runApp app request + responseBodyLbs response `shouldReturn` "\"foobar\"" + + it "allows to stream lazily" $ do + mvar <- newEmptyMVar + body <- toBody [return "foo", takeMVar mvar >> return "bar"] + let server (Source b) = liftIO $ do + first <- b + deepseq first (return ()) + putMVar mvar () + cs <$> (first <>) <$> fromBody b + app = serve (Proxy :: Proxy (StreamRequest :> Get' String)) server + request = defaultRequest{ + requestBody = body + } + response <- runApp app request + responseBodyLbs response `shouldReturn` "\"foobar\"" + it "allows to write a combinator that errors out" $ do let server = return app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server @@ -83,35 +155,9 @@ spec = do it "allows to write a combinator using IO" $ do pending - it "allows to write a combinator by providing a function (Request -> a)" $ do - let server = return - app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server - request = defaultRequest{ - requestHeaders = - ("Foo", "foo") : - requestHeaders defaultRequest - } - response <- runApp app request - responseBodyLbs response `shouldReturn` "\"foo\"" - - it "allows to write an auth combinator" $ do - let server (User name) = return name - app = serve (Proxy :: Proxy (AuthCombinator :> Get' String)) server - request = defaultRequest{ - requestHeaders = - ("Auth", "secret") : - requestHeaders defaultRequest - } - response <- runApp app request - responseStatus response `shouldBe` ok200 - responseBodyLbs response `shouldReturn` "\"Alice\"" - it "allows to pick the request check phase" $ do pending - it "allows to write streaming combinators for request bodies" $ do - pending - it "disallows to access the request body unless in the checkBody phase" $ do pending @@ -173,3 +219,14 @@ getCustom :: Request -> RouteResult String getCustom request = case lookup "Foo" (requestHeaders request) of Nothing -> FailFatal err400 Just l -> Route $ cs l + +data StreamRequest + +data Source = Source (IO SBS.ByteString) + +instance HasServer api context => HasServer (StreamRequest :> api) context where + type ServerT (StreamRequest :> api) m = Source -> ServerT api m + route = runCI $ implementRequestStreamingCombinator getSource + +getSource :: IO SBS.ByteString -> Source +getSource = Source