From 51c883684b5d92a1a5f5e1c43b2fea53ab31af11 Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Sun, 16 Apr 2017 15:31:26 -0700 Subject: [PATCH] Clarify the behavior of readValue in a comment. (#99) Also add a unit test corresponding to that comments' example code. --- tensorflow-ops/src/TensorFlow/Variable.hs | 14 ++++++++++++++ tensorflow-ops/tests/VariableTest.hs | 19 ++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tensorflow-ops/src/TensorFlow/Variable.hs b/tensorflow-ops/src/TensorFlow/Variable.hs index b62082e..76e6e48 100644 --- a/tensorflow-ops/src/TensorFlow/Variable.hs +++ b/tensorflow-ops/src/TensorFlow/Variable.hs @@ -77,6 +77,20 @@ zeroInitializedVariable' zeroInitializedVariable' params = initializedVariable' params . zeros -- | Gets the value stored in a variable. +-- +-- Note that this op is stateful since it depends on the value of the variable; +-- however, it may be CSE'd with other reads in the same context. The context can +-- be fixed by using 'render' along with (for example) 'withControlDependencies'. +-- For example: +-- +-- > runSession $ do +-- > v <- variable [] +-- > a <- assign v 24 +-- > r <- withControlDependencies a $ render $ readValue v + 18 +-- > result <- run r +-- > liftIO $ (42 :: Float) @=? unScalar result +-- +-- readValue :: TensorType a => Variable a -> Tensor Build a readValue = readValue' id diff --git a/tensorflow-ops/tests/VariableTest.hs b/tensorflow-ops/tests/VariableTest.hs index eccc76a..e9cf303 100644 --- a/tensorflow-ops/tests/VariableTest.hs +++ b/tensorflow-ops/tests/VariableTest.hs @@ -4,13 +4,20 @@ module Main (main) where import Control.Monad.IO.Class (liftIO) import qualified Data.Vector.Storable as V import Google.Test (googleTest) -import TensorFlow.Core (unScalar, run_, runSession, run) +import TensorFlow.Core + ( unScalar + , render + , run_ + , runSession + , run + , withControlDependencies) import qualified TensorFlow.Ops as Ops import TensorFlow.Variable ( readValue , initializedVariable , assign , assignAdd + , variable ) import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) @@ -19,6 +26,7 @@ import Test.HUnit ((@=?)) main :: IO () main = googleTest [ testInitializedVariable , testInitializedVariableShape + , testDependency , testRereadRef , testAssignAdd ] @@ -43,6 +51,15 @@ testInitializedVariableShape = result <- run (readValue vector) liftIO $ [42] @=? (result :: V.Vector Float) +testDependency :: Test +testDependency = + testCase "testDependency" $ runSession $ do + v <- variable [] + a <- assign v 24 + r <- withControlDependencies a $ render (readValue v + 18) + result <- run r + liftIO $ (42 :: Float) @=? unScalar result + -- | See https://github.com/tensorflow/haskell/issues/92. -- Even though we're not explicitly evaluating `f0` until the end, -- it should hold the earlier value of the variable.