diff --git a/tensorflow-logging/src/TensorFlow/Logging.hs b/tensorflow-logging/src/TensorFlow/Logging.hs index 2c0c160..fac2eed 100644 --- a/tensorflow-logging/src/TensorFlow/Logging.hs +++ b/tensorflow-logging/src/TensorFlow/Logging.hs @@ -33,8 +33,10 @@ -- > TF.logSummary eventWriter step summary -- > else TF.run_ trainStep +{-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeOperators #-} + module TensorFlow.Logging ( EventWriter , withEventWriter @@ -43,6 +45,7 @@ module TensorFlow.Logging , logSummary , SummaryTensor , histogramSummary + , imageSummary , scalarSummary , mergeAllSummaries ) where @@ -59,6 +62,7 @@ import Data.Conduit ((=$=)) import Data.Conduit.TQueue (sourceTBMQueue) import Data.Default (def) import Data.Int (Int64) +import Data.Word (Word8, Word16) import Data.ProtoLens (encodeMessage) import Data.Time.Clock (getCurrentTime) import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) @@ -72,7 +76,7 @@ import TensorFlow.Build (MonadBuild, Build, asGraphDef) import TensorFlow.Ops (scalar) import TensorFlow.Records.Conduit (sinkTFRecords) import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries) -import TensorFlow.Types (TensorType, type(/=)) +import TensorFlow.Types (TensorType, type(/=), OneOf) import Text.Printf (printf) import qualified Data.ByteString.Lazy as L import qualified Data.Conduit as Conduit @@ -142,7 +146,6 @@ logSummary writer step' summaryProto = do & summary .~ summaryProto ) - -- Number of seconds since epoch. doubleWallTime :: IO Double doubleWallTime = asDouble <$> getCurrentTime @@ -156,6 +159,16 @@ histogramSummary :: => ByteString -> Tensor v t -> m () histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag) +-- | Adds a 'CoreOps.imageSummary' node. The tag argument is intentionally +-- limited to a single value for simplicity. +imageSummary :: + (OneOf '[Word8, Word16, Float] t, MonadBuild m) + => ByteString + -> Tensor v t + -> m () + +imageSummary tag = addSummary . CoreOps.imageSummary (scalar tag) + -- | Adds a 'CoreOps.scalarSummary' node. scalarSummary :: (TensorType t, t /= ByteString, t /= Bool, MonadBuild m)