mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-22 19:09:43 +01:00
Add support for logging to tensorboard (#74)
Add support for logging to tensorboard Based on @gnezdo's internal version with some differences: * Uses a pure haskell implementation of EventWriter instead of FFI. * Special `buildAnd*` functions were dropped in favor of using `mergeAllSummaries :: Build SummaryTensor` with the normal `build` function.
This commit is contained in:
parent
dca49d8993
commit
b3c0997a8c
10 changed files with 303 additions and 13 deletions
|
@ -4,6 +4,7 @@ packages:
|
|||
- google-shim
|
||||
- tensorflow
|
||||
- tensorflow-core-ops
|
||||
- tensorflow-logging
|
||||
- tensorflow-opgen
|
||||
- tensorflow-ops
|
||||
- tensorflow-proto
|
||||
|
|
3
tensorflow-logging/Setup.hs
Normal file
3
tensorflow-logging/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
|||
import Distribution.Simple
|
||||
|
||||
main = defaultMain
|
159
tensorflow-logging/src/TensorFlow/Logging.hs
Normal file
159
tensorflow-logging/src/TensorFlow/Logging.hs
Normal file
|
@ -0,0 +1,159 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- | TensorBoard Summary generation. Provides type safe wrappers around raw
|
||||
-- string emitting CoreOps.
|
||||
--
|
||||
-- Example use:
|
||||
--
|
||||
-- > -- Call summary functions while constructing the graph.
|
||||
-- > createModel = do
|
||||
-- > loss <- -- ...
|
||||
-- > TF.scalarSummary loss
|
||||
-- >
|
||||
-- > -- Write summaries to an EventWriter.
|
||||
-- > train = TF.withEventWriter "/path/to/logs" $ \eventWriter -> do
|
||||
-- > summaryTensor <- TF.build TF.allSummaries
|
||||
-- > forM_ [1..] $ \step -> do
|
||||
-- > if (step % 100 == 0)
|
||||
-- > then do
|
||||
-- > ((), summaryBytes) <- TF.run (trainStep, summaryTensor)
|
||||
-- > let summary = decodeMessageOrDie (TF.unScalar summaryBytes)
|
||||
-- > TF.logSummary eventWriter step summary
|
||||
-- > else TF.run_ trainStep
|
||||
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
|
||||
module TensorFlow.Logging
|
||||
( EventWriter
|
||||
, withEventWriter
|
||||
, logEvent
|
||||
, logSummary
|
||||
, SummaryTensor
|
||||
, histogramSummary
|
||||
, scalarSummary
|
||||
, mergeAllSummaries
|
||||
) where
|
||||
|
||||
import Control.Concurrent (forkFinally)
|
||||
import Control.Concurrent.MVar (MVar, newEmptyMVar, readMVar, putMVar)
|
||||
import Control.Concurrent.STM (atomically)
|
||||
import Control.Concurrent.STM.TBMQueue (TBMQueue, newTBMQueueIO, closeTBMQueue, writeTBMQueue)
|
||||
import Control.Monad.Catch (MonadMask, bracket)
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Control.Monad.Trans.Resource (runResourceT)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Conduit ((=$=))
|
||||
import Data.Conduit.TQueue (sourceTBMQueue)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int64)
|
||||
import Data.ProtoLens (encodeMessage)
|
||||
import Data.Time.Clock (getCurrentTime)
|
||||
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import Network.HostName (getHostName)
|
||||
import Proto.Tensorflow.Core.Framework.Summary (Summary)
|
||||
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
|
||||
import System.Directory (createDirectoryIfMissing)
|
||||
import System.FilePath ((</>))
|
||||
import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries)
|
||||
import TensorFlow.Ops (scalar)
|
||||
import TensorFlow.Records.Conduit (sinkTFRecords)
|
||||
import TensorFlow.Tensor (Tensor)
|
||||
import TensorFlow.Types (TensorType, type(/=))
|
||||
import Text.Printf (printf)
|
||||
import qualified Data.ByteString.Lazy as L
|
||||
import qualified Data.Conduit as Conduit
|
||||
import qualified Data.Conduit.List as Conduit
|
||||
import qualified Data.Text as T
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
-- | Handle for logging TensorBoard events safely from multiple threads.
|
||||
data EventWriter = EventWriter (TBMQueue Event) (MVar ())
|
||||
|
||||
-- | Writes Event protocol buffers to event files.
|
||||
withEventWriter ::
|
||||
(MonadIO m, MonadMask m)
|
||||
=> FilePath
|
||||
-- ^ logdir. Local filesystem directory where event file will be written.
|
||||
-> (EventWriter -> m a)
|
||||
-> m a
|
||||
withEventWriter logdir =
|
||||
bracket (liftIO (newEventWriter logdir)) (liftIO . closeEventWriter)
|
||||
|
||||
newEventWriter :: FilePath -> IO EventWriter
|
||||
newEventWriter logdir = do
|
||||
createDirectoryIfMissing True logdir
|
||||
t <- doubleWallTime
|
||||
hostname <- getHostName
|
||||
let filename = printf (logdir </> "events.out.tfevents.%010d.%s")
|
||||
(truncate t :: Integer) hostname
|
||||
-- Asynchronously consume events from a queue.
|
||||
-- We use a bounded queue to ensure the producer doesn't get too far ahead
|
||||
-- of the consumer. The buffer size was picked arbitrarily.
|
||||
q <- newTBMQueueIO 1024
|
||||
-- Use an MVar to signal that the worker thread has completed.
|
||||
done <- newEmptyMVar
|
||||
let writer = EventWriter q done
|
||||
consumeQueue = runResourceT $ Conduit.runConduit $
|
||||
sourceTBMQueue q
|
||||
=$= Conduit.map (L.fromStrict . encodeMessage)
|
||||
=$= sinkTFRecords filename
|
||||
_ <- forkFinally consumeQueue (\_ -> putMVar done ())
|
||||
logEvent writer $ def & wallTime .~ t
|
||||
& fileVersion .~ T.pack "brain.Event:2"
|
||||
return writer
|
||||
|
||||
closeEventWriter :: EventWriter -> IO ()
|
||||
closeEventWriter (EventWriter q done) =
|
||||
atomically (closeTBMQueue q) >> readMVar done
|
||||
|
||||
-- | Logs the given Event protocol buffer.
|
||||
logEvent :: MonadIO m => EventWriter -> Event -> m ()
|
||||
logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb))
|
||||
|
||||
-- | Logs the given Summary event with an optional global step (use 0 if not
|
||||
-- applicable).
|
||||
logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m ()
|
||||
logSummary writer step' summaryProto = do
|
||||
t <- liftIO doubleWallTime
|
||||
logEvent writer (def & wallTime .~ t
|
||||
& step .~ step'
|
||||
& summary .~ summaryProto
|
||||
)
|
||||
|
||||
-- Number of seconds since epoch.
|
||||
doubleWallTime :: IO Double
|
||||
doubleWallTime = asDouble <$> getCurrentTime
|
||||
where asDouble t = fromRational (toRational (utcTimeToPOSIXSeconds t))
|
||||
|
||||
-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally
|
||||
-- limited to a single value for simplicity.
|
||||
histogramSummary ::
|
||||
(TensorType t, t /= ByteString, t /= Bool)
|
||||
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
||||
=> ByteString -> Tensor v t -> Build ()
|
||||
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
|
||||
|
||||
-- | Adds a 'CoreOps.scalarSummary' node.
|
||||
scalarSummary ::
|
||||
(TensorType t, t /= ByteString, t /= Bool)
|
||||
-- (TensorType t,
|
||||
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
||||
=> ByteString -> Tensor v t -> Build ()
|
||||
scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag)
|
||||
|
||||
-- | Merge all summaries accumulated in the 'Build' into one summary.
|
||||
mergeAllSummaries :: Build SummaryTensor
|
||||
mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary
|
66
tensorflow-logging/tensorflow-logging.cabal
Normal file
66
tensorflow-logging/tensorflow-logging.cabal
Normal file
|
@ -0,0 +1,66 @@
|
|||
name: tensorflow-logging
|
||||
version: 0.1.0.0
|
||||
synopsis: TensorBoard related functionality.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Simple
|
||||
cabal-version: >=1.22
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.Logging
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, bytestring
|
||||
, conduit
|
||||
, data-default
|
||||
, directory
|
||||
, exceptions
|
||||
, filepath
|
||||
, hostname
|
||||
, lens-family
|
||||
, proto-lens == 0.1.*
|
||||
, resourcet
|
||||
, stm
|
||||
, stm-chans
|
||||
, stm-conduit
|
||||
, tensorflow == 0.1.*
|
||||
, tensorflow-core-ops == 0.1.*
|
||||
, tensorflow-ops == 0.1.*
|
||||
, tensorflow-proto == 0.1.*
|
||||
, tensorflow-records-conduit
|
||||
, text
|
||||
, time
|
||||
, transformers
|
||||
default-language: Haskell2010
|
||||
|
||||
Test-Suite LoggingTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: LoggingTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, bytestring
|
||||
, conduit
|
||||
, data-default
|
||||
, directory
|
||||
, filepath
|
||||
, lens-family
|
||||
, proto-lens
|
||||
, resourcet
|
||||
, temporary
|
||||
, tensorflow-logging
|
||||
, tensorflow-proto
|
||||
, tensorflow-records-conduit
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, text
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
64
tensorflow-logging/tests/LoggingTest.hs
Normal file
64
tensorflow-logging/tests/LoggingTest.hs
Normal file
|
@ -0,0 +1,64 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- | Tests for TensorFlow.Logging.
|
||||
module Main where
|
||||
|
||||
import Control.Monad.Trans.Resource (runResourceT)
|
||||
import Data.Conduit ((=$=))
|
||||
import Data.Default (def)
|
||||
import Data.List ((\\))
|
||||
import Data.ProtoLens (decodeMessageOrDie)
|
||||
import Lens.Family2 ((^.), (.~), (&))
|
||||
import Proto.Tensorflow.Core.Util.Event (fileVersion, step)
|
||||
import System.Directory (getDirectoryContents)
|
||||
import System.FilePath ((</>))
|
||||
import System.IO.Temp (withSystemTempDirectory)
|
||||
import TensorFlow.Logging (withEventWriter, logEvent)
|
||||
import TensorFlow.Records.Conduit (sourceTFRecords)
|
||||
import Test.Framework (defaultMain, Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit (assertBool, assertEqual)
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import qualified Data.Conduit as Conduit
|
||||
import qualified Data.Conduit.List as Conduit
|
||||
import qualified Data.Text as T
|
||||
|
||||
-- TODO: This has been added to System.Directory in newer versions.
|
||||
listDirectory :: String -> IO [String]
|
||||
listDirectory dir = (\\ [".", ".."]) <$> getDirectoryContents dir
|
||||
|
||||
testEventWriter :: Test
|
||||
testEventWriter = testCase "EventWriter" $
|
||||
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||
assertEqual "No file before" [] =<< listDirectory dir
|
||||
let expected = [ def & step .~ 10
|
||||
, def & step .~ 222
|
||||
, def & step .~ 8
|
||||
]
|
||||
withEventWriter dir $ \eventWriter ->
|
||||
mapM_ (logEvent eventWriter) expected
|
||||
files <- listDirectory dir
|
||||
assertEqual "One file exists after" 1 (length files)
|
||||
records <- runResourceT $ Conduit.runConduit $
|
||||
sourceTFRecords (dir </> head files) =$= Conduit.consume
|
||||
assertBool "File is not empty" (not (null records))
|
||||
let (header:body) = decodeMessageOrDie . BL.toStrict <$> records
|
||||
assertEqual "Header has expected version"
|
||||
(T.pack "brain.Event:2") (header ^. fileVersion)
|
||||
assertEqual "Body has expected records" expected body
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain [ testEventWriter
|
||||
]
|
|
@ -13,6 +13,7 @@ cabal-version: >=1.22
|
|||
extra-source-files: ../third_party/tensorflow/tensorflow/core/framework/*.proto
|
||||
, ../third_party/tensorflow/tensorflow/core/protobuf/config.proto
|
||||
, ../third_party/tensorflow/tensorflow/core/protobuf/debug.proto
|
||||
, ../third_party/tensorflow/tensorflow/core/util/event.proto
|
||||
|
||||
library
|
||||
exposed-modules: Proto.Tensorflow.Core.Framework.AttrValue
|
||||
|
@ -20,10 +21,12 @@ library
|
|||
, Proto.Tensorflow.Core.Framework.NodeDef
|
||||
, Proto.Tensorflow.Core.Framework.OpDef
|
||||
, Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||
, Proto.Tensorflow.Core.Framework.Summary
|
||||
, Proto.Tensorflow.Core.Framework.Tensor
|
||||
, Proto.Tensorflow.Core.Framework.TensorShape
|
||||
, Proto.Tensorflow.Core.Framework.Types
|
||||
, Proto.Tensorflow.Core.Protobuf.Config
|
||||
, Proto.Tensorflow.Core.Util.Event
|
||||
other-modules: Proto.Tensorflow.Core.Framework.AllocationDescription
|
||||
, Proto.Tensorflow.Core.Framework.CostGraph
|
||||
, Proto.Tensorflow.Core.Framework.Function
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
module TensorFlow.Build
|
||||
( -- * Graph node types
|
||||
ControlNode(..)
|
||||
|
@ -61,6 +61,7 @@ module TensorFlow.Build
|
|||
, collectAllSummaries
|
||||
) where
|
||||
|
||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||
import Control.Monad.IO.Class (MonadIO(..))
|
||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||
|
@ -196,7 +197,7 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
|
|||
-- Used to manage build state internally as part of the @Session@ monad.
|
||||
newtype BuildT m a = BuildT (StateT GraphState m a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
||||
MonadState GraphState)
|
||||
MonadState GraphState, MonadThrow, MonadCatch, MonadMask)
|
||||
|
||||
-- | An action for building nodes in a TensorFlow graph.
|
||||
type Build = BuildT Identity
|
||||
|
|
|
@ -33,7 +33,6 @@ module TensorFlow.Core
|
|||
-- ** Building graphs
|
||||
, build
|
||||
, buildAnd
|
||||
, buildWithSummary
|
||||
-- ** Running graphs
|
||||
, Fetchable
|
||||
, Nodes
|
||||
|
|
|
@ -28,7 +28,6 @@ module TensorFlow.Session (
|
|||
runSessionWithOptions,
|
||||
build,
|
||||
buildAnd,
|
||||
buildWithSummary,
|
||||
extend,
|
||||
addGraphDef,
|
||||
run,
|
||||
|
@ -39,6 +38,7 @@ module TensorFlow.Session (
|
|||
) where
|
||||
|
||||
import Control.Monad (forever, unless, void)
|
||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Control.Monad.Trans.Class (lift)
|
||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||
|
@ -77,7 +77,8 @@ data SessionState
|
|||
|
||||
newtype Session a
|
||||
= Session (ReaderT SessionState (BuildT IO) a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO)
|
||||
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
||||
MonadMask)
|
||||
|
||||
-- | Run 'Session' actions in a new TensorFlow session.
|
||||
runSession :: Session a -> IO a
|
||||
|
@ -128,14 +129,6 @@ runSessionWithOptions options (Session m) =
|
|||
build :: Build a -> Session a
|
||||
build = Session . lift . hoistBuildT (return . runIdentity)
|
||||
|
||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||
-- renderings. Returns the merged summary ops which can be used for
|
||||
-- logging, see 'TensorFlow.Logging.build' for a convenient wrapper.
|
||||
buildWithSummary :: forall a . Build a -> Session (a, [SummaryTensor])
|
||||
buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
|
||||
where v :: BuildT IO a
|
||||
v = hoistBuildT (return . runIdentity) b
|
||||
|
||||
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
||||
-- any pending initializers.
|
||||
--
|
||||
|
|
|
@ -37,6 +37,7 @@ library
|
|||
, bytestring
|
||||
, containers
|
||||
, data-default
|
||||
, exceptions
|
||||
, fgl
|
||||
, lens-family
|
||||
, mainland-pretty
|
||||
|
|
Loading…
Reference in a new issue