Added sessionTracer to log graph operations. (#26)

* Added TracingTest.
This commit is contained in:
Greg Steuck 2016-11-14 15:14:51 -08:00 committed by GitHub
parent 630850c2d2
commit 0d4f5a9628
4 changed files with 119 additions and 25 deletions

View File

@ -166,6 +166,21 @@ Test-Suite MiscTest
, test-framework , test-framework
, test-framework-hunit , test-framework-hunit
Test-Suite TracingTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: TracingTest.hs
hs-source-dirs: tests
build-depends: HUnit
, base
, bytestring
, data-default
, lens-family
, tensorflow
, tensorflow-ops
, test-framework
, test-framework-hunit
Test-Suite TypesTest Test-Suite TypesTest
default-language: Haskell2010 default-language: Haskell2010
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0

View File

@ -0,0 +1,49 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
-- | Testing tracing.
module Main where
import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar)
import Data.ByteString.Builder (toLazyByteString)
import Data.ByteString.Lazy (isPrefixOf)
import Data.Default (def)
import Data.Monoid ((<>))
import Lens.Family2 ((&), (.~))
import Test.Framework (defaultMain)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), assertBool, assertFailure)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Ops as TF
testTracing :: IO ()
testTracing = do
-- Verifies that tracing happens as a side-effect of graph extension.
loggedValue <- newEmptyMVar
TF.runSessionWithOptions
(def & TF.sessionTracer .~ putMVar loggedValue)
(TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float))))
tryReadMVar loggedValue >>=
maybe (assertFailure "Logging never happened") expectedFormat
where expectedFormat x =
let got = toLazyByteString x in
assertBool ("Unexpected log entry " ++ show got)
("Session.extend" `isPrefixOf` got)
main = defaultMain
[ testCase "Tracing" testTracing
]

View File

