mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
parent
d79a919efa
commit
d8bf349962
2 changed files with 33 additions and 25 deletions
|
@ -33,8 +33,9 @@ module TensorFlow.Internal.FFI
|
||||||
|
|
||||||
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
||||||
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
||||||
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
|
||||||
import Control.Monad (when)
|
import Control.Monad (when)
|
||||||
|
import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_)
|
||||||
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
import Data.Bits (Bits, toIntegralSized)
|
import Data.Bits (Bits, toIntegralSized)
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
import Data.Maybe (fromMaybe)
|
import Data.Maybe (fromMaybe)
|
||||||
|
@ -75,13 +76,14 @@ data TensorData = TensorData
|
||||||
|
|
||||||
-- | Runs the given action after creating a session with options
|
-- | Runs the given action after creating a session with options
|
||||||
-- populated by the given optionSetter.
|
-- populated by the given optionSetter.
|
||||||
withSession :: (Raw.SessionOptions -> IO ())
|
withSession :: (MonadIO m, MonadMask m)
|
||||||
-> ((IO () -> IO ()) -> Raw.Session -> IO a)
|
=> (Raw.SessionOptions -> IO ())
|
||||||
|
-> ((IO () -> IO ()) -> Raw.Session -> m a)
|
||||||
-- ^ The action can spawn concurrent tasks which will
|
-- ^ The action can spawn concurrent tasks which will
|
||||||
-- be canceled before withSession returns.
|
-- be canceled before withSession returns.
|
||||||
-> IO a
|
-> m a
|
||||||
withSession optionSetter action = do
|
withSession optionSetter action = do
|
||||||
drain <- newMVar []
|
drain <- liftIO $ newMVar []
|
||||||
let cleanup s =
|
let cleanup s =
|
||||||
-- Closes the session to nudge the pending run calls to fail and exit.
|
-- Closes the session to nudge the pending run calls to fail and exit.
|
||||||
finally (checkStatus (Raw.closeSession s)) $ do
|
finally (checkStatus (Raw.closeSession s)) $ do
|
||||||
|
@ -89,10 +91,10 @@ withSession optionSetter action = do
|
||||||
-- Collects all runners before deleting the session.
|
-- Collects all runners before deleting the session.
|
||||||
mapM_ shutDownRunner runners
|
mapM_ shutDownRunner runners
|
||||||
checkStatus (Raw.deleteSession s)
|
checkStatus (Raw.deleteSession s)
|
||||||
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
let bracketIO x y = bracket (liftIO x) (liftIO . y)
|
||||||
optionSetter options
|
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
||||||
bracket
|
bracketIO
|
||||||
(checkStatus (Raw.newSession options))
|
(optionSetter options >> checkStatus (Raw.newSession options))
|
||||||
cleanup
|
cleanup
|
||||||
(action (asyncCollector drain))
|
(action (asyncCollector drain))
|
||||||
|
|
||||||
|
@ -225,7 +227,7 @@ checkStatus fn =
|
||||||
when (code /= Raw.TF_OK) $ do
|
when (code /= Raw.TF_OK) $ do
|
||||||
msg <- T.decodeUtf8With T.lenientDecode <$>
|
msg <- T.decodeUtf8With T.lenientDecode <$>
|
||||||
(Raw.message status >>= B.packCString)
|
(Raw.message status >>= B.packCString)
|
||||||
throwIO $ TensorFlowException code msg
|
throwM $ TensorFlowException code msg
|
||||||
return result
|
return result
|
||||||
|
|
||||||
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
|
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
|
||||||
|
@ -258,7 +260,7 @@ getAllOpList = do
|
||||||
where
|
where
|
||||||
checkCall = do
|
checkCall = do
|
||||||
p <- Raw.getAllOpList
|
p <- Raw.getAllOpList
|
||||||
when (p == nullPtr) (throwIO exception)
|
when (p == nullPtr) (throwM exception)
|
||||||
return p
|
return p
|
||||||
exception = TensorFlowException
|
exception = TensorFlowException
|
||||||
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"
|
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
module TensorFlow.Session (
|
module TensorFlow.Session (
|
||||||
Session,
|
Session,
|
||||||
|
SessionT,
|
||||||
Options,
|
Options,
|
||||||
sessionConfig,
|
sessionConfig,
|
||||||
sessionTarget,
|
sessionTarget,
|
||||||
|
@ -39,7 +40,7 @@ module TensorFlow.Session (
|
||||||
import Control.Monad (forever, unless, void)
|
import Control.Monad (forever, unless, void)
|
||||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
import Control.Monad.Trans.Class (lift)
|
import Control.Monad.Trans.Class (MonadTrans, lift)
|
||||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||||
import Data.ByteString (ByteString)
|
import Data.ByteString (ByteString)
|
||||||
import Data.Default (Default, def)
|
import Data.Default (Default, def)
|
||||||
|
@ -73,13 +74,18 @@ data SessionState
|
||||||
, tracer :: Tracer
|
, tracer :: Tracer
|
||||||
}
|
}
|
||||||
|
|
||||||
newtype Session a
|
newtype SessionT m a
|
||||||
= Session (ReaderT SessionState (BuildT IO) a)
|
= Session (ReaderT SessionState (BuildT m) a)
|
||||||
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
||||||
MonadMask)
|
MonadMask)
|
||||||
|
|
||||||
|
instance MonadTrans SessionT where
|
||||||
|
lift = Session . lift . lift
|
||||||
|
|
||||||
|
type Session = SessionT IO
|
||||||
|
|
||||||
-- | Run 'Session' actions in a new TensorFlow session.
|
-- | Run 'Session' actions in a new TensorFlow session.
|
||||||
runSession :: Session a -> IO a
|
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
|
||||||
runSession = runSessionWithOptions def
|
runSession = runSessionWithOptions def
|
||||||
|
|
||||||
-- | Customization for session. Use the lenses to update:
|
-- | Customization for session. Use the lenses to update:
|
||||||
|
@ -112,7 +118,7 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
|
||||||
|
|
||||||
-- | Run 'Session' actions in a new TensorFlow session created with
|
-- | Run 'Session' actions in a new TensorFlow session created with
|
||||||
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
||||||
runSessionWithOptions :: Options -> Session a -> IO a
|
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
|
||||||
runSessionWithOptions options (Session m) =
|
runSessionWithOptions options (Session m) =
|
||||||
FFI.withSession applyOptions $
|
FFI.withSession applyOptions $
|
||||||
\as rs ->
|
\as rs ->
|
||||||
|
@ -122,14 +128,14 @@ runSessionWithOptions options (Session m) =
|
||||||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||||
|
|
||||||
instance MonadBuild Session where
|
instance Monad m => MonadBuild (SessionT m) where
|
||||||
build = Session . lift . build
|
build = Session . lift . build
|
||||||
|
|
||||||
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
||||||
-- any pending initializers.
|
-- any pending initializers.
|
||||||
--
|
--
|
||||||
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
||||||
extend :: Session ()
|
extend :: MonadIO m => SessionT m ()
|
||||||
extend = do
|
extend = do
|
||||||
session <- Session (asks rawSession)
|
session <- Session (asks rawSession)
|
||||||
trace <- Session (asks tracer)
|
trace <- Session (asks tracer)
|
||||||
|
@ -145,13 +151,13 @@ extend = do
|
||||||
|
|
||||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||||
-- rendered, and fetch the corresponding values for 'a'.
|
-- rendered, and fetch the corresponding values for 'a'.
|
||||||
run :: Fetchable t a => t -> Session a
|
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
|
||||||
run = runWithFeeds []
|
run = runWithFeeds []
|
||||||
|
|
||||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||||
-- rendered, feed the given input values, and fetch the corresponding result
|
-- rendered, feed the given input values, and fetch the corresponding result
|
||||||
-- values for 'a'.
|
-- values for 'a'.
|
||||||
runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a
|
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
|
||||||
runWithFeeds feeds t = do
|
runWithFeeds feeds t = do
|
||||||
ns <- build $ getNodes t
|
ns <- build $ getNodes t
|
||||||
-- Note that this call to "fetch" shouldn't affect the following "extend"
|
-- Note that this call to "fetch" shouldn't affect the following "extend"
|
||||||
|
@ -160,7 +166,7 @@ runWithFeeds feeds t = do
|
||||||
fetch <- build $ getFetch t
|
fetch <- build $ getFetch t
|
||||||
runFetchWithFeeds feeds ns fetch
|
runFetchWithFeeds feeds ns fetch
|
||||||
|
|
||||||
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
|
runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a
|
||||||
runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
||||||
extend
|
extend
|
||||||
let feeds' = fixFeeds feeds
|
let feeds' = fixFeeds feeds
|
||||||
|
@ -180,14 +186,14 @@ toNodeNames = map (encodeUtf8 . unNodeName)
|
||||||
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
|
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
|
||||||
-- already rendered. This behaves like 'run' except that it doesn't do any
|
-- already rendered. This behaves like 'run' except that it doesn't do any
|
||||||
-- fetches.
|
-- fetches.
|
||||||
run_ :: Nodes t => t -> Session ()
|
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
|
||||||
run_ = runWithFeeds_ []
|
run_ = runWithFeeds_ []
|
||||||
|
|
||||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||||
-- rendered, feed the given input values, and fetch the corresponding result
|
-- rendered, feed the given input values, and fetch the corresponding result
|
||||||
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
|
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
|
||||||
-- any fetches.
|
-- any fetches.
|
||||||
runWithFeeds_ :: Nodes t => [Feed] -> t -> Session ()
|
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
|
||||||
runWithFeeds_ feeds t = do
|
runWithFeeds_ feeds t = do
|
||||||
ns <- build $ getNodes t
|
ns <- build $ getNodes t
|
||||||
runFetchWithFeeds feeds ns (pure ())
|
runFetchWithFeeds feeds ns (pure ())
|
||||||
|
@ -199,9 +205,9 @@ fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
|
||||||
-- forever until runSession exits or an exception occurs. Graph
|
-- forever until runSession exits or an exception occurs. Graph
|
||||||
-- extension happens synchronously, but the resultant run proceeds as
|
-- extension happens synchronously, but the resultant run proceeds as
|
||||||
-- a separate thread.
|
-- a separate thread.
|
||||||
asyncProdNodes :: Nodes t
|
asyncProdNodes :: (MonadIO m, Nodes t)
|
||||||
=> t -- ^ Node to evaluate concurrently.
|
=> t -- ^ Node to evaluate concurrently.
|
||||||
-> Session ()
|
-> SessionT m ()
|
||||||
asyncProdNodes nodes = do
|
asyncProdNodes nodes = do
|
||||||
target <- build (getNodes nodes)
|
target <- build (getNodes nodes)
|
||||||
extend
|
extend
|
||||||
|
|
Loading…
Reference in a new issue