diff --git a/tensorflow/src/TensorFlow/Internal/FFI.hs b/tensorflow/src/TensorFlow/Internal/FFI.hs index 59a79b3..2e29f68 100644 --- a/tensorflow/src/TensorFlow/Internal/FFI.hs +++ b/tensorflow/src/TensorFlow/Internal/FFI.hs @@ -33,8 +33,9 @@ module TensorFlow.Internal.FFI import Control.Concurrent.Async (Async, async, cancel, waitCatch) import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar) -import Control.Exception (Exception, throwIO, bracket, finally, mask_) import Control.Monad (when) +import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_) +import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Bits (Bits, toIntegralSized) import Data.Int (Int64) import Data.Maybe (fromMaybe) @@ -75,13 +76,14 @@ data TensorData = TensorData -- | Runs the given action after creating a session with options -- populated by the given optionSetter. -withSession :: (Raw.SessionOptions -> IO ()) - -> ((IO () -> IO ()) -> Raw.Session -> IO a) +withSession :: (MonadIO m, MonadMask m) + => (Raw.SessionOptions -> IO ()) + -> ((IO () -> IO ()) -> Raw.Session -> m a) -- ^ The action can spawn concurrent tasks which will -- be canceled before withSession returns. - -> IO a + -> m a withSession optionSetter action = do - drain <- newMVar [] + drain <- liftIO $ newMVar [] let cleanup s = -- Closes the session to nudge the pending run calls to fail and exit. finally (checkStatus (Raw.closeSession s)) $ do @@ -89,10 +91,10 @@ withSession optionSetter action = do -- Collects all runners before deleting the session. mapM_ shutDownRunner runners checkStatus (Raw.deleteSession s) - bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do - optionSetter options - bracket - (checkStatus (Raw.newSession options)) + let bracketIO x y = bracket (liftIO x) (liftIO . y) + bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do + bracketIO + (optionSetter options >> checkStatus (Raw.newSession options)) cleanup (action (asyncCollector drain)) @@ -225,7 +227,7 @@ checkStatus fn = when (code /= Raw.TF_OK) $ do msg <- T.decodeUtf8With T.lenientDecode <$> (Raw.message status >>= B.packCString) - throwIO $ TensorFlowException code msg + throwM $ TensorFlowException code msg return result setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO () @@ -258,7 +260,7 @@ getAllOpList = do where checkCall = do p <- Raw.getAllOpList - when (p == nullPtr) (throwIO exception) + when (p == nullPtr) (throwM exception) return p exception = TensorFlowException Raw.TF_UNKNOWN "GetAllOpList failure, check logs" diff --git a/tensorflow/src/TensorFlow/Session.hs b/tensorflow/src/TensorFlow/Session.hs index 70d85d4..a227349 100644 --- a/tensorflow/src/TensorFlow/Session.hs +++ b/tensorflow/src/TensorFlow/Session.hs @@ -20,6 +20,7 @@ module TensorFlow.Session ( Session, + SessionT, Options, sessionConfig, sessionTarget, @@ -39,7 +40,7 @@ module TensorFlow.Session ( import Control.Monad (forever, unless, void) import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.IO.Class (MonadIO, liftIO) -import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Class (MonadTrans, lift) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks) import Data.ByteString (ByteString) import Data.Default (Default, def) @@ -73,13 +74,18 @@ data SessionState , tracer :: Tracer } -newtype Session a - = Session (ReaderT SessionState (BuildT IO) a) +newtype SessionT m a + = Session (ReaderT SessionState (BuildT m) a) deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch, MonadMask) +instance MonadTrans SessionT where + lift = Session . lift . lift + +type Session = SessionT IO + -- | Run 'Session' actions in a new TensorFlow session. -runSession :: Session a -> IO a +runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a runSession = runSessionWithOptions def -- | Customization for session. Use the lenses to update: @@ -112,7 +118,7 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x }) -- | Run 'Session' actions in a new TensorFlow session created with -- the given option setter actions ('sessionTarget', 'sessionConfig'). -runSessionWithOptions :: Options -> Session a -> IO a +runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a runSessionWithOptions options (Session m) = FFI.withSession applyOptions $ \as rs -> @@ -122,14 +128,14 @@ runSessionWithOptions options (Session m) = FFI.setSessionTarget (options ^. sessionTarget) opt FFI.setSessionConfig (options ^. sessionConfig) opt -instance MonadBuild Session where +instance Monad m => MonadBuild (SessionT m) where build = Session . lift . build -- | Add all pending rendered nodes to the TensorFlow graph and runs -- any pending initializers. -- -- Note that run, runWithFeeds, etc. will all call this function implicitly. -extend :: Session () +extend :: MonadIO m => SessionT m () extend = do session <- Session (asks rawSession) trace <- Session (asks tracer) @@ -145,13 +151,13 @@ extend = do -- | Run a subgraph 't', rendering any dependent nodes that aren't already -- rendered, and fetch the corresponding values for 'a'. -run :: Fetchable t a => t -> Session a +run :: (MonadIO m, Fetchable t a) => t -> SessionT m a run = runWithFeeds [] -- | Run a subgraph 't', rendering any dependent nodes that aren't already -- rendered, feed the given input values, and fetch the corresponding result -- values for 'a'. -runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a +runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a runWithFeeds feeds t = do ns <- build $ getNodes t -- Note that this call to "fetch" shouldn't affect the following "extend" @@ -160,7 +166,7 @@ runWithFeeds feeds t = do fetch <- build $ getFetch t runFetchWithFeeds feeds ns fetch -runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a +runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a runFetchWithFeeds feeds target (Fetch fetch restore) = do extend let feeds' = fixFeeds feeds @@ -180,14 +186,14 @@ toNodeNames = map (encodeUtf8 . unNodeName) -- | Run a subgraph 't', rendering and extending any dependent nodes that aren't -- already rendered. This behaves like 'run' except that it doesn't do any -- fetches. -run_ :: Nodes t => t -> Session () +run_ :: (MonadIO m, Nodes t) => t -> SessionT m () run_ = runWithFeeds_ [] -- | Run a subgraph 't', rendering any dependent nodes that aren't already -- rendered, feed the given input values, and fetch the corresponding result -- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do -- any fetches. -runWithFeeds_ :: Nodes t => [Feed] -> t -> Session () +runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m () runWithFeeds_ feeds t = do ns <- build $ getNodes t runFetchWithFeeds feeds ns (pure ()) @@ -199,9 +205,9 @@ fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d) -- forever until runSession exits or an exception occurs. Graph -- extension happens synchronously, but the resultant run proceeds as -- a separate thread. -asyncProdNodes :: Nodes t +asyncProdNodes :: (MonadIO m, Nodes t) => t -- ^ Node to evaluate concurrently. - -> Session () + -> SessionT m () asyncProdNodes nodes = do target <- build (getNodes nodes) extend