mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Added sessionTracer to log graph operations. (#26)
* Added TracingTest.
This commit is contained in:
parent
630850c2d2
commit
0d4f5a9628
4 changed files with 119 additions and 25 deletions
|
@ -166,6 +166,21 @@ Test-Suite MiscTest
|
|||
, test-framework
|
||||
, test-framework-hunit
|
||||
|
||||
Test-Suite TracingTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: TracingTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, bytestring
|
||||
, data-default
|
||||
, lens-family
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
|
||||
Test-Suite TypesTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
|
|
49
tensorflow-ops/tests/TracingTest.hs
Normal file
49
tensorflow-ops/tests/TracingTest.hs
Normal file
|
@ -0,0 +1,49 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
-- | Testing tracing.
|
||||
module Main where
|
||||
|
||||
import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar)
|
||||
import Data.ByteString.Builder (toLazyByteString)
|
||||
import Data.ByteString.Lazy (isPrefixOf)
|
||||
import Data.Default (def)
|
||||
import Data.Monoid ((<>))
|
||||
import Lens.Family2 ((&), (.~))
|
||||
import Test.Framework (defaultMain)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?), assertBool, assertFailure)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
|
||||
testTracing :: IO ()
|
||||
testTracing = do
|
||||
-- Verifies that tracing happens as a side-effect of graph extension.
|
||||
loggedValue <- newEmptyMVar
|
||||
TF.runSessionWithOptions
|
||||
(def & TF.sessionTracer .~ putMVar loggedValue)
|
||||
(TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float))))
|
||||
tryReadMVar loggedValue >>=
|
||||
maybe (assertFailure "Logging never happened") expectedFormat
|
||||
where expectedFormat x =
|
||||
let got = toLazyByteString x in
|
||||
assertBool ("Unexpected log entry " ++ show got)
|
||||
("Session.extend" `isPrefixOf` got)
|
||||
|
||||
main = defaultMain
|
||||
[ testCase "Tracing" testTracing
|
||||
]
|
|
@ -24,9 +24,10 @@
|
|||
module TensorFlow.Core
|
||||
( -- * Session
|
||||
Session
|
||||
, SessionOption
|
||||
, Options
|
||||
, sessionConfig
|
||||
, sessionTarget
|
||||
, sessionTracer
|
||||
, runSession
|
||||
, runSessionWithOptions
|
||||
-- ** Building graphs
|
||||
|
|
|
@ -13,15 +13,17 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module TensorFlow.Session (
|
||||
Session,
|
||||
SessionOption,
|
||||
Options,
|
||||
sessionConfig,
|
||||
sessionTarget,
|
||||
sessionTracer,
|
||||
runSession,
|
||||
runSessionWithOptions,
|
||||
build,
|
||||
|
@ -41,29 +43,38 @@ import Control.Monad.IO.Class (MonadIO, liftIO)
|
|||
import Control.Monad.Trans.Class (lift)
|
||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Default (Default, def)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Data.Monoid ((<>))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import Data.Set (Set)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.ProtoLens (def)
|
||||
import Lens.Family2 ((&), (.~))
|
||||
import Data.ProtoLens (def, showMessage)
|
||||
import Lens.Family2 (Lens', (^.), (&), (.~))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||
|
||||
import TensorFlow.Build
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
import TensorFlow.Nodes
|
||||
import TensorFlow.Output (NodeName, unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
|
||||
import qualified Data.ByteString.Builder as Builder
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
|
||||
-- | An action for logging.
|
||||
type Tracer = Builder.Builder -> IO ()
|
||||
|
||||
-- Common state threaded through the session.
|
||||
data SessionState
|
||||
= SessionState {
|
||||
rawSession :: FFI.Session
|
||||
, asyncCollector :: IO () -> IO ()
|
||||
-- ^ Starts the given action concurrently.
|
||||
, tracer :: Tracer
|
||||
}
|
||||
|
||||
newtype Session a
|
||||
|
@ -72,30 +83,47 @@ newtype Session a
|
|||
|
||||
-- | Run 'Session' actions in a new TensorFlow session.
|
||||
runSession :: Session a -> IO a
|
||||
runSession = runSessionWithOptions []
|
||||
runSession = runSessionWithOptions def
|
||||
|
||||
-- | Setting of an option for the session (see 'runSessionWithOptions').
|
||||
-- Opaque value created via 'sessionConfig' and 'sessionTarget'.
|
||||
newtype SessionOption =
|
||||
SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () }
|
||||
-- | Customization for session. Use the lenses to update:
|
||||
-- 'sessionTarget', 'sessionTracer', 'sessionConfig'.
|
||||
data Options = Options
|
||||
{ _sessionTarget :: ByteString
|
||||
, _sessionConfig :: ConfigProto
|
||||
, _sessionTracer :: Tracer
|
||||
}
|
||||
|
||||
instance Default Options where
|
||||
def = Options
|
||||
{ _sessionTarget = ""
|
||||
, _sessionConfig = def
|
||||
, _sessionTracer = const (return ())
|
||||
}
|
||||
|
||||
-- | Target can be: "local", ip:port, host:port.
|
||||
-- The set of supported factories depends on the linked in libraries.
|
||||
-- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary.
|
||||
sessionTarget :: ByteString -> SessionOption
|
||||
sessionTarget = SessionOption . FFI.setSessionTarget
|
||||
sessionTarget :: Lens' Options ByteString
|
||||
sessionTarget = lens _sessionTarget (\g x -> g { _sessionTarget = x })
|
||||
|
||||
-- | Uses the specified config for the created session.
|
||||
sessionConfig :: ConfigProto -> SessionOption
|
||||
sessionConfig = SessionOption . FFI.setSessionConfig
|
||||
sessionConfig :: Lens' Options ConfigProto
|
||||
sessionConfig = lens _sessionConfig (\g x -> g { _sessionConfig = x })
|
||||
|
||||
-- | Uses the given logger to monitor session progress.
|
||||
sessionTracer :: Lens' Options Tracer
|
||||
sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
|
||||
|
||||
-- | Run 'Session' actions in a new TensorFlow session created with
|
||||
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
||||
runSessionWithOptions :: [SessionOption] -> Session a -> IO a
|
||||
runSessionWithOptions :: Options -> Session a -> IO a
|
||||
runSessionWithOptions options (Session m) =
|
||||
FFI.withSession applyOptions $
|
||||
\as rs -> evalBuildT (runReaderT m (SessionState rs as))
|
||||
where applyOptions opt = mapM_ (`unSesssionOption` opt) options
|
||||
\as rs ->
|
||||
let initState = SessionState rs as (options ^. sessionTracer)
|
||||
in evalBuildT (runReaderT m initState)
|
||||
where applyOptions opt = do
|
||||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||
|
||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||
-- renderings.
|
||||
|
@ -116,15 +144,16 @@ buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
|
|||
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
||||
extend :: Session ()
|
||||
extend = do
|
||||
let withSessionWhen vs action =
|
||||
unless (null vs) $ Session (asks rawSession) >>= action
|
||||
session <- Session (asks rawSession)
|
||||
trace <- Session (asks tracer)
|
||||
nodesToExtend <- build flushNodeBuffer
|
||||
withSessionWhen nodesToExtend $ \session ->
|
||||
liftIO $ FFI.extendGraph session
|
||||
$ def & node .~ nodesToExtend
|
||||
unless (null nodesToExtend) $ liftIO $ do
|
||||
let graphDef = def & node .~ nodesToExtend
|
||||
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
||||
FFI.extendGraph session graphDef
|
||||
-- Now that all the nodes are created, run the initializers.
|
||||
initializers <- build flushInitializers
|
||||
withSessionWhen initializers $ \session ->
|
||||
unless (null initializers) $
|
||||
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
||||
|
||||
-- | Helper combinator for doing something with the result of a 'Build' action.
|
||||
|
|
Loading…
Reference in a new issue