Clarify the behavior of readValue in a comment. (#99)

Also add a unit test corresponding to that comments' example code.
This commit is contained in:
Judah Jacobson 2017-04-16 15:31:26 -07:00 committed by fkm3
parent 42f4fc647e
commit 51c883684b
2 changed files with 32 additions and 1 deletions

View File

@ -77,6 +77,20 @@ zeroInitializedVariable'
zeroInitializedVariable' params = initializedVariable' params . zeros
-- | Gets the value stored in a variable.
--
-- Note that this op is stateful since it depends on the value of the variable;
-- however, it may be CSE'd with other reads in the same context. The context can
-- be fixed by using 'render' along with (for example) 'withControlDependencies'.
-- For example:
--
-- > runSession $ do
-- > v <- variable []
-- > a <- assign v 24
-- > r <- withControlDependencies a $ render $ readValue v + 18
-- > result <- run r
-- > liftIO $ (42 :: Float) @=? unScalar result
--
--
readValue :: TensorType a => Variable a -> Tensor Build a
readValue = readValue' id

View File

@ -4,13 +4,20 @@ module Main (main) where
import Control.Monad.IO.Class (liftIO)
import qualified Data.Vector.Storable as V
import Google.Test (googleTest)
import TensorFlow.Core (unScalar, run_, runSession, run)
import TensorFlow.Core
( unScalar
, render
, run_
, runSession
, run
, withControlDependencies)
import qualified TensorFlow.Ops as Ops
import TensorFlow.Variable
( readValue
, initializedVariable
, assign
, assignAdd
, variable
)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
@ -19,6 +26,7 @@ import Test.HUnit ((@=?))
main :: IO ()
main = googleTest [ testInitializedVariable
, testInitializedVariableShape
, testDependency
, testRereadRef
, testAssignAdd
]
@ -43,6 +51,15 @@ testInitializedVariableShape =
result <- run (readValue vector)
liftIO $ [42] @=? (result :: V.Vector Float)
testDependency :: Test
testDependency =
testCase "testDependency" $ runSession $ do
v <- variable []
a <- assign v 24
r <- withControlDependencies a $ render (readValue v + 18)
result <- run r
liftIO $ (42 :: Float) @=? unScalar result
-- | See https://github.com/tensorflow/haskell/issues/92.
-- Even though we're not explicitly evaluating `f0` until the end,
-- it should hold the earlier value of the variable.