mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Add support for logging to tensorboard (#74)
Add support for logging to tensorboard Based on @gnezdo's internal version with some differences: * Uses a pure haskell implementation of EventWriter instead of FFI. * Special `buildAnd*` functions were dropped in favor of using `mergeAllSummaries :: Build SummaryTensor` with the normal `build` function.
This commit is contained in:
parent
dca49d8993
commit
b3c0997a8c
10 changed files with 303 additions and 13 deletions
|
@ -4,6 +4,7 @@ packages:
|
||||||
- google-shim
|
- google-shim
|
||||||
- tensorflow
|
- tensorflow
|
||||||
- tensorflow-core-ops
|
- tensorflow-core-ops
|
||||||
|
- tensorflow-logging
|
||||||
- tensorflow-opgen
|
- tensorflow-opgen
|
||||||
- tensorflow-ops
|
- tensorflow-ops
|
||||||
- tensorflow-proto
|
- tensorflow-proto
|
||||||
|
|
3
tensorflow-logging/Setup.hs
Normal file
3
tensorflow-logging/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
159
tensorflow-logging/src/TensorFlow/Logging.hs
Normal file
159
tensorflow-logging/src/TensorFlow/Logging.hs
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
-- | TensorBoard Summary generation. Provides type safe wrappers around raw
|
||||||
|
-- string emitting CoreOps.
|
||||||
|
--
|
||||||
|
-- Example use:
|
||||||
|
--
|
||||||
|
-- > -- Call summary functions while constructing the graph.
|
||||||
|
-- > createModel = do
|
||||||
|
-- > loss <- -- ...
|
||||||
|
-- > TF.scalarSummary loss
|
||||||
|
-- >
|
||||||
|
-- > -- Write summaries to an EventWriter.
|
||||||
|
-- > train = TF.withEventWriter "/path/to/logs" $ \eventWriter -> do
|
||||||
|
-- > summaryTensor <- TF.build TF.allSummaries
|
||||||
|
-- > forM_ [1..] $ \step -> do
|
||||||
|
-- > if (step % 100 == 0)
|
||||||
|
-- > then do
|
||||||
|
-- > ((), summaryBytes) <- TF.run (trainStep, summaryTensor)
|
||||||
|
-- > let summary = decodeMessageOrDie (TF.unScalar summaryBytes)
|
||||||
|
-- > TF.logSummary eventWriter step summary
|
||||||
|
-- > else TF.run_ trainStep
|
||||||
|
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
|
||||||
|
module TensorFlow.Logging
|
||||||
|
( EventWriter
|
||||||
|
, withEventWriter
|
||||||
|
, logEvent
|
||||||
|
, logSummary
|
||||||
|
, SummaryTensor
|
||||||
|
, histogramSummary
|
||||||
|
, scalarSummary
|
||||||
|
, mergeAllSummaries
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Concurrent (forkFinally)
|
||||||
|
import Control.Concurrent.MVar (MVar, newEmptyMVar, readMVar, putMVar)
|
||||||
|
import Control.Concurrent.STM (atomically)
|
||||||
|
import Control.Concurrent.STM.TBMQueue (TBMQueue, newTBMQueueIO, closeTBMQueue, writeTBMQueue)
|
||||||
|
import Control.Monad.Catch (MonadMask, bracket)
|
||||||
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
|
import Control.Monad.Trans.Resource (runResourceT)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Conduit ((=$=))
|
||||||
|
import Data.Conduit.TQueue (sourceTBMQueue)
|
||||||
|
import Data.Default (def)
|
||||||
|
import Data.Int (Int64)
|
||||||
|
import Data.ProtoLens (encodeMessage)
|
||||||
|
import Data.Time.Clock (getCurrentTime)
|
||||||
|
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
|
||||||
|
import Lens.Family2 ((.~), (&))
|
||||||
|
import Network.HostName (getHostName)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Summary (Summary)
|
||||||
|
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
|
||||||
|
import System.Directory (createDirectoryIfMissing)
|
||||||
|
import System.FilePath ((</>))
|
||||||
|
import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries)
|
||||||
|
import TensorFlow.Ops (scalar)
|
||||||
|
import TensorFlow.Records.Conduit (sinkTFRecords)
|
||||||
|
import TensorFlow.Tensor (Tensor)
|
||||||
|
import TensorFlow.Types (TensorType, type(/=))
|
||||||
|
import Text.Printf (printf)
|
||||||
|
import qualified Data.ByteString.Lazy as L
|
||||||
|
import qualified Data.Conduit as Conduit
|
||||||
|
import qualified Data.Conduit.List as Conduit
|
||||||
|
import qualified Data.Text as T
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
|
-- | Handle for logging TensorBoard events safely from multiple threads.
|
||||||
|
data EventWriter = EventWriter (TBMQueue Event) (MVar ())
|
||||||
|
|
||||||
|
-- | Writes Event protocol buffers to event files.
|
||||||
|
withEventWriter ::
|
||||||
|
(MonadIO m, MonadMask m)
|
||||||
|
=> FilePath
|
||||||
|
-- ^ logdir. Local filesystem directory where event file will be written.
|
||||||
|
-> (EventWriter -> m a)
|
||||||
|
-> m a
|
||||||
|
withEventWriter logdir =
|
||||||
|
bracket (liftIO (newEventWriter logdir)) (liftIO . closeEventWriter)
|
||||||
|
|
||||||
|
newEventWriter :: FilePath -> IO EventWriter
|
||||||
|
newEventWriter logdir = do
|
||||||
|
createDirectoryIfMissing True logdir
|
||||||
|
t <- doubleWallTime
|
||||||
|
hostname <- getHostName
|
||||||
|
let filename = printf (logdir </> "events.out.tfevents.%010d.%s")
|
||||||
|
(truncate t :: Integer) hostname
|
||||||
|
-- Asynchronously consume events from a queue.
|
||||||
|
-- We use a bounded queue to ensure the producer doesn't get too far ahead
|
||||||
|
-- of the consumer. The buffer size was picked arbitrarily.
|
||||||
|
q <- newTBMQueueIO 1024
|
||||||
|
-- Use an MVar to signal that the worker thread has completed.
|
||||||
|
done <- newEmptyMVar
|
||||||
|
let writer = EventWriter q done
|
||||||
|
consumeQueue = runResourceT $ Conduit.runConduit $
|
||||||
|
sourceTBMQueue q
|
||||||
|
=$= Conduit.map (L.fromStrict . encodeMessage)
|
||||||
|
=$= sinkTFRecords filename
|
||||||
|
_ <- forkFinally consumeQueue (\_ -> putMVar done ())
|
||||||
|
logEvent writer $ def & wallTime .~ t
|
||||||
|
& fileVersion .~ T.pack "brain.Event:2"
|
||||||
|
return writer
|
||||||
|
|
||||||
|
closeEventWriter :: EventWriter -> IO ()
|
||||||
|
closeEventWriter (EventWriter q done) =
|
||||||
|
atomically (closeTBMQueue q) >> readMVar done
|
||||||
|
|
||||||
|
-- | Logs the given Event protocol buffer.
|
||||||
|
logEvent :: MonadIO m => EventWriter -> Event -> m ()
|
||||||
|
logEvent (EventWriter q _) pb = liftIO (atomically (writeTBMQueue q pb))
|
||||||
|
|
||||||
|
-- | Logs the given Summary event with an optional global step (use 0 if not
|
||||||
|
-- applicable).
|
||||||
|
logSummary :: MonadIO m => EventWriter -> Int64 -> Summary -> m ()
|
||||||
|
logSummary writer step' summaryProto = do
|
||||||
|
t <- liftIO doubleWallTime
|
||||||
|
logEvent writer (def & wallTime .~ t
|
||||||
|
& step .~ step'
|
||||||
|
& summary .~ summaryProto
|
||||||
|
)
|
||||||
|
|
||||||
|
-- Number of seconds since epoch.
|
||||||
|
doubleWallTime :: IO Double
|
||||||
|
doubleWallTime = asDouble <$> getCurrentTime
|
||||||
|
where asDouble t = fromRational (toRational (utcTimeToPOSIXSeconds t))
|
||||||
|
|
||||||
|
-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally
|
||||||
|
-- limited to a single value for simplicity.
|
||||||
|
histogramSummary ::
|
||||||
|
(TensorType t, t /= ByteString, t /= Bool)
|
||||||
|
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
||||||
|
=> ByteString -> Tensor v t -> Build ()
|
||||||
|
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
|
||||||
|
|
||||||
|
-- | Adds a 'CoreOps.scalarSummary' node.
|
||||||
|
scalarSummary ::
|
||||||
|
(TensorType t, t /= ByteString, t /= Bool)
|
||||||
|
-- (TensorType t,
|
||||||
|
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
||||||
|
=> ByteString -> Tensor v t -> Build ()
|
||||||
|
scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag)
|
||||||
|
|
||||||
|
-- | Merge all summaries accumulated in the 'Build' into one summary.
|
||||||
|
mergeAllSummaries :: Build SummaryTensor
|
||||||
|
mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary
|
66
tensorflow-logging/tensorflow-logging.cabal
Normal file
66
tensorflow-logging/tensorflow-logging.cabal
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
name: tensorflow-logging
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: TensorBoard related functionality.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.Logging
|
||||||
|
build-depends: base >= 4.7 && < 5
|
||||||
|
, bytestring
|
||||||
|
, conduit
|
||||||
|
, data-default
|
||||||
|
, directory
|
||||||
|
, exceptions
|
||||||
|
, filepath
|
||||||
|
, hostname
|
||||||
|
, lens-family
|
||||||
|
, proto-lens == 0.1.*
|
||||||
|
, resourcet
|
||||||
|
, stm
|
||||||
|
, stm-chans
|
||||||
|
, stm-conduit
|
||||||
|
, tensorflow == 0.1.*
|
||||||
|
, tensorflow-core-ops == 0.1.*
|
||||||
|
, tensorflow-ops == 0.1.*
|
||||||
|
, tensorflow-proto == 0.1.*
|
||||||
|
, tensorflow-records-conduit
|
||||||
|
, text
|
||||||
|
, time
|
||||||
|
, transformers
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
Test-Suite LoggingTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: LoggingTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, base
|
||||||
|
, bytestring
|
||||||
|
, conduit
|
||||||
|
, data-default
|
||||||
|
, directory
|
||||||
|
, filepath
|
||||||
|
, lens-family
|
||||||
|
, proto-lens
|
||||||
|
, resourcet
|
||||||
|
, temporary
|
||||||
|
, tensorflow-logging
|
||||||
|
, tensorflow-proto
|
||||||
|
, tensorflow-records-conduit
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, text
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
64
tensorflow-logging/tests/LoggingTest.hs
Normal file
64
tensorflow-logging/tests/LoggingTest.hs
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
-- | Tests for TensorFlow.Logging.
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Control.Monad.Trans.Resource (runResourceT)
|
||||||
|
import Data.Conduit ((=$=))
|
||||||
|
import Data.Default (def)
|
||||||
|
import Data.List ((\\))
|
||||||
|
import Data.ProtoLens (decodeMessageOrDie)
|
||||||
|
import Lens.Family2 ((^.), (.~), (&))
|
||||||
|
import Proto.Tensorflow.Core.Util.Event (fileVersion, step)
|
||||||
|
import System.Directory (getDirectoryContents)
|
||||||
|
import System.FilePath ((</>))
|
||||||
|
import System.IO.Temp (withSystemTempDirectory)
|
||||||
|
import TensorFlow.Logging (withEventWriter, logEvent)
|
||||||
|
import TensorFlow.Records.Conduit (sourceTFRecords)
|
||||||
|
import Test.Framework (defaultMain, Test)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit (assertBool, assertEqual)
|
||||||
|
import qualified Data.ByteString.Lazy as BL
|
||||||
|
import qualified Data.Conduit as Conduit
|
||||||
|
import qualified Data.Conduit.List as Conduit
|
||||||
|
import qualified Data.Text as T
|
||||||
|
|
||||||
|
-- TODO: This has been added to System.Directory in newer versions.
|
||||||
|
listDirectory :: String -> IO [String]
|
||||||
|
listDirectory dir = (\\ [".", ".."]) <$> getDirectoryContents dir
|
||||||
|
|
||||||
|
testEventWriter :: Test
|
||||||
|
testEventWriter = testCase "EventWriter" $
|
||||||
|
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||||
|
assertEqual "No file before" [] =<< listDirectory dir
|
||||||
|
let expected = [ def & step .~ 10
|
||||||
|
, def & step .~ 222
|
||||||
|
, def & step .~ 8
|
||||||
|
]
|
||||||
|
withEventWriter dir $ \eventWriter ->
|
||||||
|
mapM_ (logEvent eventWriter) expected
|
||||||
|
files <- listDirectory dir
|
||||||
|
assertEqual "One file exists after" 1 (length files)
|
||||||
|
records <- runResourceT $ Conduit.runConduit $
|
||||||
|
sourceTFRecords (dir </> head files) =$= Conduit.consume
|
||||||
|
assertBool "File is not empty" (not (null records))
|
||||||
|
let (header:body) = decodeMessageOrDie . BL.toStrict <$> records
|
||||||
|
assertEqual "Header has expected version"
|
||||||
|
(T.pack "brain.Event:2") (header ^. fileVersion)
|
||||||
|
assertEqual "Body has expected records" expected body
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = defaultMain [ testEventWriter
|
||||||
|
]
|
|
@ -13,6 +13,7 @@ cabal-version: >=1.22
|
||||||
extra-source-files: ../third_party/tensorflow/tensorflow/core/framework/*.proto
|
extra-source-files: ../third_party/tensorflow/tensorflow/core/framework/*.proto
|
||||||
, ../third_party/tensorflow/tensorflow/core/protobuf/config.proto
|
, ../third_party/tensorflow/tensorflow/core/protobuf/config.proto
|
||||||
, ../third_party/tensorflow/tensorflow/core/protobuf/debug.proto
|
, ../third_party/tensorflow/tensorflow/core/protobuf/debug.proto
|
||||||
|
, ../third_party/tensorflow/tensorflow/core/util/event.proto
|
||||||
|
|
||||||
library
|
library
|
||||||
exposed-modules: Proto.Tensorflow.Core.Framework.AttrValue
|
exposed-modules: Proto.Tensorflow.Core.Framework.AttrValue
|
||||||
|
@ -20,10 +21,12 @@ library
|
||||||
, Proto.Tensorflow.Core.Framework.NodeDef
|
, Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
, Proto.Tensorflow.Core.Framework.OpDef
|
, Proto.Tensorflow.Core.Framework.OpDef
|
||||||
, Proto.Tensorflow.Core.Framework.ResourceHandle
|
, Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||||
|
, Proto.Tensorflow.Core.Framework.Summary
|
||||||
, Proto.Tensorflow.Core.Framework.Tensor
|
, Proto.Tensorflow.Core.Framework.Tensor
|
||||||
, Proto.Tensorflow.Core.Framework.TensorShape
|
, Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
, Proto.Tensorflow.Core.Framework.Types
|
, Proto.Tensorflow.Core.Framework.Types
|
||||||
, Proto.Tensorflow.Core.Protobuf.Config
|
, Proto.Tensorflow.Core.Protobuf.Config
|
||||||
|
, Proto.Tensorflow.Core.Util.Event
|
||||||
other-modules: Proto.Tensorflow.Core.Framework.AllocationDescription
|
other-modules: Proto.Tensorflow.Core.Framework.AllocationDescription
|
||||||
, Proto.Tensorflow.Core.Framework.CostGraph
|
, Proto.Tensorflow.Core.Framework.CostGraph
|
||||||
, Proto.Tensorflow.Core.Framework.Function
|
, Proto.Tensorflow.Core.Framework.Function
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
|
|
||||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
{-# LANGUAGE Rank2Types #-}
|
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE Rank2Types #-}
|
||||||
module TensorFlow.Build
|
module TensorFlow.Build
|
||||||
( -- * Graph node types
|
( -- * Graph node types
|
||||||
ControlNode(..)
|
ControlNode(..)
|
||||||
|
@ -61,6 +61,7 @@ module TensorFlow.Build
|
||||||
, collectAllSummaries
|
, collectAllSummaries
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||||
import Control.Monad.IO.Class (MonadIO(..))
|
import Control.Monad.IO.Class (MonadIO(..))
|
||||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||||
|
@ -196,7 +197,7 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
|
||||||
-- Used to manage build state internally as part of the @Session@ monad.
|
-- Used to manage build state internally as part of the @Session@ monad.
|
||||||
newtype BuildT m a = BuildT (StateT GraphState m a)
|
newtype BuildT m a = BuildT (StateT GraphState m a)
|
||||||
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
||||||
MonadState GraphState)
|
MonadState GraphState, MonadThrow, MonadCatch, MonadMask)
|
||||||
|
|
||||||
-- | An action for building nodes in a TensorFlow graph.
|
-- | An action for building nodes in a TensorFlow graph.
|
||||||
type Build = BuildT Identity
|
type Build = BuildT Identity
|
||||||
|
|
|
@ -33,7 +33,6 @@ module TensorFlow.Core
|
||||||
-- ** Building graphs
|
-- ** Building graphs
|
||||||
, build
|
, build
|
||||||
, buildAnd
|
, buildAnd
|
||||||
, buildWithSummary
|
|
||||||
-- ** Running graphs
|
-- ** Running graphs
|
||||||
, Fetchable
|
, Fetchable
|
||||||
, Nodes
|
, Nodes
|
||||||
|
|
|
@ -28,7 +28,6 @@ module TensorFlow.Session (
|
||||||
runSessionWithOptions,
|
runSessionWithOptions,
|
||||||
build,
|
build,
|
||||||
buildAnd,
|
buildAnd,
|
||||||
buildWithSummary,
|
|
||||||
extend,
|
extend,
|
||||||
addGraphDef,
|
addGraphDef,
|
||||||
run,
|
run,
|
||||||
|
@ -39,6 +38,7 @@ module TensorFlow.Session (
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad (forever, unless, void)
|
import Control.Monad (forever, unless, void)
|
||||||
|
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
import Control.Monad.Trans.Class (lift)
|
import Control.Monad.Trans.Class (lift)
|
||||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||||
|
@ -77,7 +77,8 @@ data SessionState
|
||||||
|
|
||||||
newtype Session a
|
newtype Session a
|
||||||
= Session (ReaderT SessionState (BuildT IO) a)
|
= Session (ReaderT SessionState (BuildT IO) a)
|
||||||
deriving (Functor, Applicative, Monad, MonadIO)
|
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
||||||
|
MonadMask)
|
||||||
|
|
||||||
-- | Run 'Session' actions in a new TensorFlow session.
|
-- | Run 'Session' actions in a new TensorFlow session.
|
||||||
runSession :: Session a -> IO a
|
runSession :: Session a -> IO a
|
||||||
|
@ -128,14 +129,6 @@ runSessionWithOptions options (Session m) =
|
||||||
build :: Build a -> Session a
|
build :: Build a -> Session a
|
||||||
build = Session . lift . hoistBuildT (return . runIdentity)
|
build = Session . lift . hoistBuildT (return . runIdentity)
|
||||||
|
|
||||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
|
||||||
-- renderings. Returns the merged summary ops which can be used for
|
|
||||||
-- logging, see 'TensorFlow.Logging.build' for a convenient wrapper.
|
|
||||||
buildWithSummary :: forall a . Build a -> Session (a, [SummaryTensor])
|
|
||||||
buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
|
|
||||||
where v :: BuildT IO a
|
|
||||||
v = hoistBuildT (return . runIdentity) b
|
|
||||||
|
|
||||||
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
||||||
-- any pending initializers.
|
-- any pending initializers.
|
||||||
--
|
--
|
||||||
|
|
|
@ -37,6 +37,7 @@ library
|
||||||
, bytestring
|
, bytestring
|
||||||
, containers
|
, containers
|
||||||
, data-default
|
, data-default
|
||||||
|
, exceptions
|
||||||
, fgl
|
, fgl
|
||||||
, lens-family
|
, lens-family
|
||||||
, mainland-pretty
|
, mainland-pretty
|
||||||
|
|
Loading…
Reference in a new issue