mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
Clarify the behavior of readValue in a comment. (#99)
Also add a unit test corresponding to that comments' example code.
This commit is contained in:
parent
42f4fc647e
commit
51c883684b
2 changed files with 32 additions and 1 deletions
|
@ -77,6 +77,20 @@ zeroInitializedVariable'
|
||||||
zeroInitializedVariable' params = initializedVariable' params . zeros
|
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||||
|
|
||||||
-- | Gets the value stored in a variable.
|
-- | Gets the value stored in a variable.
|
||||||
|
--
|
||||||
|
-- Note that this op is stateful since it depends on the value of the variable;
|
||||||
|
-- however, it may be CSE'd with other reads in the same context. The context can
|
||||||
|
-- be fixed by using 'render' along with (for example) 'withControlDependencies'.
|
||||||
|
-- For example:
|
||||||
|
--
|
||||||
|
-- > runSession $ do
|
||||||
|
-- > v <- variable []
|
||||||
|
-- > a <- assign v 24
|
||||||
|
-- > r <- withControlDependencies a $ render $ readValue v + 18
|
||||||
|
-- > result <- run r
|
||||||
|
-- > liftIO $ (42 :: Float) @=? unScalar result
|
||||||
|
--
|
||||||
|
--
|
||||||
readValue :: TensorType a => Variable a -> Tensor Build a
|
readValue :: TensorType a => Variable a -> Tensor Build a
|
||||||
readValue = readValue' id
|
readValue = readValue' id
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,20 @@ module Main (main) where
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import qualified Data.Vector.Storable as V
|
import qualified Data.Vector.Storable as V
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import TensorFlow.Core (unScalar, run_, runSession, run)
|
import TensorFlow.Core
|
||||||
|
( unScalar
|
||||||
|
, render
|
||||||
|
, run_
|
||||||
|
, runSession
|
||||||
|
, run
|
||||||
|
, withControlDependencies)
|
||||||
import qualified TensorFlow.Ops as Ops
|
import qualified TensorFlow.Ops as Ops
|
||||||
import TensorFlow.Variable
|
import TensorFlow.Variable
|
||||||
( readValue
|
( readValue
|
||||||
, initializedVariable
|
, initializedVariable
|
||||||
, assign
|
, assign
|
||||||
, assignAdd
|
, assignAdd
|
||||||
|
, variable
|
||||||
)
|
)
|
||||||
import Test.Framework (Test)
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
@ -19,6 +26,7 @@ import Test.HUnit ((@=?))
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testInitializedVariable
|
main = googleTest [ testInitializedVariable
|
||||||
, testInitializedVariableShape
|
, testInitializedVariableShape
|
||||||
|
, testDependency
|
||||||
, testRereadRef
|
, testRereadRef
|
||||||
, testAssignAdd
|
, testAssignAdd
|
||||||
]
|
]
|
||||||
|
@ -43,6 +51,15 @@ testInitializedVariableShape =
|
||||||
result <- run (readValue vector)
|
result <- run (readValue vector)
|
||||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||||
|
|
||||||
|
testDependency :: Test
|
||||||
|
testDependency =
|
||||||
|
testCase "testDependency" $ runSession $ do
|
||||||
|
v <- variable []
|
||||||
|
a <- assign v 24
|
||||||
|
r <- withControlDependencies a $ render (readValue v + 18)
|
||||||
|
result <- run r
|
||||||
|
liftIO $ (42 :: Float) @=? unScalar result
|
||||||
|
|
||||||
-- | See https://github.com/tensorflow/haskell/issues/92.
|
-- | See https://github.com/tensorflow/haskell/issues/92.
|
||||||
-- Even though we're not explicitly evaluating `f0` until the end,
|
-- Even though we're not explicitly evaluating `f0` until the end,
|
||||||
-- it should hold the earlier value of the variable.
|
-- it should hold the earlier value of the variable.
|
||||||
|
|
Loading…
Reference in a new issue