mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
parent
d79a919efa
commit
d8bf349962
2 changed files with 33 additions and 25 deletions
|
@ -33,8 +33,9 @@ module TensorFlow.Internal.FFI
|
|||
|
||||
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
||||
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
||||
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
||||
import Control.Monad (when)
|
||||
import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_)
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Data.Bits (Bits, toIntegralSized)
|
||||
import Data.Int (Int64)
|
||||
import Data.Maybe (fromMaybe)
|
||||
|
@ -75,13 +76,14 @@ data TensorData = TensorData
|
|||
|
||||
-- | Runs the given action after creating a session with options
|
||||
-- populated by the given optionSetter.
|
||||
withSession :: (Raw.SessionOptions -> IO ())
|
||||
-> ((IO () -> IO ()) -> Raw.Session -> IO a)
|
||||
withSession :: (MonadIO m, MonadMask m)
|
||||
=> (Raw.SessionOptions -> IO ())
|
||||
-> ((IO () -> IO ()) -> Raw.Session -> m a)
|
||||
-- ^ The action can spawn concurrent tasks which will
|
||||
-- be canceled before withSession returns.
|
||||
-> IO a
|
||||
-> m a
|
||||
withSession optionSetter action = do
|
||||
drain <- newMVar []
|
||||
drain <- liftIO $ newMVar []
|
||||
let cleanup s =
|
||||
-- Closes the session to nudge the pending run calls to fail and exit.
|
||||
finally (checkStatus (Raw.closeSession s)) $ do
|
||||
|
@ -89,10 +91,10 @@ withSession optionSetter action = do
|
|||
-- Collects all runners before deleting the session.
|
||||
mapM_ shutDownRunner runners
|
||||
checkStatus (Raw.deleteSession s)
|
||||
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
||||
optionSetter options
|
||||
bracket
|
||||
(checkStatus (Raw.newSession options))
|
||||
let bracketIO x y = bracket (liftIO x) (liftIO . y)
|
||||
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
||||
bracketIO
|
||||
(optionSetter options >> checkStatus (Raw.newSession options))
|
||||
cleanup
|
||||
(action (asyncCollector drain))
|
||||
|
||||
|
@ -225,7 +227,7 @@ checkStatus fn =
|
|||
when (code /= Raw.TF_OK) $ do
|
||||
msg <- T.decodeUtf8With T.lenientDecode <$>
|
||||
(Raw.message status >>= B.packCString)
|
||||
throwIO $ TensorFlowException code msg
|
||||
throwM $ TensorFlowException code msg
|
||||
return result
|
||||
|
||||
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
|
||||
|
@ -258,7 +260,7 @@ getAllOpList = do
|
|||
where
|
||||
checkCall = do
|
||||
p <- Raw.getAllOpList
|
||||
when (p == nullPtr) (throwIO exception)
|
||||
when (p == nullPtr) (throwM exception)
|
||||
return p
|
||||
exception = TensorFlowException
|
||||
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
module TensorFlow.Session (
|
||||
Session,
|
||||
SessionT,
|
||||
Options,
|
||||
sessionConfig,
|
||||
sessionTarget,
|
||||
|
@ -39,7 +40,7 @@ module TensorFlow.Session (
|
|||
import Control.Monad (forever, unless, void)
|
||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Control.Monad.Trans.Class (lift)
|
||||
import Control.Monad.Trans.Class (MonadTrans, lift)
|
||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Default (Default, def)
|
||||
|
@ -73,13 +74,18 @@ data SessionState
|
|||
, tracer :: Tracer
|
||||
}
|
||||
|
||||
newtype Session a
|
||||
= Session (ReaderT SessionState (BuildT IO) a)
|
||||
newtype SessionT m a
|
||||
= Session (ReaderT SessionState (BuildT m) a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
||||
MonadMask)
|
||||
|
||||
instance MonadTrans SessionT where
|
||||
lift = Session . lift . lift
|
||||
|
||||
type Session = SessionT IO
|
||||
|
||||
-- | Run 'Session' actions in a new TensorFlow session.
|
||||
runSession :: Session a -> IO a
|
||||
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
|
||||
runSession = runSessionWithOptions def
|
||||
|
||||
-- | Customization for session. Use the lenses to update:
|
||||
|
@ -112,7 +118,7 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
|
|||
|
||||
-- | Run 'Session' actions in a new TensorFlow session created with
|
||||
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
||||
runSessionWithOptions :: Options -> Session a -> IO a
|
||||
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
|
||||
runSessionWithOptions options (Session m) =
|
||||
FFI.withSession applyOptions $
|
||||
\as rs ->
|
||||
|
@ -122,14 +128,14 @@ runSessionWithOptions options (Session m) =
|
|||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||
|
||||
instance MonadBuild Session where
|
||||
instance Monad m => MonadBuild (SessionT m) where
|
||||
build = Session . lift . build
|
||||
|
||||
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
||||
-- any pending initializers.
|
||||
--
|
||||
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
||||
extend :: Session ()
|
||||
extend :: MonadIO m => SessionT m ()
|
||||
extend = do
|
||||
session <- Session (asks rawSession)
|
||||
trace <- Session (asks tracer)
|
||||
|
@ -145,13 +151,13 @@ extend = do
|
|||
|
||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||
-- rendered, and fetch the corresponding values for 'a'.
|
||||
run :: Fetchable t a => t -> Session a
|
||||
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
|
||||
run = runWithFeeds []
|
||||
|
||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||
-- rendered, feed the given input values, and fetch the corresponding result
|
||||
-- values for 'a'.
|
||||
runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a
|
||||
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
|
||||
runWithFeeds feeds t = do
|
||||
ns <- build $ getNodes t
|
||||
-- Note that this call to "fetch" shouldn't affect the following "extend"
|
||||
|
@ -160,7 +166,7 @@ runWithFeeds feeds t = do
|
|||
fetch <- build $ getFetch t
|
||||
runFetchWithFeeds feeds ns fetch
|
||||
|
||||
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
|
||||
runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a
|
||||
runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
||||
extend
|
||||
let feeds' = fixFeeds feeds
|
||||
|
@ -180,14 +186,14 @@ toNodeNames = map (encodeUtf8 . unNodeName)
|
|||
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
|
||||
-- already rendered. This behaves like 'run' except that it doesn't do any
|
||||
-- fetches.
|
||||
run_ :: Nodes t => t -> Session ()
|
||||
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
|
||||
run_ = runWithFeeds_ []
|
||||
|
||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||
-- rendered, feed the given input values, and fetch the corresponding result
|
||||
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
|
||||
-- any fetches.
|
||||
runWithFeeds_ :: Nodes t => [Feed] -> t -> Session ()
|
||||
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
|
||||
runWithFeeds_ feeds t = do
|
||||
ns <- build $ getNodes t
|
||||
runFetchWithFeeds feeds ns (pure ())
|
||||
|
@ -199,9 +205,9 @@ fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
|
|||
-- forever until runSession exits or an exception occurs. Graph
|
||||
-- extension happens synchronously, but the resultant run proceeds as
|
||||
-- a separate thread.
|
||||
asyncProdNodes :: Nodes t
|
||||
asyncProdNodes :: (MonadIO m, Nodes t)
|
||||
=> t -- ^ Node to evaluate concurrently.
|
||||
-> Session ()
|
||||
-> SessionT m ()
|
||||
asyncProdNodes nodes = do
|
||||
target <- build (getNodes nodes)
|
||||
extend
|
||||
|
|
Loading…
Reference in a new issue