diff --git a/stack.yaml b/stack.yaml index ac13a7e..82e9075 100644 --- a/stack.yaml +++ b/stack.yaml @@ -4,6 +4,7 @@ packages: - google-shim - tensorflow - tensorflow-core-ops +- tensorflow-logging - tensorflow-opgen - tensorflow-ops - tensorflow-proto diff --git a/tensorflow-logging/Setup.hs b/tensorflow-logging/Setup.hs new file mode 100644 index 0000000..e8ef27d --- /dev/null +++ b/tensorflow-logging/Setup.hs @@ -0,0 +1,3 @@ +import Distribution.Simple + +main = defaultMain diff --git a/tensorflow-logging/src/TensorFlow/Logging.hs b/tensorflow-logging/src/TensorFlow/Logging.hs new file mode 100644 index 0000000..da1c29f --- /dev/null +++ b/tensorflow-logging/src/TensorFlow/Logging.hs @@ -0,0 +1,159 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- | TensorBoard Summary generation. Provides type safe wrappers around raw +-- string emitting CoreOps. +-- +-- Example use: +-- +-- > -- Call summary functions while constructing the graph. +-- > createModel = do +-- > loss <- -- ... +-- > TF.scalarSummary loss +-- > +-- > -- Write summaries to an EventWriter. +-- > train = TF.withEventWriter "/path/to/logs" $ \eventWriter -> do +-- > summaryTensor <- TF.build TF.allSummaries +-- > forM_ [1..] $ \step -> do +-- > if (step % 100 == 0) +-- > then do +-- > ((), summaryBytes) <- TF.run (trainStep, summaryTensor) +-- > let summary = decodeMessageOrDie (TF.unScalar summaryBytes) +-- > TF.logSummary eventWriter step summary +-- > else TF.run_ trainStep + +{-# LANGUAGE TypeOperators #-} + +module TensorFlow.Logging + ( EventWriter + , withEventWriter + , logEvent + , logSummary + , SummaryTensor + , histogramSummary + , scalarSummary + , mergeAllSummaries + ) where + +import Control.Concurrent (forkFinally) +import Control.Concurrent.MVar (MVar, newEmptyMVar, readMVar, putMVar) +import Control.Concurrent.STM (atomically) +import Control.Concurrent.STM.TBMQueue (TBMQueue, newTBMQueueIO, closeTBMQueue, writeTBMQueue) +import Control.Monad.Catch (MonadMask, bracket) +import Control.Monad.IO.Class (MonadIO, liftIO) +import Control.Monad.Trans.Resource (runResourceT) +import Data.ByteString (ByteString) +import Data.Conduit ((=$=)) +import Data.Conduit.TQueue (sourceTBMQueue) +import Data.Default (def) +import Data.Int (Int64) +import Data.ProtoLens (encodeMessage) +import Data.Time.Clock (getCurrentTime) +import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) +import Lens.Family2 ((.~), (&)) +import Network.HostName (getHostName) +import Proto.Tensorflow.Core.Framework.Summary (Summary) +import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime) +import System.Directory (createDirectoryIfMissing) +import System.FilePath (()) +import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries) +import TensorFlow.Ops (scalar) +import TensorFlow.Records.Conduit (sinkTFRecords) +import TensorFlow.Tensor (Tensor) +import TensorFlow.Types (TensorType, type(/=)) +import Text.Printf (printf) +import qualified Data.ByteString.Lazy as L +import qualified Data.Conduit as Conduit +import qualified Data.Conduit.List as Conduit +import qualified Data.Text as T +import qualified TensorFlow.GenOps.Core as CoreOps + +-- | Handle for logging TensorBoard events safely from multiple threads. +data EventWriter = EventWriter (TBMQueue Event) (MVar ()) + +-- | Writes Event protocol buffers to event files. +withEventWriter :: + (MonadIO m, MonadMask m) + => FilePath + -- ^ logdir. Local filesystem directory where event file will be written. + -> (EventWriter -> m a) + -> m a +withEventWriter logdir = + bracket (liftIO (newEventWriter logdir)) (liftIO . closeEventWriter) + +newEventWriter :: FilePath -> IO EventWriter +newEventWriter logdir = do + createDirectoryIfMissing True logdir + t <- doubleWallTime + hostname <- getHostName + let filename = printf (logdir "events.out.tfevents.%010d.%s") + (truncate t :: Integer) hostname + -- Asynchronously consume events from a queue. + -- We use a bounded queue to ensure the producer doesn't get too far ahead + -- of the consumer. The buffer size was picked arbitrarily. + q <- newTBMQueueIO 1024 + -- Use an MVar to signal that the worker thread has completed. + done <- newEmptyMVar + let writer = EventWriter q done + consumeQueue = runResourceT $ Conduit.runConduit $ + sourceTBMQueue q + =$= Conduit.map (L.fromStrict . encodeMessage) + =$= sinkTFRecords filename + _ <- forkFinally consumeQueue (\_ -> putMVar done ()) + logEvent writer $ def & wallTime .~ t + & fileVersion .~ T.pack "brain.Event:2" + return writer + +closeEventWriter :: EventWriter -> IO () +closeEventWriter (EventWriter q done) = + atomically (closeTBMQueue q) >> readMVar done + +-- | Logs the given Event protocol buffer. +logEvent :: MonadIO m => EventWriter -> Event -> m () +logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb)) + +-- | Logs the given Summary event with an optional global step (use 0 if not +-- applicable). +logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m () +logSummary writer step' summaryProto = do + t <- liftIO doubleWallTime + logEvent writer (def & wallTime .~ t + & step .~ step' + & summary .~ summaryProto + ) + +-- Number of seconds since epoch. +doubleWallTime :: IO Double +doubleWallTime = asDouble <$> getCurrentTime + where asDouble t = fromRational (toRational (utcTimeToPOSIXSeconds t)) + +-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally +-- limited to a single value for simplicity. +histogramSummary :: + (TensorType t, t /= ByteString, t /= Bool) + -- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) + => ByteString -> Tensor v t -> Build () +histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag) + +-- | Adds a 'CoreOps.scalarSummary' node. +scalarSummary :: + (TensorType t, t /= ByteString, t /= Bool) + -- (TensorType t, + -- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) + => ByteString -> Tensor v t -> Build () +scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag) + +-- | Merge all summaries accumulated in the 'Build' into one summary. +mergeAllSummaries :: Build SummaryTensor +mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary diff --git a/tensorflow-logging/tensorflow-logging.cabal b/tensorflow-logging/tensorflow-logging.cabal new file mode 100644 index 0000000..91febd7 --- /dev/null +++ b/tensorflow-logging/tensorflow-logging.cabal @@ -0,0 +1,66 @@ +name: tensorflow-logging +version: 0.1.0.0 +synopsis: TensorBoard related functionality. +description: Please see README.md +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Simple +cabal-version: >=1.22 + +library + hs-source-dirs: src + exposed-modules: TensorFlow.Logging + build-depends: base >= 4.7 && < 5 + , bytestring + , conduit + , data-default + , directory + , exceptions + , filepath + , hostname + , lens-family + , proto-lens == 0.1.* + , resourcet + , stm + , stm-chans + , stm-conduit + , tensorflow == 0.1.* + , tensorflow-core-ops == 0.1.* + , tensorflow-ops == 0.1.* + , tensorflow-proto == 0.1.* + , tensorflow-records-conduit + , text + , time + , transformers + default-language: Haskell2010 + +Test-Suite LoggingTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: LoggingTest.hs + hs-source-dirs: tests + build-depends: HUnit + , base + , bytestring + , conduit + , data-default + , directory + , filepath + , lens-family + , proto-lens + , resourcet + , temporary + , tensorflow-logging + , tensorflow-proto + , tensorflow-records-conduit + , test-framework + , test-framework-hunit + , text + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/tensorflow-logging/tests/LoggingTest.hs b/tensorflow-logging/tests/LoggingTest.hs new file mode 100644 index 0000000..560d192 --- /dev/null +++ b/tensorflow-logging/tests/LoggingTest.hs @@ -0,0 +1,64 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- | Tests for TensorFlow.Logging. +module Main where + +import Control.Monad.Trans.Resource (runResourceT) +import Data.Conduit ((=$=)) +import Data.Default (def) +import Data.List ((\\)) +import Data.ProtoLens (decodeMessageOrDie) +import Lens.Family2 ((^.), (.~), (&)) +import Proto.Tensorflow.Core.Util.Event (fileVersion, step) +import System.Directory (getDirectoryContents) +import System.FilePath (()) +import System.IO.Temp (withSystemTempDirectory) +import TensorFlow.Logging (withEventWriter, logEvent) +import TensorFlow.Records.Conduit (sourceTFRecords) +import Test.Framework (defaultMain, Test) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit (assertBool, assertEqual) +import qualified Data.ByteString.Lazy as BL +import qualified Data.Conduit as Conduit +import qualified Data.Conduit.List as Conduit +import qualified Data.Text as T + +-- TODO: This has been added to System.Directory in newer versions. +listDirectory :: String -> IO [String] +listDirectory dir = (\\ [".", ".."]) <$> getDirectoryContents dir + +testEventWriter :: Test +testEventWriter = testCase "EventWriter" $ + withSystemTempDirectory "event_writer_logs" $ \dir -> do + assertEqual "No file before" [] =<< listDirectory dir + let expected = [ def & step .~ 10 + , def & step .~ 222 + , def & step .~ 8 + ] + withEventWriter dir $ \eventWriter -> + mapM_ (logEvent eventWriter) expected + files <- listDirectory dir + assertEqual "One file exists after" 1 (length files) + records <- runResourceT $ Conduit.runConduit $ + sourceTFRecords (dir head files) =$= Conduit.consume + assertBool "File is not empty" (not (null records)) + let (header:body) = decodeMessageOrDie . BL.toStrict <$> records + assertEqual "Header has expected version" + (T.pack "brain.Event:2") (header ^. fileVersion) + assertEqual "Body has expected records" expected body + +main :: IO () +main = defaultMain [ testEventWriter + ] diff --git a/tensorflow-proto/tensorflow-proto.cabal b/tensorflow-proto/tensorflow-proto.cabal index a7ad83b..5f244c4 100644 --- a/tensorflow-proto/tensorflow-proto.cabal +++ b/tensorflow-proto/tensorflow-proto.cabal @@ -13,6 +13,7 @@ cabal-version: >=1.22 extra-source-files: ../third_party/tensorflow/tensorflow/core/framework/*.proto , ../third_party/tensorflow/tensorflow/core/protobuf/config.proto , ../third_party/tensorflow/tensorflow/core/protobuf/debug.proto + , ../third_party/tensorflow/tensorflow/core/util/event.proto library exposed-modules: Proto.Tensorflow.Core.Framework.AttrValue @@ -20,10 +21,12 @@ library , Proto.Tensorflow.Core.Framework.NodeDef , Proto.Tensorflow.Core.Framework.OpDef , Proto.Tensorflow.Core.Framework.ResourceHandle + , Proto.Tensorflow.Core.Framework.Summary , Proto.Tensorflow.Core.Framework.Tensor , Proto.Tensorflow.Core.Framework.TensorShape , Proto.Tensorflow.Core.Framework.Types , Proto.Tensorflow.Core.Protobuf.Config + , Proto.Tensorflow.Core.Util.Event other-modules: Proto.Tensorflow.Core.Framework.AllocationDescription , Proto.Tensorflow.Core.Framework.CostGraph , Proto.Tensorflow.Core.Framework.Function diff --git a/tensorflow/src/TensorFlow/Build.hs b/tensorflow/src/TensorFlow/Build.hs index 2165c94..8724fb1 100644 --- a/tensorflow/src/TensorFlow/Build.hs +++ b/tensorflow/src/TensorFlow/Build.hs @@ -14,8 +14,8 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE Rank2Types #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE Rank2Types #-} module TensorFlow.Build ( -- * Graph node types ControlNode(..) @@ -61,6 +61,7 @@ module TensorFlow.Build , collectAllSummaries ) where +import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT) @@ -196,7 +197,7 @@ summaries = lens _summaries (\g x -> g { _summaries = x }) -- Used to manage build state internally as part of the @Session@ monad. newtype BuildT m a = BuildT (StateT GraphState m a) deriving (Functor, Applicative, Monad, MonadIO, MonadTrans, - MonadState GraphState) + MonadState GraphState, MonadThrow, MonadCatch, MonadMask) -- | An action for building nodes in a TensorFlow graph. type Build = BuildT Identity diff --git a/tensorflow/src/TensorFlow/Core.hs b/tensorflow/src/TensorFlow/Core.hs index 3938e89..8af036e 100644 --- a/tensorflow/src/TensorFlow/Core.hs +++ b/tensorflow/src/TensorFlow/Core.hs @@ -33,7 +33,6 @@ module TensorFlow.Core -- ** Building graphs , build , buildAnd - , buildWithSummary -- ** Running graphs , Fetchable , Nodes diff --git a/tensorflow/src/TensorFlow/Session.hs b/tensorflow/src/TensorFlow/Session.hs index 1e9cf36..a9a0182 100644 --- a/tensorflow/src/TensorFlow/Session.hs +++ b/tensorflow/src/TensorFlow/Session.hs @@ -28,7 +28,6 @@ module TensorFlow.Session ( runSessionWithOptions, build, buildAnd, - buildWithSummary, extend, addGraphDef, run, @@ -39,6 +38,7 @@ module TensorFlow.Session ( ) where import Control.Monad (forever, unless, void) +import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks) @@ -77,7 +77,8 @@ data SessionState newtype Session a = Session (ReaderT SessionState (BuildT IO) a) - deriving (Functor, Applicative, Monad, MonadIO) + deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch, + MonadMask) -- | Run 'Session' actions in a new TensorFlow session. runSession :: Session a -> IO a @@ -128,14 +129,6 @@ runSessionWithOptions options (Session m) = build :: Build a -> Session a build = Session . lift . hoistBuildT (return . runIdentity) --- | Lift a 'Build' action into a 'Session', including any explicit op --- renderings. Returns the merged summary ops which can be used for --- logging, see 'TensorFlow.Logging.build' for a convenient wrapper. -buildWithSummary :: forall a . Build a -> Session (a, [SummaryTensor]) -buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries - where v :: BuildT IO a - v = hoistBuildT (return . runIdentity) b - -- | Add all pending rendered nodes to the TensorFlow graph and runs -- any pending initializers. -- diff --git a/tensorflow/tensorflow.cabal b/tensorflow/tensorflow.cabal index 09efe53..0d656f1 100644 --- a/tensorflow/tensorflow.cabal +++ b/tensorflow/tensorflow.cabal @@ -37,6 +37,7 @@ library , bytestring , containers , data-default + , exceptions , fgl , lens-family , mainland-pretty