mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Added sessionTracer to log graph operations.
This commit is contained in:
parent
630850c2d2
commit
2bf1783d3f
|
@ -24,9 +24,10 @@
|
|||
module TensorFlow.Core
|
||||
( -- * Session
|
||||
Session
|
||||
, SessionOption
|
||||
, Options
|
||||
, sessionConfig
|
||||
, sessionTarget
|
||||
, sessionTracer
|
||||
, runSession
|
||||
, runSessionWithOptions
|
||||
-- ** Building graphs
|
||||
|
|
|
@ -13,15 +13,17 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module TensorFlow.Session (
|
||||
Session,
|
||||
SessionOption,
|
||||
Options,
|
||||
sessionConfig,
|
||||
sessionTarget,
|
||||
sessionTracer,
|
||||
runSession,
|
||||
runSessionWithOptions,
|
||||
build,
|
||||
|
@ -41,29 +43,38 @@ import Control.Monad.IO.Class (MonadIO, liftIO)
|
|||
import Control.Monad.Trans.Class (lift)
|
||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Default (Default, def)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Data.Monoid ((<>))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import Data.Set (Set)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.ProtoLens (def)
|
||||
import Lens.Family2 ((&), (.~))
|
||||
import Data.ProtoLens (def, showMessage)
|
||||
import Lens.Family2 (Lens', (^.), (&), (.~))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||
|
||||
import TensorFlow.Build
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
import TensorFlow.Nodes
|
||||
import TensorFlow.Output (NodeName, unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
|
||||
import qualified Data.ByteString.Builder as Builder
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
|
||||
-- | An action for logging.
|
||||
type Tracer = Builder.Builder -> IO ()
|
||||
|
||||
-- Common state threaded through the session.
|
||||
data SessionState
|
||||
= SessionState {
|
||||
rawSession :: FFI.Session
|
||||
, asyncCollector :: IO () -> IO ()
|
||||
-- ^ Starts the given action concurrently.
|
||||
, tracer :: Tracer
|
||||
}
|
||||
|
||||
newtype Session a
|
||||
|
@ -72,30 +83,47 @@ newtype Session a
|
|||
|
||||
-- | Run 'Session' actions in a new TensorFlow session.
|
||||
runSession :: Session a -> IO a
|
||||
runSession = runSessionWithOptions []
|
||||
runSession = runSessionWithOptions def
|
||||
|
||||
-- | Setting of an option for the session (see 'runSessionWithOptions').
|
||||
-- Opaque value created via 'sessionConfig' and 'sessionTarget'.
|
||||
newtype SessionOption =
|
||||
SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () }
|
||||
-- | Customization for session. Use the lenses to update:
|
||||
-- 'sessionTarget', 'sessionTracer', 'sessionConfig'.
|
||||
data Options = Options
|
||||
{ _sessionTarget :: ByteString
|
||||
, _sessionConfig :: ConfigProto
|
||||
, _sessionTracer :: Tracer
|
||||
}
|
||||
|
||||
instance Default Options where
|
||||
def = Options
|
||||
{ _sessionTarget = ""
|
||||
, _sessionConfig = def
|
||||
, _sessionTracer = const (return ())
|
||||
}
|
||||
|
||||
-- | Target can be: "local", ip:port, host:port.
|
||||
-- The set of supported factories depends on the linked in libraries.
|
||||
-- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary.
|
||||
sessionTarget :: ByteString -> SessionOption
|
||||
sessionTarget = SessionOption . FFI.setSessionTarget
|
||||
sessionTarget :: Lens' Options ByteString
|
||||
sessionTarget = lens _sessionTarget (\g x -> g { _sessionTarget = x })
|
||||
|
||||
-- | Uses the specified config for the created session.
|
||||
sessionConfig :: ConfigProto -> SessionOption
|
||||
sessionConfig = SessionOption . FFI.setSessionConfig
|
||||
sessionConfig :: Lens' Options ConfigProto
|
||||
sessionConfig = lens _sessionConfig (\g x -> g { _sessionConfig = x })
|
||||
|
||||
-- | Uses the given logger to monitor session progress.
|
||||
sessionTracer :: Lens' Options Tracer
|
||||
sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
|
||||
|
||||
-- | Run 'Session' actions in a new TensorFlow session created with
|
||||
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
||||
runSessionWithOptions :: [SessionOption] -> Session a -> IO a
|
||||
runSessionWithOptions :: Options -> Session a -> IO a
|
||||
runSessionWithOptions options (Session m) =
|
||||
FFI.withSession applyOptions $
|
||||
\as rs -> evalBuildT (runReaderT m (SessionState rs as))
|
||||
where applyOptions opt = mapM_ (`unSesssionOption` opt) options
|
||||
\as rs -> do
|
||||
let initState = SessionState rs as (options ^. sessionTracer)
|
||||
evalBuildT (runReaderT m initState)
|
||||
where applyOptions opt = do
|
||||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||
|
||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||
-- renderings.
|
||||
|
@ -116,15 +144,16 @@ buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
|
|||
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
||||
extend :: Session ()
|
||||
extend = do
|
||||
let withSessionWhen vs action =
|
||||
unless (null vs) $ Session (asks rawSession) >>= action
|
||||
session <- Session (asks rawSession)
|
||||
trace <- Session (asks tracer)
|
||||
nodesToExtend <- build flushNodeBuffer
|
||||
withSessionWhen nodesToExtend $ \session ->
|
||||
liftIO $ FFI.extendGraph session
|
||||
$ def & node .~ nodesToExtend
|
||||
unless (null nodesToExtend) $ liftIO $ do
|
||||
let graphDef = def & node .~ nodesToExtend
|
||||
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
||||
FFI.extendGraph session graphDef
|
||||
-- Now that all the nodes are created, run the initializers.
|
||||
initializers <- build flushInitializers
|
||||
withSessionWhen initializers $ \session ->
|
||||
unless (null initializers) $
|
||||
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
||||
|
||||
-- | Helper combinator for doing something with the result of a 'Build' action.
|
||||
|
|
Loading…
Reference in New Issue
Block a user