diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index fc07f74a..fb03edd7 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -109,6 +109,7 @@ test-suite spec Servant.ArbitraryMonadServerSpec Servant.Server.ErrorSpec Servant.Server.Internal.ContextSpec + Servant.Server.Internal.RoutingApplicationSpec Servant.Server.RouterSpec Servant.Server.StreamingSpec Servant.Server.UsingContextSpec diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index 7c89b8f5..890c1856 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -23,6 +23,7 @@ module Servant.Server.Internal , module Servant.Server.Internal.ServantErr ) where +import Control.Exception (finally) import Control.Monad.Trans (liftIO) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC8 @@ -400,11 +401,17 @@ instance HasServer Raw context where type ServerT Raw m = Application route Proxy _ rawApplication = RawRouter $ \ env request respond -> do - r <- runDelayed rawApplication env request - case r of - Route app -> app request (respond . Route) - Fail a -> respond $ Fail a - FailFatal e -> respond $ FailFatal e + -- note: a Raw application doesn't register any cleanup + -- but for the sake of consistency, we nonetheless run + -- the cleanup once its done + cleanupRef <- newCleanupRef + r <- runDelayed rawApplication env request cleanupRef + go r request respond `finally` runCleanup cleanupRef + + where go r request respond = case r of + Route app -> app request (respond . Route) + Fail a -> respond $ Fail a + FailFatal e -> respond $ FailFatal e -- | If you use 'ReqBody' in one of the endpoints for your API, -- this automatically requires your server-side handler to be a function diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 85fb04dc..e8bd7bc6 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -6,10 +6,13 @@ {-# LANGUAGE KindSignatures #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} module Servant.Server.Internal.RoutingApplication where +import Control.Exception (finally) import Control.Monad (ap, liftM) import Control.Monad.Trans (MonadIO(..)) +import Data.IORef (newIORef, readIORef, IORef, atomicModifyIORef) import Network.Wai (Application, Request, Response, ResponseReceived) import Prelude () @@ -112,12 +115,33 @@ instance Functor (Delayed env) where , .. } -- Note [Existential Record Update] +-- | A mutable cleanup action +newtype CleanupRef = CleanupRef (IORef (IO ())) + +newCleanupRef :: IO CleanupRef +newCleanupRef = CleanupRef <$> newIORef (return ()) + +-- | Add a clean up action to a 'CleanupRef' +addCleanup' :: IO () -> CleanupRef -> IO () +addCleanup' act (CleanupRef ref) = atomicModifyIORef ref (\act' -> (act' >> act, ())) + +addCleanup :: IO () -> DelayedIO () +addCleanup act = DelayedIO $ \_req cleanupRef -> + addCleanup' act cleanupRef >> return (Route ()) + +-- | Run all the clean up actions registered in +-- a 'CleanupRef'. +runCleanup :: CleanupRef -> IO () +runCleanup (CleanupRef ref) = do + act <- readIORef ref + act + -- | Computations used in a 'Delayed' can depend on the -- incoming 'Request', may perform 'IO, and result in a -- 'RouteResult, meaning they can either suceed, fail -- (with the possibility to recover), or fail fatally. -- -newtype DelayedIO a = DelayedIO { runDelayedIO :: Request -> IO (RouteResult a) } +newtype DelayedIO a = DelayedIO { runDelayedIO :: Request -> CleanupRef -> IO (RouteResult a) } instance Functor DelayedIO where fmap = liftM @@ -127,17 +151,17 @@ instance Applicative DelayedIO where (<*>) = ap instance Monad DelayedIO where - return x = DelayedIO (const $ return (Route x)) + return x = DelayedIO (\_req _cleanup -> return (Route x)) DelayedIO m >>= f = - DelayedIO $ \ req -> do - r <- m req + DelayedIO $ \ req cl -> do + r <- m req cl case r of Fail e -> return $ Fail e FailFatal e -> return $ FailFatal e - Route a -> runDelayedIO (f a) req + Route a -> runDelayedIO (f a) req cl instance MonadIO DelayedIO where - liftIO m = DelayedIO (const $ Route <$> m) + liftIO m = DelayedIO (\_req _cl -> Route <$> m) -- | A 'Delayed' without any stored checks. emptyDelayed :: RouteResult a -> Delayed env a @@ -148,15 +172,15 @@ emptyDelayed result = -- | Fail with the option to recover. delayedFail :: ServantErr -> DelayedIO a -delayedFail err = DelayedIO (const $ return $ Fail err) +delayedFail err = DelayedIO (\_req _cleanup -> return $ Fail err) -- | Fail fatally, i.e., without any option to recover. delayedFailFatal :: ServantErr -> DelayedIO a -delayedFailFatal err = DelayedIO (const $ return $ FailFatal err) +delayedFailFatal err = DelayedIO (\_req _cleanup -> return $ FailFatal err) -- | Gain access to the incoming request. withRequest :: (Request -> DelayedIO a) -> DelayedIO a -withRequest f = DelayedIO (\ req -> runDelayedIO (f req) req) +withRequest f = DelayedIO (\ req cl -> runDelayedIO (f req) req cl) -- | Add a capture to the end of the capture block. addCapture :: Delayed env (a -> b) @@ -196,8 +220,8 @@ addBodyCheck :: Delayed env (a -> b) -> Delayed env b addBodyCheck Delayed{..} new = Delayed - { bodyD = (,) <$> bodyD <*> new - , serverD = \ c a (z, v) req -> ($ v) <$> serverD c a z req + { bodyD = (,) <$> bodyD <*> new + , serverD = \ c a (z, v) req -> ($ v) <$> serverD c a z req , .. } -- Note [Existential Record Update] @@ -240,13 +264,18 @@ passToServer Delayed{..} x = runDelayed :: Delayed env a -> env -> Request + -> CleanupRef -> IO (RouteResult a) -runDelayed Delayed{..} env = runDelayedIO $ do - c <- capturesD env - methodD - a <- authD - b <- bodyD - DelayedIO (\ req -> return $ serverD c a b req) +runDelayed Delayed{..} env req cleanupRef = + runDelayedIO + (do c <- capturesD env + methodD + a <- authD + b <- bodyD + DelayedIO $ \ r _cleanup -> return (serverD c a b r) + ) + req + cleanupRef -- | Runs a delayed server and the resulting action. -- Takes a continuation that lets us send a response. @@ -258,8 +287,11 @@ runAction :: Delayed env (Handler a) -> (RouteResult Response -> IO r) -> (a -> RouteResult Response) -> IO r -runAction action env req respond k = - runDelayed action env req >>= go >>= respond +runAction action env req respond k = do + cleanupRef <- newCleanupRef + (runDelayed action env req cleanupRef >>= go >>= respond) + `finally` runCleanup cleanupRef + where go (Fail e) = return $ Fail e go (FailFatal e) = return $ FailFatal e diff --git a/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs new file mode 100644 index 00000000..abdf016d --- /dev/null +++ b/servant-server/test/Servant/Server/Internal/RoutingApplicationSpec.hs @@ -0,0 +1,61 @@ +module Servant.Server.Internal.RoutingApplicationSpec (spec) where + +import Prelude () +import Prelude.Compat + +import Control.Exception hiding (Handler) +import Control.Monad.IO.Class +import Data.Maybe (isJust) +import Data.IORef +import Servant.Server +import Servant.Server.Internal.RoutingApplication +import Test.Hspec + +import System.IO.Unsafe (unsafePerformIO) + +ok :: IO (RouteResult ()) +ok = return (Route ()) + +-- Let's not write to the filesystem +delayedTestRef :: IORef (Maybe String) +delayedTestRef = unsafePerformIO $ newIORef Nothing + +delayed :: DelayedIO () -> RouteResult (Handler ()) -> Delayed () (Handler ()) +delayed body srv = Delayed + { capturesD = \() -> DelayedIO $ \_req _cl -> ok + , methodD = DelayedIO $ \_req_ _cl -> ok + , authD = DelayedIO $ \_req _cl -> ok + , bodyD = do + liftIO (writeIORef delayedTestRef (Just "hia") >> putStrLn "garbage created") + addCleanup (writeIORef delayedTestRef Nothing >> putStrLn "garbage collected") + body + , serverD = \() () _body _req -> srv + } + +simpleRun :: Delayed () (Handler ()) + -> IO () +simpleRun d = fmap (either ignoreE id) . try $ + runAction d () undefined (\_ -> return ()) (\_ -> FailFatal err500) + + where ignoreE :: SomeException -> () + ignoreE = const () + +spec :: Spec +spec = do + describe "Delayed" $ do + it "actually runs clean up actions" $ do + _ <- simpleRun $ delayed (return ()) (Route $ return ()) + cleanUpDone <- isJust <$> readIORef delayedTestRef + cleanUpDone `shouldBe` False + it "even with exceptions in serverD" $ do + _ <- simpleRun $ delayed (return ()) (Route $ throw DivideByZero) + cleanUpDone <- isJust <$> readIORef delayedTestRef + cleanUpDone `shouldBe` False + it "even with routing failure in bodyD" $ do + _ <- simpleRun $ delayed (delayedFailFatal err500) (Route $ return ()) + cleanUpDone <- isJust <$> readIORef delayedTestRef + cleanUpDone `shouldBe` False + it "even with exceptions in bodyD" $ do + _ <- simpleRun $ delayed (liftIO $ throwIO DivideByZero) (Route $ return ()) + cleanUpDone <- isJust <$> readIORef delayedTestRef + cleanUpDone `shouldBe` False