diff --git a/tensorflow-logging/src/TensorFlow/Logging.hs b/tensorflow-logging/src/TensorFlow/Logging.hs index f0d5439..2c0c160 100644 --- a/tensorflow-logging/src/TensorFlow/Logging.hs +++ b/tensorflow-logging/src/TensorFlow/Logging.hs @@ -39,6 +39,7 @@ module TensorFlow.Logging ( EventWriter , withEventWriter , logEvent + , logGraph , logSummary , SummaryTensor , histogramSummary @@ -64,10 +65,10 @@ import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) import Lens.Family2 ((.~), (&)) import Network.HostName (getHostName) import Proto.Tensorflow.Core.Framework.Summary (Summary) -import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime) +import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, graphDef, step, summary, wallTime) import System.Directory (createDirectoryIfMissing) import System.FilePath (()) -import TensorFlow.Build (MonadBuild) +import TensorFlow.Build (MonadBuild, Build, asGraphDef) import TensorFlow.Ops (scalar) import TensorFlow.Records.Conduit (sinkTFRecords) import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries) @@ -123,6 +124,14 @@ closeEventWriter (EventWriter q done) = logEvent :: MonadIO m => EventWriter -> Event -> m () logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb)) +-- | Logs the graph for the given 'Build' action. +logGraph :: MonadIO m => EventWriter -> Build a -> m () +logGraph writer build = do + let graph = asGraphDef build + graphBytes = encodeMessage graph + graphEvent = (def :: Event) & graphDef .~ graphBytes + logEvent writer graphEvent + -- | Logs the given Summary event with an optional global step (use 0 if not -- applicable). logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m () @@ -133,6 +142,7 @@ logSummary writer step' summaryProto = do & summary .~ summaryProto ) + -- Number of seconds since epoch. doubleWallTime :: IO Double doubleWallTime = asDouble <$> getCurrentTime diff --git a/tensorflow-logging/tensorflow-logging.cabal b/tensorflow-logging/tensorflow-logging.cabal index 5feb859..d9c14a0 100644 --- a/tensorflow-logging/tensorflow-logging.cabal +++ b/tensorflow-logging/tensorflow-logging.cabal @@ -55,6 +55,7 @@ Test-Suite LoggingTest , proto-lens , resourcet , temporary + , tensorflow , tensorflow-logging , tensorflow-proto , tensorflow-records-conduit diff --git a/tensorflow-logging/tests/LoggingTest.hs b/tensorflow-logging/tests/LoggingTest.hs index 12d46d6..8b874c4 100644 --- a/tensorflow-logging/tests/LoggingTest.hs +++ b/tensorflow-logging/tests/LoggingTest.hs @@ -19,13 +19,14 @@ import Control.Monad.Trans.Resource (runResourceT) import Data.Conduit ((=$=)) import Data.Default (def) import Data.List ((\\)) -import Data.ProtoLens (decodeMessageOrDie) +import Data.ProtoLens (encodeMessage, decodeMessageOrDie) import Lens.Family2 ((^.), (.~), (&)) -import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step) +import Proto.Tensorflow.Core.Util.Event (Event, graphDef, fileVersion, step) import System.Directory (getDirectoryContents) import System.FilePath (()) import System.IO.Temp (withSystemTempDirectory) -import TensorFlow.Logging (withEventWriter, logEvent) +import TensorFlow.Core (Build, ControlNode, asGraphDef, noOp) +import TensorFlow.Logging (withEventWriter, logEvent, logGraph) import TensorFlow.Records.Conduit (sourceTFRecords) import Test.Framework (defaultMain, Test) import Test.Framework.Providers.HUnit (testCase) @@ -59,6 +60,22 @@ testEventWriter = testCase "EventWriter" $ (T.pack "brain.Event:2") (header ^. fileVersion) assertEqual "Body has expected records" expected body +testLogGraph :: Test +testLogGraph = testCase "LogGraph" $ + withSystemTempDirectory "event_writer_logs" $ \dir -> do + let graphBuild = noOp :: Build ControlNode + expectedGraph = asGraphDef graphBuild + expectedGraphEvent = (def :: Event) & graphDef .~ (encodeMessage expectedGraph) + + withEventWriter dir $ \eventWriter -> + logGraph eventWriter graphBuild + files <- listDirectory dir + records <- runResourceT $ Conduit.runConduit $ + sourceTFRecords (dir head files) =$= Conduit.consume + let (_:event:_) = decodeMessageOrDie . BL.toStrict <$> records + assertEqual "First record expected to be Event containing GraphDef" expectedGraphEvent event + main :: IO () main = defaultMain [ testEventWriter + , testLogGraph ]