1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Create monad transformer version of Session (closes #153) (#154)

This commit is contained in:
Jeroen Bransen 2017-10-02 22:33:49 +02:00 committed by fkm3
parent d79a919efa
commit d8bf349962
2 changed files with 33 additions and 25 deletions

View file

@ -33,8 +33,9 @@ module TensorFlow.Internal.FFI
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when)
import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Bits (Bits, toIntegralSized)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
@ -75,13 +76,14 @@ data TensorData = TensorData
-- | Runs the given action after creating a session with options
-- populated by the given optionSetter.
withSession :: (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> IO a)
withSession :: (MonadIO m, MonadMask m)
=> (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> m a)
-- ^ The action can spawn concurrent tasks which will
-- be canceled before withSession returns.
-> IO a
-> m a
withSession optionSetter action = do
drain <- newMVar []
drain <- liftIO $ newMVar []
let cleanup s =
-- Closes the session to nudge the pending run calls to fail and exit.
finally (checkStatus (Raw.closeSession s)) $ do
@ -89,10 +91,10 @@ withSession optionSetter action = do
-- Collects all runners before deleting the session.
mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s)
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
optionSetter options
bracket
(checkStatus (Raw.newSession options))
let bracketIO x y = bracket (liftIO x) (liftIO . y)
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
bracketIO
(optionSetter options >> checkStatus (Raw.newSession options))
cleanup
(action (asyncCollector drain))
@ -225,7 +227,7 @@ checkStatus fn =
when (code /= Raw.TF_OK) $ do
msg <- T.decodeUtf8With T.lenientDecode <$>
(Raw.message status >>= B.packCString)
throwIO $ TensorFlowException code msg
throwM $ TensorFlowException code msg
return result
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
@ -258,7 +260,7 @@ getAllOpList = do
where
checkCall = do
p <- Raw.getAllOpList
when (p == nullPtr) (throwIO exception)
when (p == nullPtr) (throwM exception)
return p
exception = TensorFlowException
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"

View file

@ -20,6 +20,7 @@
module TensorFlow.Session (
Session,
SessionT,
Options,
sessionConfig,
sessionTarget,
@ -39,7 +40,7 @@ module TensorFlow.Session (
import Control.Monad (forever, unless, void)
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Default (Default, def)
@ -73,13 +74,18 @@ data SessionState
, tracer :: Tracer
}
newtype Session a
= Session (ReaderT SessionState (BuildT IO) a)
newtype SessionT m a
= Session (ReaderT SessionState (BuildT m) a)
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
MonadMask)
instance MonadTrans SessionT where
lift = Session . lift . lift
type Session = SessionT IO
-- | Run 'Session' actions in a new TensorFlow session.
runSession :: Session a -> IO a
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
runSession = runSessionWithOptions def
-- | Customization for session. Use the lenses to update:
@ -112,7 +118,7 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
-- | Run 'Session' actions in a new TensorFlow session created with
-- the given option setter actions ('sessionTarget', 'sessionConfig').
runSessionWithOptions :: Options -> Session a -> IO a
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
runSessionWithOptions options (Session m) =
FFI.withSession applyOptions $
\as rs ->
@ -122,14 +128,14 @@ runSessionWithOptions options (Session m) =
FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
instance MonadBuild Session where
instance Monad m => MonadBuild (SessionT m) where
build = Session . lift . build
-- | Add all pending rendered nodes to the TensorFlow graph and runs
-- any pending initializers.
--
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
extend :: Session ()
extend :: MonadIO m => SessionT m ()
extend = do
session <- Session (asks rawSession)
trace <- Session (asks tracer)
@ -145,13 +151,13 @@ extend = do
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'.
run :: Fetchable t a => t -> Session a
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
run = runWithFeeds []
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'.
runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
runWithFeeds feeds t = do
ns <- build $ getNodes t
-- Note that this call to "fetch" shouldn't affect the following "extend"
@ -160,7 +166,7 @@ runWithFeeds feeds t = do
fetch <- build $ getFetch t
runFetchWithFeeds feeds ns fetch
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds feeds target (Fetch fetch restore) = do
extend
let feeds' = fixFeeds feeds
@ -180,14 +186,14 @@ toNodeNames = map (encodeUtf8 . unNodeName)
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
-- already rendered. This behaves like 'run' except that it doesn't do any
-- fetches.
run_ :: Nodes t => t -> Session ()
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
run_ = runWithFeeds_ []
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
-- any fetches.
runWithFeeds_ :: Nodes t => [Feed] -> t -> Session ()
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
runWithFeeds_ feeds t = do
ns <- build $ getNodes t
runFetchWithFeeds feeds ns (pure ())
@ -199,9 +205,9 @@ fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
-- forever until runSession exits or an exception occurs. Graph
-- extension happens synchronously, but the resultant run proceeds as
-- a separate thread.
asyncProdNodes :: Nodes t
asyncProdNodes :: (MonadIO m, Nodes t)
=> t -- ^ Node to evaluate concurrently.
-> Session ()
-> SessionT m ()
asyncProdNodes nodes = do
target <- build (getNodes nodes)
extend