mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
Added logGraph for graph visualization in TensorBoard (#104)
This commit is contained in:
parent
423b34537e
commit
042910b000
3 changed files with 33 additions and 5 deletions
|
@ -39,6 +39,7 @@ module TensorFlow.Logging
|
||||||
( EventWriter
|
( EventWriter
|
||||||
, withEventWriter
|
, withEventWriter
|
||||||
, logEvent
|
, logEvent
|
||||||
|
, logGraph
|
||||||
, logSummary
|
, logSummary
|
||||||
, SummaryTensor
|
, SummaryTensor
|
||||||
, histogramSummary
|
, histogramSummary
|
||||||
|
@ -64,10 +65,10 @@ import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
|
||||||
import Lens.Family2 ((.~), (&))
|
import Lens.Family2 ((.~), (&))
|
||||||
import Network.HostName (getHostName)
|
import Network.HostName (getHostName)
|
||||||
import Proto.Tensorflow.Core.Framework.Summary (Summary)
|
import Proto.Tensorflow.Core.Framework.Summary (Summary)
|
||||||
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
|
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, graphDef, step, summary, wallTime)
|
||||||
import System.Directory (createDirectoryIfMissing)
|
import System.Directory (createDirectoryIfMissing)
|
||||||
import System.FilePath ((</>))
|
import System.FilePath ((</>))
|
||||||
import TensorFlow.Build (MonadBuild)
|
import TensorFlow.Build (MonadBuild, Build, asGraphDef)
|
||||||
import TensorFlow.Ops (scalar)
|
import TensorFlow.Ops (scalar)
|
||||||
import TensorFlow.Records.Conduit (sinkTFRecords)
|
import TensorFlow.Records.Conduit (sinkTFRecords)
|
||||||
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
|
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
|
||||||
|
@ -123,6 +124,14 @@ closeEventWriter (EventWriter q done) =
|
||||||
logEvent :: MonadIO m => EventWriter -> Event -> m ()
|
logEvent :: MonadIO m => EventWriter -> Event -> m ()
|
||||||
logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb))
|
logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb))
|
||||||
|
|
||||||
|
-- | Logs the graph for the given 'Build' action.
|
||||||
|
logGraph :: MonadIO m => EventWriter -> Build a -> m ()
|
||||||
|
logGraph writer build = do
|
||||||
|
let graph = asGraphDef build
|
||||||
|
graphBytes = encodeMessage graph
|
||||||
|
graphEvent = (def :: Event) & graphDef .~ graphBytes
|
||||||
|
logEvent writer graphEvent
|
||||||
|
|
||||||
-- | Logs the given Summary event with an optional global step (use 0 if not
|
-- | Logs the given Summary event with an optional global step (use 0 if not
|
||||||
-- applicable).
|
-- applicable).
|
||||||
logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m ()
|
logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m ()
|
||||||
|
@ -133,6 +142,7 @@ logSummary writer step' summaryProto = do
|
||||||
& summary .~ summaryProto
|
& summary .~ summaryProto
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
-- Number of seconds since epoch.
|
-- Number of seconds since epoch.
|
||||||
doubleWallTime :: IO Double
|
doubleWallTime :: IO Double
|
||||||
doubleWallTime = asDouble <$> getCurrentTime
|
doubleWallTime = asDouble <$> getCurrentTime
|
||||||
|
|
|
@ -55,6 +55,7 @@ Test-Suite LoggingTest
|
||||||
, proto-lens
|
, proto-lens
|
||||||
, resourcet
|
, resourcet
|
||||||
, temporary
|
, temporary
|
||||||
|
, tensorflow
|
||||||
, tensorflow-logging
|
, tensorflow-logging
|
||||||
, tensorflow-proto
|
, tensorflow-proto
|
||||||
, tensorflow-records-conduit
|
, tensorflow-records-conduit
|
||||||
|
|
|
@ -19,13 +19,14 @@ import Control.Monad.Trans.Resource (runResourceT)
|
||||||
import Data.Conduit ((=$=))
|
import Data.Conduit ((=$=))
|
||||||
import Data.Default (def)
|
import Data.Default (def)
|
||||||
import Data.List ((\\))
|
import Data.List ((\\))
|
||||||
import Data.ProtoLens (decodeMessageOrDie)
|
import Data.ProtoLens (encodeMessage, decodeMessageOrDie)
|
||||||
import Lens.Family2 ((^.), (.~), (&))
|
import Lens.Family2 ((^.), (.~), (&))
|
||||||
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step)
|
import Proto.Tensorflow.Core.Util.Event (Event, graphDef, fileVersion, step)
|
||||||
import System.Directory (getDirectoryContents)
|
import System.Directory (getDirectoryContents)
|
||||||
import System.FilePath ((</>))
|
import System.FilePath ((</>))
|
||||||
import System.IO.Temp (withSystemTempDirectory)
|
import System.IO.Temp (withSystemTempDirectory)
|
||||||
import TensorFlow.Logging (withEventWriter, logEvent)
|
import TensorFlow.Core (Build, ControlNode, asGraphDef, noOp)
|
||||||
|
import TensorFlow.Logging (withEventWriter, logEvent, logGraph)
|
||||||
import TensorFlow.Records.Conduit (sourceTFRecords)
|
import TensorFlow.Records.Conduit (sourceTFRecords)
|
||||||
import Test.Framework (defaultMain, Test)
|
import Test.Framework (defaultMain, Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
@ -59,6 +60,22 @@ testEventWriter = testCase "EventWriter" $
|
||||||
(T.pack "brain.Event:2") (header ^. fileVersion)
|
(T.pack "brain.Event:2") (header ^. fileVersion)
|
||||||
assertEqual "Body has expected records" expected body
|
assertEqual "Body has expected records" expected body
|
||||||
|
|
||||||
|
testLogGraph :: Test
|
||||||
|
testLogGraph = testCase "LogGraph" $
|
||||||
|
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||||
|
let graphBuild = noOp :: Build ControlNode
|
||||||
|
expectedGraph = asGraphDef graphBuild
|
||||||
|
expectedGraphEvent = (def :: Event) & graphDef .~ (encodeMessage expectedGraph)
|
||||||
|
|
||||||
|
withEventWriter dir $ \eventWriter ->
|
||||||
|
logGraph eventWriter graphBuild
|
||||||
|
files <- listDirectory dir
|
||||||
|
records <- runResourceT $ Conduit.runConduit $
|
||||||
|
sourceTFRecords (dir </> head files) =$= Conduit.consume
|
||||||
|
let (_:event:_) = decodeMessageOrDie . BL.toStrict <$> records
|
||||||
|
assertEqual "First record expected to be Event containing GraphDef" expectedGraphEvent event
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = defaultMain [ testEventWriter
|
main = defaultMain [ testEventWriter
|
||||||
|
, testLogGraph
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue