Add support for logging to tensorboard (#74)

Add support for logging to tensorboard

Based on @gnezdo's internal version with some differences:

* Uses a pure haskell implementation of EventWriter instead of FFI.
* Special `buildAnd*` functions were dropped in favor of using
  `mergeAllSummaries :: Build SummaryTensor` with the normal
  `build` function.
This commit is contained in:
fkm3 2017-02-20 19:16:42 -08:00 committed by GitHub
parent dca49d8993
commit b3c0997a8c
10 changed files with 303 additions and 13 deletions

View File

@ -4,6 +4,7 @@ packages:
- google-shim
- tensorflow
- tensorflow-core-ops
- tensorflow-logging
- tensorflow-opgen
- tensorflow-ops
- tensorflow-proto

View File

@ -0,0 +1,3 @@
import Distribution.Simple
main = defaultMain

View File

@ -0,0 +1,159 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | TensorBoard Summary generation. Provides type safe wrappers around raw
-- string emitting CoreOps.
--
-- Example use:
--
-- > -- Call summary functions while constructing the graph.
-- > createModel = do
-- > loss <- -- ...
-- > TF.scalarSummary loss
-- >
-- > -- Write summaries to an EventWriter.
-- > train = TF.withEventWriter "/path/to/logs" $ \eventWriter -> do
-- > summaryTensor <- TF.build TF.allSummaries
-- > forM_ [1..] $ \step -> do
-- > if (step % 100 == 0)
-- > then do
-- > ((), summaryBytes) <- TF.run (trainStep, summaryTensor)
-- > let summary = decodeMessageOrDie (TF.unScalar summaryBytes)
-- > TF.logSummary eventWriter step summary
-- > else TF.run_ trainStep
{-# LANGUAGE TypeOperators #-}
module TensorFlow.Logging
( EventWriter
, withEventWriter
, logEvent
, logSummary
, SummaryTensor
, histogramSummary
, scalarSummary
, mergeAllSummaries
) where
import Control.Concurrent (forkFinally)
import Control.Concurrent.MVar (MVar, newEmptyMVar, readMVar, putMVar)
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TBMQueue (TBMQueue, newTBMQueueIO, closeTBMQueue, writeTBMQueue)
import Control.Monad.Catch (MonadMask, bracket)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Resource (runResourceT)
import Data.ByteString (ByteString)
import Data.Conduit ((=$=))
import Data.Conduit.TQueue (sourceTBMQueue)
import Data.Default (def)
import Data.Int (Int64)
import Data.ProtoLens (encodeMessage)
import Data.Time.Clock (getCurrentTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import Lens.Family2 ((.~), (&))
import Network.HostName (getHostName)
import Proto.Tensorflow.Core.Framework.Summary (Summary)
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
import System.Directory (createDirectoryIfMissing)
import System.FilePath ((</>))
import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries)
import TensorFlow.Ops (scalar)
import TensorFlow.Records.Conduit (sinkTFRecords)
import TensorFlow.Tensor (Tensor)
import TensorFlow.Types (TensorType, type(/=))
import Text.Printf (printf)
import qualified Data.ByteString.Lazy as L
import qualified Data.Conduit as Conduit
import qualified Data.Conduit.List as Conduit
import qualified Data.Text as T
import qualified TensorFlow.GenOps.Core as CoreOps
-- | Handle for logging TensorBoard events safely from multiple threads.
data EventWriter = EventWriter (TBMQueue Event) (MVar ())
-- | Writes Event protocol buffers to event files.
withEventWriter ::
(MonadIO m, MonadMask m)
=> FilePath
-- ^ logdir. Local filesystem directory where event file will be written.
-> (EventWriter -> m a)
-> m a
withEventWriter logdir =
bracket (liftIO (newEventWriter logdir)) (liftIO . closeEventWriter)
newEventWriter :: FilePath -> IO EventWriter
newEventWriter logdir = do
createDirectoryIfMissing True logdir
t <- doubleWallTime
hostname <- getHostName
let filename = printf (logdir </> "events.out.tfevents.%010d.%s")
(truncate t :: Integer) hostname
-- Asynchronously consume events from a queue.
-- We use a bounded queue to ensure the producer doesn't get too far ahead
-- of the consumer. The buffer size was picked arbitrarily.
q <- newTBMQueueIO 1024
-- Use an MVar to signal that the worker thread has completed.
done <- newEmptyMVar
let writer = EventWriter q done
consumeQueue = runResourceT $ Conduit.runConduit $
sourceTBMQueue q
=$= Conduit.map (L.fromStrict . encodeMessage)
=$= sinkTFRecords filename
_ <- forkFinally consumeQueue (\_ -> putMVar done ())
logEvent writer $ def & wallTime .~ t
& fileVersion .~ T.pack "brain.Event:2"
return writer
closeEventWriter :: EventWriter -> IO ()
closeEventWriter (EventWriter q done) =
atomically (closeTBMQueue q) >> readMVar done
-- | Logs the given Event protocol buffer.
logEvent :: MonadIO m => EventWriter -> Event -> m ()
logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb))
-- | Logs the given Summary event with an optional global step (use 0 if not
-- applicable).
logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m ()
logSummary writer step' summaryProto = do
t <- liftIO doubleWallTime
logEvent writer (def & wallTime .~ t
& step .~ step'
& summary .~ summaryProto
)
-- Number of seconds since epoch.
doubleWallTime :: IO Double
doubleWallTime = asDouble <$> getCurrentTime
where asDouble t = fromRational (toRational (utcTimeToPOSIXSeconds t))
-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally
-- limited to a single value for simplicity.
histogramSummary ::
(TensorType t, t /= ByteString, t /= Bool)
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
=> ByteString -> Tensor v t -> Build ()
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
-- | Adds a 'CoreOps.scalarSummary' node.
scalarSummary ::
(TensorType t, t /= ByteString, t /= Bool)
-- (TensorType t,
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
=> ByteString -> Tensor v t -> Build ()
scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag)
-- | Merge all summaries accumulated in the 'Build' into one summary.
mergeAllSummaries :: Build SummaryTensor
mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary

View File

@ -0,0 +1,66 @@
name: tensorflow-logging
version: 0.1.0.0
synopsis: TensorBoard related functionality.
description: Please see README.md
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Simple
cabal-version: >=1.22
library
hs-source-dirs: src
exposed-modules: TensorFlow.Logging
build-depends: base >= 4.7 && < 5
, bytestring
, conduit
, data-default
, directory
, exceptions
, filepath
, hostname
, lens-family
, proto-lens == 0.1.*
, resourcet
, stm
, stm-chans
, stm-conduit
, tensorflow == 0.1.*
, tensorflow-core-ops == 0.1.*
, tensorflow-ops == 0.1.*
, tensorflow-proto == 0.1.*
, tensorflow-records-conduit
, text
, time
, transformers
default-language: Haskell2010
Test-Suite LoggingTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: LoggingTest.hs
hs-source-dirs: tests
build-depends: HUnit
, base
, bytestring
, conduit
, data-default
, directory
, filepath
, lens-family
, proto-lens
, resourcet
, temporary
, tensorflow-logging
, tensorflow-proto
, tensorflow-records-conduit
, test-framework
, test-framework-hunit
, text
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View File

@ -0,0 +1,64 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | Tests for TensorFlow.Logging.
module Main where
import Control.Monad.Trans.Resource (runResourceT)
import Data.Conduit ((=$=))
import Data.Default (def)
import Data.List ((\\))
import Data.ProtoLens (decodeMessageOrDie)
import Lens.Family2 ((^.), (.~), (&))
import Proto.Tensorflow.Core.Util.Event (fileVersion, step)
import System.Directory (getDirectoryContents)
import System.FilePath ((</>))
import System.IO.Temp (withSystemTempDirectory)
import TensorFlow.Logging (withEventWriter, logEvent)
import TensorFlow.Records.Conduit (sourceTFRecords)
import Test.Framework (defaultMain, Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit (assertBool, assertEqual)
import qualified Data.ByteString.Lazy as BL
import qualified Data.Conduit as Conduit
import qualified Data.Conduit.List as Conduit
import qualified Data.Text as T
-- TODO: This has been added to System.Directory in newer versions.
listDirectory :: String -> IO [String]
listDirectory dir = (\\ [".", ".."]) <$> getDirectoryContents dir
testEventWriter :: Test
testEventWriter = testCase "EventWriter" $
withSystemTempDirectory "event_writer_logs" $ \dir -> do
assertEqual "No file before" [] =<< listDirectory dir
let expected = [ def & step .~ 10
, def & step .~ 222
, def & step .~ 8
]
withEventWriter dir $ \eventWriter ->
mapM_ (logEvent eventWriter) expected
files <- listDirectory dir
assertEqual "One file exists after" 1 (length files)
records <- runResourceT $ Conduit.runConduit $
sourceTFRecords (dir </> head files) =$= Conduit.consume
assertBool "File is not empty" (not (null records))
let (header:body) = decodeMessageOrDie . BL.toStrict <$> records
assertEqual "Header has expected version"
(T.pack "brain.Event:2") (header ^. fileVersion)
assertEqual "Body has expected records" expected body
main :: IO ()
main = defaultMain [ testEventWriter
]

View File

@ -13,6 +13,7 @@ cabal-version: >=1.22
extra-source-files: ../third_party/tensorflow/tensorflow/core/framework/*.proto
, ../third_party/tensorflow/tensorflow/core/protobuf/config.proto
, ../third_party/tensorflow/tensorflow/core/protobuf/debug.proto
, ../third_party/tensorflow/tensorflow/core/util/event.proto
library
exposed-modules: Proto.Tensorflow.Core.Framework.AttrValue
@ -20,10 +21,12 @@ library
, Proto.Tensorflow.Core.Framework.NodeDef
, Proto.Tensorflow.Core.Framework.OpDef
, Proto.Tensorflow.Core.Framework.ResourceHandle
, Proto.Tensorflow.Core.Framework.Summary
, Proto.Tensorflow.Core.Framework.Tensor
, Proto.Tensorflow.Core.Framework.TensorShape
, Proto.Tensorflow.Core.Framework.Types
, Proto.Tensorflow.Core.Protobuf.Config
, Proto.Tensorflow.Core.Util.Event
other-modules: Proto.Tensorflow.Core.Framework.AllocationDescription
, Proto.Tensorflow.Core.Framework.CostGraph
, Proto.Tensorflow.Core.Framework.Function

View File

@ -14,8 +14,8 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
module TensorFlow.Build
( -- * Graph node types
ControlNode(..)
@ -61,6 +61,7 @@ module TensorFlow.Build
, collectAllSummaries
) where
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
@ -196,7 +197,7 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
-- Used to manage build state internally as part of the @Session@ monad.
newtype BuildT m a = BuildT (StateT GraphState m a)
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
MonadState GraphState)
MonadState GraphState, MonadThrow, MonadCatch, MonadMask)
-- | An action for building nodes in a TensorFlow graph.
type Build = BuildT Identity

View File

@ -33,7 +33,6 @@ module TensorFlow.Core
-- ** Building graphs
, build
, buildAnd
, buildWithSummary
-- ** Running graphs
, Fetchable
, Nodes

View File

@ -28,7 +28,6 @@ module TensorFlow.Session (
runSessionWithOptions,
build,
buildAnd,
buildWithSummary,
extend,
addGraphDef,
run,
@ -39,6 +38,7 @@ module TensorFlow.Session (
) where
import Control.Monad (forever, unless, void)
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
@ -77,7 +77,8 @@ data SessionState
newtype Session a
= Session (ReaderT SessionState (BuildT IO) a)
deriving (Functor, Applicative, Monad, MonadIO)
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
MonadMask)
-- | Run 'Session' actions in a new TensorFlow session.
runSession :: Session a -> IO a
@ -128,14 +129,6 @@ runSessionWithOptions options (Session m) =
build :: Build a -> Session a
build = Session . lift . hoistBuildT (return . runIdentity)
-- | Lift a 'Build' action into a 'Session', including any explicit op
-- renderings. Returns the merged summary ops which can be used for
-- logging, see 'TensorFlow.Logging.build' for a convenient wrapper.
buildWithSummary :: forall a . Build a -> Session (a, [SummaryTensor])
buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
where v :: BuildT IO a
v = hoistBuildT (return . runIdentity) b
-- | Add all pending rendered nodes to the TensorFlow graph and runs
-- any pending initializers.
--

View File

@ -37,6 +37,7 @@ library
, bytestring
, containers
, data-default
, exceptions
, fgl
, lens-family
, mainland-pretty