@ -24,9 +24,10 @@
module TensorFlow.Core module TensorFlow.Core
( -- * Session ( -- * Session
Session Session
, SessionOption , Options
, sessionConfig , sessionConfig
, sessionTarget , sessionTarget
, sessionTracer
, runSession , runSession
, runSessionWithOptions , runSessionWithOptions
-- ** Building graphs -- ** Building graphs

View File

@ -13,15 +13,17 @@
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-} {-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-} {-# LANGUAGE TupleSections #-}
module TensorFlow.Session ( module TensorFlow.Session (
Session, Session,
SessionOption, Options,
sessionConfig, sessionConfig,
sessionTarget, sessionTarget,
sessionTracer,
runSession, runSession,
runSessionWithOptions, runSessionWithOptions,
build, build,
@ -41,29 +43,38 @@ import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Data.Default (Default, def)
import Data.Functor.Identity (runIdentity) import Data.Functor.Identity (runIdentity)
import Data.Monoid ((<>))
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import qualified Data.Set as Set import qualified Data.Set as Set
import Data.Set (Set) import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8) import Data.Text.Encoding (encodeUtf8)
import Data.ProtoLens (def) import Data.ProtoLens (def, showMessage)
import Lens.Family2 ((&), (.~)) import Lens.Family2 (Lens', (^.), (&), (.~))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph (node) import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build import TensorFlow.Build
import qualified TensorFlow.Internal.FFI as FFI
import qualified TensorFlow.Internal.Raw as Raw
import TensorFlow.Nodes import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName) import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Tensor import TensorFlow.Tensor
import qualified Data.ByteString.Builder as Builder
import qualified TensorFlow.Internal.FFI as FFI
import qualified TensorFlow.Internal.Raw as Raw
-- | An action for logging.
type Tracer = Builder.Builder -> IO ()
-- Common state threaded through the session. -- Common state threaded through the session.
data SessionState data SessionState
= SessionState { = SessionState {
rawSession :: FFI.Session rawSession :: FFI.Session
, asyncCollector :: IO () -> IO () , asyncCollector :: IO () -> IO ()
-- ^ Starts the given action concurrently. -- ^ Starts the given action concurrently.
, tracer :: Tracer
} }
newtype Session a newtype Session a
@ -72,30 +83,47 @@ newtype Session a
-- | Run 'Session' actions in a new TensorFlow session. -- | Run 'Session' actions in a new TensorFlow session.
runSession :: Session a -> IO a runSession :: Session a -> IO a
runSession = runSessionWithOptions [] runSession = runSessionWithOptions def
-- | Setting of an option for the session (see 'runSessionWithOptions'). -- | Customization for session. Use the lenses to update:
-- Opaque value created via 'sessionConfig' and 'sessionTarget'. -- 'sessionTarget', 'sessionTracer', 'sessionConfig'.
newtype SessionOption = data Options = Options
SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () } { _sessionTarget :: ByteString
, _sessionConfig :: ConfigProto
, _sessionTracer :: Tracer
}
instance Default Options where
def = Options
{ _sessionTarget = ""
, _sessionConfig = def
, _sessionTracer = const (return ())
}
-- | Target can be: "local", ip:port, host:port. -- | Target can be: "local", ip:port, host:port.
-- The set of supported factories depends on the linked in libraries. -- The set of supported factories depends on the linked in libraries.
-- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary. sessionTarget :: Lens' Options ByteString
sessionTarget :: ByteString -> SessionOption sessionTarget = lens _sessionTarget (\g x -> g { _sessionTarget = x })
sessionTarget = SessionOption . FFI.setSessionTarget
-- | Uses the specified config for the created session. -- | Uses the specified config for the created session.
sessionConfig :: ConfigProto -> SessionOption sessionConfig :: Lens' Options ConfigProto
sessionConfig = SessionOption . FFI.setSessionConfig sessionConfig = lens _sessionConfig (\g x -> g { _sessionConfig = x })
-- | Uses the given logger to monitor session progress.
sessionTracer :: Lens' Options Tracer
sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
-- | Run 'Session' actions in a new TensorFlow session created with -- | Run 'Session' actions in a new TensorFlow session created with
-- the given option setter actions ('sessionTarget', 'sessionConfig'). -- the given option setter actions ('sessionTarget', 'sessionConfig').
runSessionWithOptions :: [SessionOption] -> Session a -> IO a runSessionWithOptions :: Options -> Session a -> IO a
runSessionWithOptions options (Session m) = runSessionWithOptions options (Session m) =
FFI.withSession applyOptions $ FFI.withSession applyOptions $
\as rs -> evalBuildT (runReaderT m (SessionState rs as)) \as rs ->
where applyOptions opt = mapM_ (`unSesssionOption` opt) options let initState = SessionState rs as (options ^. sessionTracer)
in evalBuildT (runReaderT m initState)
where applyOptions opt = do
FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
-- | Lift a 'Build' action into a 'Session', including any explicit op -- | Lift a 'Build' action into a 'Session', including any explicit op
-- renderings. -- renderings.
@ -116,15 +144,16 @@ buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
-- Note that run, runWithFeeds, etc. will all call this function implicitly. -- Note that run, runWithFeeds, etc. will all call this function implicitly.
extend :: Session () extend :: Session ()
extend = do extend = do
let withSessionWhen vs action = session <- Session (asks rawSession)
unless (null vs) $ Session (asks rawSession) >>= action trace <- Session (asks tracer)
nodesToExtend <- build flushNodeBuffer nodesToExtend <- build flushNodeBuffer
withSessionWhen nodesToExtend $ \session -> unless (null nodesToExtend) $ liftIO $ do
liftIO $ FFI.extendGraph session let graphDef = def & node .~ nodesToExtend
$ def & node .~ nodesToExtend trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
FFI.extendGraph session graphDef
-- Now that all the nodes are created, run the initializers. -- Now that all the nodes are created, run the initializers.
initializers <- build flushInitializers initializers <- build flushInitializers
withSessionWhen initializers $ \session -> unless (null initializers) $
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers) void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
-- | Helper combinator for doing something with the result of a 'Build' action. -- | Helper combinator for doing something with the result of a 'Build' action.