mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-11 19:39:49 +01:00
Added logGraph for graph visualization in TensorBoard (#104)
This commit is contained in:
parent
423b34537e
commit
042910b000
3 changed files with 33 additions and 5 deletions
|
@ -39,6 +39,7 @@ module TensorFlow.Logging
|
|||
( EventWriter
|
||||
, withEventWriter
|
||||
, logEvent
|
||||
, logGraph
|
||||
, logSummary
|
||||
, SummaryTensor
|
||||
, histogramSummary
|
||||
|
@ -64,10 +65,10 @@ import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
|
|||
import Lens.Family2 ((.~), (&))
|
||||
import Network.HostName (getHostName)
|
||||
import Proto.Tensorflow.Core.Framework.Summary (Summary)
|
||||
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
|
||||
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, graphDef, step, summary, wallTime)
|
||||
import System.Directory (createDirectoryIfMissing)
|
||||
import System.FilePath ((</>))
|
||||
import TensorFlow.Build (MonadBuild)
|
||||
import TensorFlow.Build (MonadBuild, Build, asGraphDef)
|
||||
import TensorFlow.Ops (scalar)
|
||||
import TensorFlow.Records.Conduit (sinkTFRecords)
|
||||
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
|
||||
|
@ -123,6 +124,14 @@ closeEventWriter (EventWriter q done) =
|
|||
logEvent :: MonadIO m => EventWriter -> Event -> m ()
|
||||
logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb))
|
||||
|
||||
-- | Logs the graph for the given 'Build' action.
|
||||
logGraph :: MonadIO m => EventWriter -> Build a -> m ()
|
||||
logGraph writer build = do
|
||||
let graph = asGraphDef build
|
||||
graphBytes = encodeMessage graph
|
||||
graphEvent = (def :: Event) & graphDef .~ graphBytes
|
||||
logEvent writer graphEvent
|
||||
|
||||
-- | Logs the given Summary event with an optional global step (use 0 if not
|
||||
-- applicable).
|
||||
logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m ()
|
||||
|
@ -133,6 +142,7 @@ logSummary writer step' summaryProto = do
|
|||
& summary .~ summaryProto
|
||||
)
|
||||
|
||||
|
||||
-- Number of seconds since epoch.
|
||||
doubleWallTime :: IO Double
|
||||
doubleWallTime = asDouble <$> getCurrentTime
|
||||
|
|
|
@ -55,6 +55,7 @@ Test-Suite LoggingTest
|
|||
, proto-lens
|
||||
, resourcet
|
||||
, temporary
|
||||
, tensorflow
|
||||
, tensorflow-logging
|
||||
, tensorflow-proto
|
||||
, tensorflow-records-conduit
|
||||
|
|
|
@ -19,13 +19,14 @@ import Control.Monad.Trans.Resource (runResourceT)
|
|||
import Data.Conduit ((=$=))
|
||||
import Data.Default (def)
|
||||
import Data.List ((\\))
|
||||
import Data.ProtoLens (decodeMessageOrDie)
|
||||
import Data.ProtoLens (encodeMessage, decodeMessageOrDie)
|
||||
import Lens.Family2 ((^.), (.~), (&))
|
||||
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step)
|
||||
import Proto.Tensorflow.Core.Util.Event (Event, graphDef, fileVersion, step)
|
||||
import System.Directory (getDirectoryContents)
|
||||
import System.FilePath ((</>))
|
||||
import System.IO.Temp (withSystemTempDirectory)
|
||||
import TensorFlow.Logging (withEventWriter, logEvent)
|
||||
import TensorFlow.Core (Build, ControlNode, asGraphDef, noOp)
|
||||
import TensorFlow.Logging (withEventWriter, logEvent, logGraph)
|
||||
import TensorFlow.Records.Conduit (sourceTFRecords)
|
||||
import Test.Framework (defaultMain, Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
|
@ -59,6 +60,22 @@ testEventWriter = testCase "EventWriter" $
|
|||
(T.pack "brain.Event:2") (header ^. fileVersion)
|
||||
assertEqual "Body has expected records" expected body
|
||||
|
||||
testLogGraph :: Test
|
||||
testLogGraph = testCase "LogGraph" $
|
||||
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||
let graphBuild = noOp :: Build ControlNode
|
||||
expectedGraph = asGraphDef graphBuild
|
||||
expectedGraphEvent = (def :: Event) & graphDef .~ (encodeMessage expectedGraph)
|
||||
|
||||
withEventWriter dir $ \eventWriter ->
|
||||
logGraph eventWriter graphBuild
|
||||
files <- listDirectory dir
|
||||
records <- runResourceT $ Conduit.runConduit $
|
||||
sourceTFRecords (dir </> head files) =$= Conduit.consume
|
||||
let (_:event:_) = decodeMessageOrDie . BL.toStrict <$> records
|
||||
assertEqual "First record expected to be Event containing GraphDef" expectedGraphEvent event
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain [ testEventWriter
|
||||
, testLogGraph
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue