mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Added sessionTracer to log graph operations. (#26)
* Added TracingTest.
This commit is contained in:
parent
630850c2d2
commit
0d4f5a9628
4 changed files with 119 additions and 25 deletions
|
@ -166,6 +166,21 @@ Test-Suite MiscTest
|
||||||
, test-framework
|
, test-framework
|
||||||
, test-framework-hunit
|
, test-framework-hunit
|
||||||
|
|
||||||
|
Test-Suite TracingTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: TracingTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, base
|
||||||
|
, bytestring
|
||||||
|
, data-default
|
||||||
|
, lens-family
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-ops
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
|
||||||
Test-Suite TypesTest
|
Test-Suite TypesTest
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
type: exitcode-stdio-1.0
|
type: exitcode-stdio-1.0
|
||||||
|
|
49
tensorflow-ops/tests/TracingTest.hs
Normal file
49
tensorflow-ops/tests/TracingTest.hs
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
-- | Testing tracing.
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar)
|
||||||
|
import Data.ByteString.Builder (toLazyByteString)
|
||||||
|
import Data.ByteString.Lazy (isPrefixOf)
|
||||||
|
import Data.Default (def)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Lens.Family2 ((&), (.~))
|
||||||
|
import Test.Framework (defaultMain)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@=?), assertBool, assertFailure)
|
||||||
|
|
||||||
|
import qualified TensorFlow.Core as TF
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
|
||||||
|
testTracing :: IO ()
|
||||||
|
testTracing = do
|
||||||
|
-- Verifies that tracing happens as a side-effect of graph extension.
|
||||||
|
loggedValue <- newEmptyMVar
|
||||||
|
TF.runSessionWithOptions
|
||||||
|
(def & TF.sessionTracer .~ putMVar loggedValue)
|
||||||
|
(TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float))))
|
||||||
|
tryReadMVar loggedValue >>=
|
||||||
|
maybe (assertFailure "Logging never happened") expectedFormat
|
||||||
|
where expectedFormat x =
|
||||||
|
let got = toLazyByteString x in
|
||||||
|
assertBool ("Unexpected log entry " ++ show got)
|
||||||
|
("Session.extend" `isPrefixOf` got)
|
||||||
|
|
||||||
|
main = defaultMain
|
||||||
|
[ testCase "Tracing" testTracing
|
||||||
|
]
|
|
@ -24,9 +24,10 @@
|
||||||
module TensorFlow.Core
|
module TensorFlow.Core
|
||||||
( -- * Session
|
( -- * Session
|
||||||
Session
|
Session
|
||||||
, SessionOption
|
, Options
|
||||||
, sessionConfig
|
, sessionConfig
|
||||||
, sessionTarget
|
, sessionTarget
|
||||||
|
, sessionTracer
|
||||||
, runSession
|
, runSession
|
||||||
, runSessionWithOptions
|
, runSessionWithOptions
|
||||||
-- ** Building graphs
|
-- ** Building graphs
|
||||||
|
|
|
@ -13,15 +13,17 @@
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE Rank2Types #-}
|
{-# LANGUAGE Rank2Types #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TupleSections #-}
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
|
||||||
module TensorFlow.Session (
|
module TensorFlow.Session (
|
||||||
Session,
|
Session,
|
||||||
SessionOption,
|
Options,
|
||||||
sessionConfig,
|
sessionConfig,
|
||||||
sessionTarget,
|
sessionTarget,
|
||||||
|
sessionTracer,
|
||||||
runSession,
|
runSession,
|
||||||
runSessionWithOptions,
|
runSessionWithOptions,
|
||||||
build,
|
build,
|
||||||
|
@ -41,29 +43,38 @@ import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
import Control.Monad.Trans.Class (lift)
|
import Control.Monad.Trans.Class (lift)
|
||||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||||
import Data.ByteString (ByteString)
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Default (Default, def)
|
||||||
import Data.Functor.Identity (runIdentity)
|
import Data.Functor.Identity (runIdentity)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
import qualified Data.Set as Set
|
import qualified Data.Set as Set
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
import Data.Text.Encoding (encodeUtf8)
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
import Data.ProtoLens (def)
|
import Data.ProtoLens (def, showMessage)
|
||||||
import Lens.Family2 ((&), (.~))
|
import Lens.Family2 (Lens', (^.), (&), (.~))
|
||||||
|
import Lens.Family2.Unchecked (lens)
|
||||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||||
|
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import qualified TensorFlow.Internal.FFI as FFI
|
|
||||||
import qualified TensorFlow.Internal.Raw as Raw
|
|
||||||
import TensorFlow.Nodes
|
import TensorFlow.Nodes
|
||||||
import TensorFlow.Output (NodeName, unNodeName)
|
import TensorFlow.Output (NodeName, unNodeName)
|
||||||
import TensorFlow.Tensor
|
import TensorFlow.Tensor
|
||||||
|
|
||||||
|
import qualified Data.ByteString.Builder as Builder
|
||||||
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
import qualified TensorFlow.Internal.Raw as Raw
|
||||||
|
|
||||||
|
-- | An action for logging.
|
||||||
|
type Tracer = Builder.Builder -> IO ()
|
||||||
|
|
||||||
-- Common state threaded through the session.
|
-- Common state threaded through the session.
|
||||||
data SessionState
|
data SessionState
|
||||||
= SessionState {
|
= SessionState {
|
||||||
rawSession :: FFI.Session
|
rawSession :: FFI.Session
|
||||||
, asyncCollector :: IO () -> IO ()
|
, asyncCollector :: IO () -> IO ()
|
||||||
-- ^ Starts the given action concurrently.
|
-- ^ Starts the given action concurrently.
|
||||||
|
, tracer :: Tracer
|
||||||
}
|
}
|
||||||
|
|
||||||
newtype Session a
|
newtype Session a
|
||||||
|
@ -72,30 +83,47 @@ newtype Session a
|
||||||
|
|
||||||
-- | Run 'Session' actions in a new TensorFlow session.
|
-- | Run 'Session' actions in a new TensorFlow session.
|
||||||
runSession :: Session a -> IO a
|
runSession :: Session a -> IO a
|
||||||
runSession = runSessionWithOptions []
|
runSession = runSessionWithOptions def
|
||||||
|
|
||||||
-- | Setting of an option for the session (see 'runSessionWithOptions').
|
-- | Customization for session. Use the lenses to update:
|
||||||
-- Opaque value created via 'sessionConfig' and 'sessionTarget'.
|
-- 'sessionTarget', 'sessionTracer', 'sessionConfig'.
|
||||||
newtype SessionOption =
|
data Options = Options
|
||||||
SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () }
|
{ _sessionTarget :: ByteString
|
||||||
|
, _sessionConfig :: ConfigProto
|
||||||
|
, _sessionTracer :: Tracer
|
||||||
|
}
|
||||||
|
|
||||||
|
instance Default Options where
|
||||||
|
def = Options
|
||||||
|
{ _sessionTarget = ""
|
||||||
|
, _sessionConfig = def
|
||||||
|
, _sessionTracer = const (return ())
|
||||||
|
}
|
||||||
|
|
||||||
-- | Target can be: "local", ip:port, host:port.
|
-- | Target can be: "local", ip:port, host:port.
|
||||||
-- The set of supported factories depends on the linked in libraries.
|
-- The set of supported factories depends on the linked in libraries.
|
||||||
-- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary.
|
sessionTarget :: Lens' Options ByteString
|
||||||
sessionTarget :: ByteString -> SessionOption
|
sessionTarget = lens _sessionTarget (\g x -> g { _sessionTarget = x })
|
||||||
sessionTarget = SessionOption . FFI.setSessionTarget
|
|
||||||
|
|
||||||
-- | Uses the specified config for the created session.
|
-- | Uses the specified config for the created session.
|
||||||
sessionConfig :: ConfigProto -> SessionOption
|
sessionConfig :: Lens' Options ConfigProto
|
||||||
sessionConfig = SessionOption . FFI.setSessionConfig
|
sessionConfig = lens _sessionConfig (\g x -> g { _sessionConfig = x })
|
||||||
|
|
||||||
|
-- | Uses the given logger to monitor session progress.
|
||||||
|
sessionTracer :: Lens' Options Tracer
|
||||||
|
sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
|
||||||
|
|
||||||
-- | Run 'Session' actions in a new TensorFlow session created with
|
-- | Run 'Session' actions in a new TensorFlow session created with
|
||||||
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
||||||
runSessionWithOptions :: [SessionOption] -> Session a -> IO a
|
runSessionWithOptions :: Options -> Session a -> IO a
|
||||||
runSessionWithOptions options (Session m) =
|
runSessionWithOptions options (Session m) =
|
||||||
FFI.withSession applyOptions $
|
FFI.withSession applyOptions $
|
||||||
\as rs -> evalBuildT (runReaderT m (SessionState rs as))
|
\as rs ->
|
||||||
where applyOptions opt = mapM_ (`unSesssionOption` opt) options
|
let initState = SessionState rs as (options ^. sessionTracer)
|
||||||
|
in evalBuildT (runReaderT m initState)
|
||||||
|
where applyOptions opt = do
|
||||||
|
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||||
|
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||||
|
|
||||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||||
-- renderings.
|
-- renderings.
|
||||||
|
@ -116,15 +144,16 @@ buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
|
||||||
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
||||||
extend :: Session ()
|
extend :: Session ()
|
||||||
extend = do
|
extend = do
|
||||||
let withSessionWhen vs action =
|
session <- Session (asks rawSession)
|
||||||
unless (null vs) $ Session (asks rawSession) >>= action
|
trace <- Session (asks tracer)
|
||||||
nodesToExtend <- build flushNodeBuffer
|
nodesToExtend <- build flushNodeBuffer
|
||||||
withSessionWhen nodesToExtend $ \session ->
|
unless (null nodesToExtend) $ liftIO $ do
|
||||||
liftIO $ FFI.extendGraph session
|
let graphDef = def & node .~ nodesToExtend
|
||||||
$ def & node .~ nodesToExtend
|
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
||||||
|
FFI.extendGraph session graphDef
|
||||||
-- Now that all the nodes are created, run the initializers.
|
-- Now that all the nodes are created, run the initializers.
|
||||||
initializers <- build flushInitializers
|
initializers <- build flushInitializers
|
||||||
withSessionWhen initializers $ \session ->
|
unless (null initializers) $
|
||||||
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
||||||
|
|
||||||
-- | Helper combinator for doing something with the result of a 'Build' action.
|
-- | Helper combinator for doing something with the result of a 'Build' action.
|
||||||
|
|
Loading…
Reference in a new issue