1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00

Added sessionTracer to log graph operations.

This commit is contained in:
Greg Steuck 2016-11-11 17:01:15 -08:00
parent 630850c2d2
commit 2bf1783d3f
2 changed files with 55 additions and 25 deletions

View File

@ -24,9 +24,10 @@
module TensorFlow.Core
( -- * Session
Session
, SessionOption
, Options
, sessionConfig
, sessionTarget
, sessionTracer
, runSession
, runSessionWithOptions
-- ** Building graphs

View File

@ -13,15 +13,17 @@
-- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.Session (
Session,
SessionOption,
Options,
sessionConfig,
sessionTarget,
sessionTracer,
runSession,
runSessionWithOptions,
build,
@ -41,29 +43,38 @@ import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Default (Default, def)
import Data.Functor.Identity (runIdentity)
import Data.Monoid ((<>))
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8)
import Data.ProtoLens (def)
import Lens.Family2 ((&), (.~))
import Data.ProtoLens (def, showMessage)
import Lens.Family2 (Lens', (^.), (&), (.~))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build
import qualified TensorFlow.Internal.FFI as FFI
import qualified TensorFlow.Internal.Raw as Raw
import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Tensor
import qualified Data.ByteString.Builder as Builder
import qualified TensorFlow.Internal.FFI as FFI
import qualified TensorFlow.Internal.Raw as Raw
-- | An action for logging.
type Tracer = Builder.Builder -> IO ()
-- Common state threaded through the session.
data SessionState
= SessionState {
rawSession :: FFI.Session
, asyncCollector :: IO () -> IO ()
-- ^ Starts the given action concurrently.
, tracer :: Tracer
}
newtype Session a
@ -72,30 +83,47 @@ newtype Session a
-- | Run 'Session' actions in a new TensorFlow session.
runSession :: Session a -> IO a
runSession = runSessionWithOptions []
runSession = runSessionWithOptions def
-- | Setting of an option for the session (see 'runSessionWithOptions').
-- Opaque value created via 'sessionConfig' and 'sessionTarget'.
newtype SessionOption =
SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () }
-- | Customization for session. Use the lenses to update:
-- 'sessionTarget', 'sessionTracer', 'sessionConfig'.
data Options = Options
{ _sessionTarget :: ByteString
, _sessionConfig :: ConfigProto
, _sessionTracer :: Tracer
}
instance Default Options where
def = Options
{ _sessionTarget = ""
, _sessionConfig = def
, _sessionTracer = const (return ())
}
-- | Target can be: "local", ip:port, host:port.
-- The set of supported factories depends on the linked in libraries.
-- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary.
sessionTarget :: ByteString -> SessionOption
sessionTarget = SessionOption . FFI.setSessionTarget
sessionTarget :: Lens' Options ByteString
sessionTarget = lens _sessionTarget (\g x -> g { _sessionTarget = x })
-- | Uses the specified config for the created session.
sessionConfig :: ConfigProto -> SessionOption
sessionConfig = SessionOption . FFI.setSessionConfig
sessionConfig :: Lens' Options ConfigProto
sessionConfig = lens _sessionConfig (\g x -> g { _sessionConfig = x })
-- | Uses the given logger to monitor session progress.
sessionTracer :: Lens' Options Tracer
sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
-- | Run 'Session' actions in a new TensorFlow session created with
-- the given option setter actions ('sessionTarget', 'sessionConfig').
runSessionWithOptions :: [SessionOption] -> Session a -> IO a
runSessionWithOptions :: Options -> Session a -> IO a
runSessionWithOptions options (Session m) =
FFI.withSession applyOptions $
\as rs -> evalBuildT (runReaderT m (SessionState rs as))
where applyOptions opt = mapM_ (`unSesssionOption` opt) options
\as rs -> do
let initState = SessionState rs as (options ^. sessionTracer)
evalBuildT (runReaderT m initState)
where applyOptions opt = do
FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
-- | Lift a 'Build' action into a 'Session', including any explicit op
-- renderings.
@ -116,15 +144,16 @@ buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
extend :: Session ()
extend = do
let withSessionWhen vs action =
unless (null vs) $ Session (asks rawSession) >>= action
session <- Session (asks rawSession)
trace <- Session (asks tracer)
nodesToExtend <- build flushNodeBuffer
withSessionWhen nodesToExtend $ \session ->
liftIO $ FFI.extendGraph session
$ def & node .~ nodesToExtend
unless (null nodesToExtend) $ liftIO $ do
let graphDef = def & node .~ nodesToExtend
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
FFI.extendGraph session graphDef
-- Now that all the nodes are created, run the initializers.
initializers <- build flushInitializers
withSessionWhen initializers $ \session ->
unless (null initializers) $
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
-- | Helper combinator for doing something with the result of a 'Build' action.