1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-01-11 19:39:49 +01:00

Added imageSummary wrapper (#159)

This commit is contained in:
Christian Berentsen 2018-01-23 19:02:58 +01:00 committed by fkm3
parent 760c067e89
commit f2cafa7071

View file

@ -33,8 +33,10 @@
-- > TF.logSummary eventWriter step summary
-- > else TF.run_ trainStep
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
module TensorFlow.Logging
( EventWriter
, withEventWriter
@ -43,6 +45,7 @@ module TensorFlow.Logging
, logSummary
, SummaryTensor
, histogramSummary
, imageSummary
, scalarSummary
, mergeAllSummaries
) where
@ -59,6 +62,7 @@ import Data.Conduit ((=$=))
import Data.Conduit.TQueue (sourceTBMQueue)
import Data.Default (def)
import Data.Int (Int64)
import Data.Word (Word8, Word16)
import Data.ProtoLens (encodeMessage)
import Data.Time.Clock (getCurrentTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
@ -72,7 +76,7 @@ import TensorFlow.Build (MonadBuild, Build, asGraphDef)
import TensorFlow.Ops (scalar)
import TensorFlow.Records.Conduit (sinkTFRecords)
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
import TensorFlow.Types (TensorType, type(/=))
import TensorFlow.Types (TensorType, type(/=), OneOf)
import Text.Printf (printf)
import qualified Data.ByteString.Lazy as L
import qualified Data.Conduit as Conduit
@ -142,7 +146,6 @@ logSummary writer step' summaryProto = do
& summary .~ summaryProto
)
-- Number of seconds since epoch.
doubleWallTime :: IO Double
doubleWallTime = asDouble <$> getCurrentTime
@ -156,6 +159,16 @@ histogramSummary ::
=> ByteString -> Tensor v t -> m ()
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
-- | Adds a 'CoreOps.imageSummary' node. The tag argument is intentionally
-- limited to a single value for simplicity.
imageSummary ::
(OneOf '[Word8, Word16, Float] t, MonadBuild m)
=> ByteString
-> Tensor v t
-> m ()
imageSummary tag = addSummary . CoreOps.imageSummary (scalar tag)
-- | Adds a 'CoreOps.scalarSummary' node.
scalarSummary ::
(TensorType t, t /= ByteString, t /= Bool, MonadBuild m)