diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index d819788..3ab99b2 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -166,6 +166,21 @@ Test-Suite MiscTest , test-framework , test-framework-hunit +Test-Suite TracingTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: TracingTest.hs + hs-source-dirs: tests + build-depends: HUnit + , base + , bytestring + , data-default + , lens-family + , tensorflow + , tensorflow-ops + , test-framework + , test-framework-hunit + Test-Suite TypesTest default-language: Haskell2010 type: exitcode-stdio-1.0 diff --git a/tensorflow-ops/tests/TracingTest.hs b/tensorflow-ops/tests/TracingTest.hs new file mode 100644 index 0000000..dab3fe9 --- /dev/null +++ b/tensorflow-ops/tests/TracingTest.hs @@ -0,0 +1,49 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} + +-- | Testing tracing. +module Main where + +import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar) +import Data.ByteString.Builder (toLazyByteString) +import Data.ByteString.Lazy (isPrefixOf) +import Data.Default (def) +import Data.Monoid ((<>)) +import Lens.Family2 ((&), (.~)) +import Test.Framework (defaultMain) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@=?), assertBool, assertFailure) + +import qualified TensorFlow.Core as TF +import qualified TensorFlow.Ops as TF + +testTracing :: IO () +testTracing = do + -- Verifies that tracing happens as a side-effect of graph extension. + loggedValue <- newEmptyMVar + TF.runSessionWithOptions + (def & TF.sessionTracer .~ putMVar loggedValue) + (TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float)))) + tryReadMVar loggedValue >>= + maybe (assertFailure "Logging never happened") expectedFormat + where expectedFormat x = + let got = toLazyByteString x in + assertBool ("Unexpected log entry " ++ show got) + ("Session.extend" `isPrefixOf` got) + +main = defaultMain + [ testCase "Tracing" testTracing + ] diff --git a/tensorflow/src/TensorFlow/Core.hs b/tensorflow/src/TensorFlow/Core.hs index a7a4ae7..0fc0590 100644 --- a/tensorflow/src/TensorFlow/Core.hs +++ b/tensorflow/src/TensorFlow/Core.hs @@ -24,9 +24,10 @@ module TensorFlow.Core ( -- * Session Session - , SessionOption + , Options , sessionConfig , sessionTarget + , sessionTracer , runSession , runSessionWithOptions -- ** Building graphs diff --git a/tensorflow/src/TensorFlow/Session.hs b/tensorflow/src/TensorFlow/Session.hs index ec3f806..f032985 100644 --- a/tensorflow/src/TensorFlow/Session.hs +++ b/tensorflow/src/TensorFlow/Session.hs @@ -13,15 +13,17 @@ -- limitations under the License. {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module TensorFlow.Session ( Session, - SessionOption, + Options, sessionConfig, sessionTarget, + sessionTracer, runSession, runSessionWithOptions, build, @@ -41,29 +43,38 @@ import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks) import Data.ByteString (ByteString) +import Data.Default (Default, def) import Data.Functor.Identity (runIdentity) +import Data.Monoid ((<>)) import qualified Data.Map.Strict as Map import qualified Data.Set as Set import Data.Set (Set) import Data.Text.Encoding (encodeUtf8) -import Data.ProtoLens (def) -import Lens.Family2 ((&), (.~)) +import Data.ProtoLens (def, showMessage) +import Lens.Family2 (Lens', (^.), (&), (.~)) +import Lens.Family2.Unchecked (lens) import Proto.Tensorflow.Core.Framework.Graph (node) import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) import TensorFlow.Build -import qualified TensorFlow.Internal.FFI as FFI -import qualified TensorFlow.Internal.Raw as Raw import TensorFlow.Nodes import TensorFlow.Output (NodeName, unNodeName) import TensorFlow.Tensor +import qualified Data.ByteString.Builder as Builder +import qualified TensorFlow.Internal.FFI as FFI +import qualified TensorFlow.Internal.Raw as Raw + +-- | An action for logging. +type Tracer = Builder.Builder -> IO () + -- Common state threaded through the session. data SessionState = SessionState { rawSession :: FFI.Session , asyncCollector :: IO () -> IO () -- ^ Starts the given action concurrently. + , tracer :: Tracer } newtype Session a @@ -72,30 +83,47 @@ newtype Session a -- | Run 'Session' actions in a new TensorFlow session. runSession :: Session a -> IO a -runSession = runSessionWithOptions [] +runSession = runSessionWithOptions def --- | Setting of an option for the session (see 'runSessionWithOptions'). --- Opaque value created via 'sessionConfig' and 'sessionTarget'. -newtype SessionOption = - SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () } +-- | Customization for session. Use the lenses to update: +-- 'sessionTarget', 'sessionTracer', 'sessionConfig'. +data Options = Options + { _sessionTarget :: ByteString + , _sessionConfig :: ConfigProto + , _sessionTracer :: Tracer + } + +instance Default Options where + def = Options + { _sessionTarget = "" + , _sessionConfig = def + , _sessionTracer = const (return ()) + } -- | Target can be: "local", ip:port, host:port. -- The set of supported factories depends on the linked in libraries. --- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary. -sessionTarget :: ByteString -> SessionOption -sessionTarget = SessionOption . FFI.setSessionTarget +sessionTarget :: Lens' Options ByteString +sessionTarget = lens _sessionTarget (\g x -> g { _sessionTarget = x }) -- | Uses the specified config for the created session. -sessionConfig :: ConfigProto -> SessionOption -sessionConfig = SessionOption . FFI.setSessionConfig +sessionConfig :: Lens' Options ConfigProto +sessionConfig = lens _sessionConfig (\g x -> g { _sessionConfig = x }) + +-- | Uses the given logger to monitor session progress. +sessionTracer :: Lens' Options Tracer +sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x }) -- | Run 'Session' actions in a new TensorFlow session created with -- the given option setter actions ('sessionTarget', 'sessionConfig'). -runSessionWithOptions :: [SessionOption] -> Session a -> IO a +runSessionWithOptions :: Options -> Session a -> IO a runSessionWithOptions options (Session m) = FFI.withSession applyOptions $ - \as rs -> evalBuildT (runReaderT m (SessionState rs as)) - where applyOptions opt = mapM_ (`unSesssionOption` opt) options + \as rs -> + let initState = SessionState rs as (options ^. sessionTracer) + in evalBuildT (runReaderT m initState) + where applyOptions opt = do + FFI.setSessionTarget (options ^. sessionTarget) opt + FFI.setSessionConfig (options ^. sessionConfig) opt -- | Lift a 'Build' action into a 'Session', including any explicit op -- renderings. @@ -116,15 +144,16 @@ buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries -- Note that run, runWithFeeds, etc. will all call this function implicitly. extend :: Session () extend = do - let withSessionWhen vs action = - unless (null vs) $ Session (asks rawSession) >>= action + session <- Session (asks rawSession) + trace <- Session (asks tracer) nodesToExtend <- build flushNodeBuffer - withSessionWhen nodesToExtend $ \session -> - liftIO $ FFI.extendGraph session - $ def & node .~ nodesToExtend + unless (null nodesToExtend) $ liftIO $ do + let graphDef = def & node .~ nodesToExtend + trace ("Session.extend " <> Builder.string8 (showMessage graphDef)) + FFI.extendGraph session graphDef -- Now that all the nodes are created, run the initializers. initializers <- build flushInitializers - withSessionWhen initializers $ \session -> + unless (null initializers) $ void $ liftIO $ FFI.run session [] [] (toNodeNames initializers) -- | Helper combinator for doing something with the result of a 'Build' action.