Create monad transformer version of Session (closes #153) (#154)

This commit is contained in:
Jeroen Bransen 2017-10-02 22:33:49 +02:00 committed by fkm3
parent d79a919efa
commit d8bf349962
2 changed files with 33 additions and 25 deletions

View File

@ -33,8 +33,9 @@ module TensorFlow.Internal.FFI
import Control.Concurrent.Async (Async, async, cancel, waitCatch) import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar) import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when) import Control.Monad (when)
import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Bits (Bits, toIntegralSized) import Data.Bits (Bits, toIntegralSized)
import Data.Int (Int64) import Data.Int (Int64)
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
@ -75,13 +76,14 @@ data TensorData = TensorData
-- | Runs the given action after creating a session with options -- | Runs the given action after creating a session with options
-- populated by the given optionSetter. -- populated by the given optionSetter.
withSession :: (Raw.SessionOptions -> IO ()) withSession :: (MonadIO m, MonadMask m)
-> ((IO () -> IO ()) -> Raw.Session -> IO a) => (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> m a)
-- ^ The action can spawn concurrent tasks which will -- ^ The action can spawn concurrent tasks which will
-- be canceled before withSession returns. -- be canceled before withSession returns.
-> IO a -> m a
withSession optionSetter action = do withSession optionSetter action = do
drain <- newMVar [] drain <- liftIO $ newMVar []
let cleanup s = let cleanup s =
-- Closes the session to nudge the pending run calls to fail and exit. -- Closes the session to nudge the pending run calls to fail and exit.
finally (checkStatus (Raw.closeSession s)) $ do finally (checkStatus (Raw.closeSession s)) $ do
@ -89,10 +91,10 @@ withSession optionSetter action = do
-- Collects all runners before deleting the session. -- Collects all runners before deleting the session.
mapM_ shutDownRunner runners mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s) checkStatus (Raw.deleteSession s)
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do let bracketIO x y = bracket (liftIO x) (liftIO . y)
optionSetter options bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
bracket bracketIO
(checkStatus (Raw.newSession options)) (optionSetter options >> checkStatus (Raw.newSession options))
cleanup cleanup
(action (asyncCollector drain)) (action (asyncCollector drain))
@ -225,7 +227,7 @@ checkStatus fn =
when (code /= Raw.TF_OK) $ do when (code /= Raw.TF_OK) $ do
msg <- T.decodeUtf8With T.lenientDecode <$> msg <- T.decodeUtf8With T.lenientDecode <$>
(Raw.message status >>= B.packCString) (Raw.message status >>= B.packCString)
throwIO $ TensorFlowException code msg throwM $ TensorFlowException code msg
return result return result
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO () setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
@ -258,7 +260,7 @@ getAllOpList = do
where where
checkCall = do checkCall = do
p <- Raw.getAllOpList p <- Raw.getAllOpList
when (p == nullPtr) (throwIO exception) when (p == nullPtr) (throwM exception)
return p return p
exception = TensorFlowException exception = TensorFlowException
Raw.TF_UNKNOWN "GetAllOpList failure, check logs" Raw.TF_UNKNOWN "GetAllOpList failure, check logs"

View File

