{-# LANGUAGE TypeOperators #-}
module TensorFlow.Logging
( EventWriter
, withEventWriter
, logEvent
, logGraph
, logSummary
, SummaryTensor
, histogramSummary
, scalarSummary
, mergeAllSummaries
) where
import Control.Concurrent (forkFinally)
import Control.Concurrent.MVar (MVar, newEmptyMVar, readMVar, putMVar)
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TBMQueue (TBMQueue, newTBMQueueIO, closeTBMQueue, writeTBMQueue)
import Control.Monad.Catch (MonadMask, bracket)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Resource (runResourceT)
import Data.ByteString (ByteString)
import Data.Conduit ((=$=))
import Data.Conduit.TQueue (sourceTBMQueue)
import Data.Default (def)
import Data.Int (Int64)
import Data.ProtoLens (encodeMessage)
import Data.Time.Clock (getCurrentTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import Lens.Family2 ((.~), (&))
import Network.HostName (getHostName)
import Proto.Tensorflow.Core.Framework.Summary (Summary)
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, graphDef, step, summary, wallTime)
import System.Directory (createDirectoryIfMissing)
import System.FilePath ((</>))
import TensorFlow.Build (MonadBuild, Build, asGraphDef)
import TensorFlow.Ops (scalar)
import TensorFlow.Records.Conduit (sinkTFRecords)
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
import TensorFlow.Types (TensorType, type(/=))
import Text.Printf (printf)
import qualified Data.ByteString.Lazy as L
import qualified Data.Conduit as Conduit
import qualified Data.Conduit.List as Conduit
import qualified Data.Text as T
import qualified TensorFlow.GenOps.Core as CoreOps
data EventWriter = EventWriter (TBMQueue Event) (MVar ())
withEventWriter ::
(MonadIO m, MonadMask m)
=> FilePath
-> (EventWriter -> m a)
-> m a
withEventWriter logdir =
bracket (liftIO (newEventWriter logdir)) (liftIO . closeEventWriter)
newEventWriter :: FilePath -> IO EventWriter
newEventWriter logdir = do
createDirectoryIfMissing True logdir
t <- doubleWallTime
hostname <- getHostName
let filename = printf (logdir </> "events.out.tfevents.%010d.%s")
(truncate t :: Integer) hostname
q <- newTBMQueueIO 1024
done <- newEmptyMVar
let writer = EventWriter q done
consumeQueue = runResourceT $ Conduit.runConduit $
sourceTBMQueue q
=$= Conduit.map (L.fromStrict . encodeMessage)
=$= sinkTFRecords filename
_ <- forkFinally consumeQueue (\_ -> putMVar done ())
logEvent writer $ def & wallTime .~ t
& fileVersion .~ T.pack "brain.Event:2"
return writer
closeEventWriter :: EventWriter -> IO ()
closeEventWriter (EventWriter q done) =
atomically (closeTBMQueue q) >> readMVar done
logEvent :: MonadIO m => EventWriter -> Event -> m ()
logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb))
logGraph :: MonadIO m => EventWriter -> Build a -> m ()
logGraph writer build = do
let graph = asGraphDef build
graphBytes = encodeMessage graph
graphEvent = (def :: Event) & graphDef .~ graphBytes
logEvent writer graphEvent
logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m ()
logSummary writer step' summaryProto = do
t <- liftIO doubleWallTime
logEvent writer (def & wallTime .~ t
& step .~ step'
& summary .~ summaryProto
)
doubleWallTime :: IO Double
doubleWallTime = asDouble <$> getCurrentTime
where asDouble t = fromRational (toRational (utcTimeToPOSIXSeconds t))
histogramSummary ::
(MonadBuild m, TensorType t, t /= ByteString, t /= Bool)
=> ByteString -> Tensor v t -> m ()
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
scalarSummary ::
(TensorType t, t /= ByteString, t /= Bool, MonadBuild m)
=> ByteString -> Tensor v t -> m ()
scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag)
mergeAllSummaries :: MonadBuild m => m SummaryTensor
mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary