1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-07-01 08:58:34 +02:00

Added logGraph for graph visualization in TensorBoard (#104)

This commit is contained in:
Christian Berentsen 2017-06-20 05:53:55 +02:00 committed by Greg Steuck
parent 423b34537e
commit 042910b000
3 changed files with 33 additions and 5 deletions

View File

@ -39,6 +39,7 @@ module TensorFlow.Logging
( EventWriter ( EventWriter
, withEventWriter , withEventWriter
, logEvent , logEvent
, logGraph
, logSummary , logSummary
, SummaryTensor , SummaryTensor
, histogramSummary , histogramSummary
@ -64,10 +65,10 @@ import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import Lens.Family2 ((.~), (&)) import Lens.Family2 ((.~), (&))
import Network.HostName (getHostName) import Network.HostName (getHostName)
import Proto.Tensorflow.Core.Framework.Summary (Summary) import Proto.Tensorflow.Core.Framework.Summary (Summary)
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime) import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, graphDef, step, summary, wallTime)
import System.Directory (createDirectoryIfMissing) import System.Directory (createDirectoryIfMissing)
import System.FilePath ((</>)) import System.FilePath ((</>))
import TensorFlow.Build (MonadBuild) import TensorFlow.Build (MonadBuild, Build, asGraphDef)
import TensorFlow.Ops (scalar) import TensorFlow.Ops (scalar)
import TensorFlow.Records.Conduit (sinkTFRecords) import TensorFlow.Records.Conduit (sinkTFRecords)
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries) import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
@ -123,6 +124,14 @@ closeEventWriter (EventWriter q done) =
logEvent :: MonadIO m => EventWriter -> Event -> m () logEvent :: MonadIO m => EventWriter -> Event -> m ()
logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb)) logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb))
-- | Logs the graph for the given 'Build' action.
logGraph :: MonadIO m => EventWriter -> Build a -> m ()
logGraph writer build = do
let graph = asGraphDef build
graphBytes = encodeMessage graph
graphEvent = (def :: Event) & graphDef .~ graphBytes
logEvent writer graphEvent
-- | Logs the given Summary event with an optional global step (use 0 if not -- | Logs the given Summary event with an optional global step (use 0 if not
-- applicable). -- applicable).
logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m () logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m ()
@ -133,6 +142,7 @@ logSummary writer step' summaryProto = do
& summary .~ summaryProto & summary .~ summaryProto
) )
-- Number of seconds since epoch. -- Number of seconds since epoch.
doubleWallTime :: IO Double doubleWallTime :: IO Double
doubleWallTime = asDouble <$> getCurrentTime doubleWallTime = asDouble <$> getCurrentTime

View File

@ -55,6 +55,7 @@ Test-Suite LoggingTest
, proto-lens , proto-lens
, resourcet , resourcet
, temporary , temporary
, tensorflow
, tensorflow-logging , tensorflow-logging
, tensorflow-proto , tensorflow-proto
, tensorflow-records-conduit , tensorflow-records-conduit

View File

@ -19,13 +19,14 @@ import Control.Monad.Trans.Resource (runResourceT)
import Data.Conduit ((=$=)) import Data.Conduit ((=$=))
import Data.Default (def) import Data.Default (def)
import Data.List ((\\)) import Data.List ((\\))
import Data.ProtoLens (decodeMessageOrDie) import Data.ProtoLens (encodeMessage, decodeMessageOrDie)
import Lens.Family2 ((^.), (.~), (&)) import Lens.Family2 ((^.), (.~), (&))
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step) import Proto.Tensorflow.Core.Util.Event (Event, graphDef, fileVersion, step)
import System.Directory (getDirectoryContents) import System.Directory (getDirectoryContents)
import System.FilePath ((</>)) import System.FilePath ((</>))
import System.IO.Temp (withSystemTempDirectory) import System.IO.Temp (withSystemTempDirectory)
import TensorFlow.Logging (withEventWriter, logEvent) import TensorFlow.Core (Build, ControlNode, asGraphDef, noOp)
import TensorFlow.Logging (withEventWriter, logEvent, logGraph)
import TensorFlow.Records.Conduit (sourceTFRecords) import TensorFlow.Records.Conduit (sourceTFRecords)
import Test.Framework (defaultMain, Test) import Test.Framework (defaultMain, Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
@ -59,6 +60,22 @@ testEventWriter = testCase "EventWriter" $
(T.pack "brain.Event:2") (header ^. fileVersion) (T.pack "brain.Event:2") (header ^. fileVersion)
assertEqual "Body has expected records" expected body assertEqual "Body has expected records" expected body
testLogGraph :: Test
testLogGraph = testCase "LogGraph" $
withSystemTempDirectory "event_writer_logs" $ \dir -> do
let graphBuild = noOp :: Build ControlNode
expectedGraph = asGraphDef graphBuild
expectedGraphEvent = (def :: Event) & graphDef .~ (encodeMessage expectedGraph)
withEventWriter dir $ \eventWriter ->
logGraph eventWriter graphBuild
files <- listDirectory dir
records <- runResourceT $ Conduit.runConduit $
sourceTFRecords (dir </> head files) =$= Conduit.consume
let (_:event:_) = decodeMessageOrDie . BL.toStrict <$> records
assertEqual "First record expected to be Event containing GraphDef" expectedGraphEvent event
main :: IO () main :: IO ()
main = defaultMain [ testEventWriter main = defaultMain [ testEventWriter
, testLogGraph
] ]