@ -20,6 +20,7 @@
module TensorFlow.Session ( module TensorFlow.Session (
Session, Session,
SessionT,
Options, Options,
sessionConfig, sessionConfig,
sessionTarget, sessionTarget,
@ -39,7 +40,7 @@ module TensorFlow.Session (
import Control.Monad (forever, unless, void) import Control.Monad (forever, unless, void)
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Class (MonadTrans, lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Data.Default (Default, def) import Data.Default (Default, def)
@ -73,13 +74,18 @@ data SessionState
, tracer :: Tracer , tracer :: Tracer
} }
newtype Session a newtype SessionT m a
= Session (ReaderT SessionState (BuildT IO) a) = Session (ReaderT SessionState (BuildT m) a)
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch, deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
MonadMask) MonadMask)
instance MonadTrans SessionT where
lift = Session . lift . lift
type Session = SessionT IO
-- | Run 'Session' actions in a new TensorFlow session. -- | Run 'Session' actions in a new TensorFlow session.
runSession :: Session a -> IO a runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
runSession = runSessionWithOptions def runSession = runSessionWithOptions def
-- | Customization for session. Use the lenses to update: -- | Customization for session. Use the lenses to update:
@ -112,7 +118,7 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
-- | Run 'Session' actions in a new TensorFlow session created with -- | Run 'Session' actions in a new TensorFlow session created with
-- the given option setter actions ('sessionTarget', 'sessionConfig'). -- the given option setter actions ('sessionTarget', 'sessionConfig').
runSessionWithOptions :: Options -> Session a -> IO a runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
runSessionWithOptions options (Session m) = runSessionWithOptions options (Session m) =
FFI.withSession applyOptions $ FFI.withSession applyOptions $
\as rs -> \as rs ->
@ -122,14 +128,14 @@ runSessionWithOptions options (Session m) =
FFI.setSessionTarget (options ^. sessionTarget) opt FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt FFI.setSessionConfig (options ^. sessionConfig) opt
instance MonadBuild Session where instance Monad m => MonadBuild (SessionT m) where
build = Session . lift . build build = Session . lift . build
-- | Add all pending rendered nodes to the TensorFlow graph and runs -- | Add all pending rendered nodes to the TensorFlow graph and runs
-- any pending initializers. -- any pending initializers.
-- --
-- Note that run, runWithFeeds, etc. will all call this function implicitly. -- Note that run, runWithFeeds, etc. will all call this function implicitly.
extend :: Session () extend :: MonadIO m => SessionT m ()
extend = do extend = do
session <- Session (asks rawSession) session <- Session (asks rawSession)
trace <- Session (asks tracer) trace <- Session (asks tracer)
@ -145,13 +151,13 @@ extend = do
-- | Run a subgraph 't', rendering any dependent nodes that aren't already -- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'. -- rendered, and fetch the corresponding values for 'a'.
run :: Fetchable t a => t -> Session a run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
run = runWithFeeds [] run = runWithFeeds []
-- | Run a subgraph 't', rendering any dependent nodes that aren't already -- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result -- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'. -- values for 'a'.
runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
runWithFeeds feeds t = do runWithFeeds feeds t = do
ns <- build $ getNodes t ns <- build $ getNodes t
-- Note that this call to "fetch" shouldn't affect the following "extend" -- Note that this call to "fetch" shouldn't affect the following "extend"
@ -160,7 +166,7 @@ runWithFeeds feeds t = do
fetch <- build $ getFetch t fetch <- build $ getFetch t
runFetchWithFeeds feeds ns fetch runFetchWithFeeds feeds ns fetch
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds feeds target (Fetch fetch restore) = do runFetchWithFeeds feeds target (Fetch fetch restore) = do
extend extend
let feeds' = fixFeeds feeds let feeds' = fixFeeds feeds
@ -180,14 +186,14 @@ toNodeNames = map (encodeUtf8 . unNodeName)
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't -- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
-- already rendered. This behaves like 'run' except that it doesn't do any -- already rendered. This behaves like 'run' except that it doesn't do any
-- fetches. -- fetches.
run_ :: Nodes t => t -> Session () run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
run_ = runWithFeeds_ [] run_ = runWithFeeds_ []
-- | Run a subgraph 't', rendering any dependent nodes that aren't already -- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result -- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do -- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
-- any fetches. -- any fetches.
runWithFeeds_ :: Nodes t => [Feed] -> t -> Session () runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
runWithFeeds_ feeds t = do runWithFeeds_ feeds t = do
ns <- build $ getNodes t ns <- build $ getNodes t
runFetchWithFeeds feeds ns (pure ()) runFetchWithFeeds feeds ns (pure ())
@ -199,9 +205,9 @@ fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
-- forever until runSession exits or an exception occurs. Graph -- forever until runSession exits or an exception occurs. Graph
-- extension happens synchronously, but the resultant run proceeds as -- extension happens synchronously, but the resultant run proceeds as
-- a separate thread. -- a separate thread.
asyncProdNodes :: Nodes t asyncProdNodes :: (MonadIO m, Nodes t)
=> t -- ^ Node to evaluate concurrently. => t -- ^ Node to evaluate concurrently.
-> Session () -> SessionT m ()
asyncProdNodes nodes = do asyncProdNodes nodes = do
target <- build (getNodes nodes) target <- build (getNodes nodes)
extend extend