mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-11 11:29:47 +01:00
Clarify the behavior of readValue in a comment. (#99)
Also add a unit test corresponding to that comments' example code.
This commit is contained in:
parent
42f4fc647e
commit
51c883684b
2 changed files with 32 additions and 1 deletions
|
@ -77,6 +77,20 @@ zeroInitializedVariable'
|
|||
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||
|
||||
-- | Gets the value stored in a variable.
|
||||
--
|
||||
-- Note that this op is stateful since it depends on the value of the variable;
|
||||
-- however, it may be CSE'd with other reads in the same context. The context can
|
||||
-- be fixed by using 'render' along with (for example) 'withControlDependencies'.
|
||||
-- For example:
|
||||
--
|
||||
-- > runSession $ do
|
||||
-- > v <- variable []
|
||||
-- > a <- assign v 24
|
||||
-- > r <- withControlDependencies a $ render $ readValue v + 18
|
||||
-- > result <- run r
|
||||
-- > liftIO $ (42 :: Float) @=? unScalar result
|
||||
--
|
||||
--
|
||||
readValue :: TensorType a => Variable a -> Tensor Build a
|
||||
readValue = readValue' id
|
||||
|
||||
|
|
|
@ -4,13 +4,20 @@ module Main (main) where
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
import qualified Data.Vector.Storable as V
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Core (unScalar, run_, runSession, run)
|
||||
import TensorFlow.Core
|
||||
( unScalar
|
||||
, render
|
||||
, run_
|
||||
, runSession
|
||||
, run
|
||||
, withControlDependencies)
|
||||
import qualified TensorFlow.Ops as Ops
|
||||
import TensorFlow.Variable
|
||||
( readValue
|
||||
, initializedVariable
|
||||
, assign
|
||||
, assignAdd
|
||||
, variable
|
||||
)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
|
@ -19,6 +26,7 @@ import Test.HUnit ((@=?))
|
|||
main :: IO ()
|
||||
main = googleTest [ testInitializedVariable
|
||||
, testInitializedVariableShape
|
||||
, testDependency
|
||||
, testRereadRef
|
||||
, testAssignAdd
|
||||
]
|
||||
|
@ -43,6 +51,15 @@ testInitializedVariableShape =
|
|||
result <- run (readValue vector)
|
||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
|
||||
testDependency :: Test
|
||||
testDependency =
|
||||
testCase "testDependency" $ runSession $ do
|
||||
v <- variable []
|
||||
a <- assign v 24
|
||||
r <- withControlDependencies a $ render (readValue v + 18)
|
||||
result <- run r
|
||||
liftIO $ (42 :: Float) @=? unScalar result
|
||||
|
||||
-- | See https://github.com/tensorflow/haskell/issues/92.
|
||||
-- Even though we're not explicitly evaluating `f0` until the end,
|
||||
-- it should hold the earlier value of the variable.
|
||||
|
|
Loading…
Reference in a new issue