diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 7e33696..b0afc5c 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -68,6 +68,7 @@ module TensorFlow.Ops , constant , CoreOps.equal , expandDims + , initializedValue , initializedVariable , zeroInitializedVariable , CoreOps.fill @@ -111,7 +112,9 @@ import Data.Int (Int32, Int64) import Prelude hiding (abs, sum, concat) import Data.ProtoLens (def) import Data.Text.Encoding (encodeUtf8) +import Data.Set (fromList) import Lens.Family2 ((.~), (&)) +import Lens.Family2.State.Strict (use) import Text.Printf (printf) import Proto.Tensorflow.Core.Framework.Tensor ( TensorProto @@ -161,6 +164,18 @@ placeholder shape' = & opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "shape" .~ shape' + +-- | Construct a tensor whose value is the initialized value of the given +-- tensor. +initializedValue :: forall a. TensorType a + => Tensor Ref a + -> Build (Tensor Ref a) +initializedValue t = do + ns <- use initializationNodes + -- Make this tensor depend on the initializers of the other. + withNodeDependencies (fromList ns) (render t) + + -- | Creates a variable initialized to the given value. -- Initialization happens next time session runs. initializedVariable :: forall a . TensorType a diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index 94c7478..4c3edf6 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -42,6 +42,7 @@ Test-Suite BuildTest , google-shim , tensorflow , tensorflow-ops + , tensorflow-core-ops , tensorflow-proto , test-framework , test-framework-hunit diff --git a/tensorflow-ops/tests/BuildTest.hs b/tensorflow-ops/tests/BuildTest.hs index d8bf859..efc641c 100644 --- a/tensorflow-ops/tests/BuildTest.hs +++ b/tensorflow-ops/tests/BuildTest.hs @@ -43,10 +43,12 @@ import TensorFlow.Build ) import TensorFlow.ControlFlow (named) import TensorFlow.Types (unScalar) +import TensorFlow.GenOps.Core (identity) import TensorFlow.Ops ( add , assign , constant + , initializedValue , initializedVariable , variable ) @@ -109,7 +111,10 @@ testInitializedVariable = testInitializedVariableShape :: Test testInitializedVariableShape = testCase "testInitializedVariableShape" $ runSession $ do - vector <- build $ initializedVariable (constant [1] [42 :: Float]) + vector <- build $ do + a <- initializedVariable (constant [1] [42 :: Float]) + b <- initializedValue a + return b result <- run vector liftIO $ [42] @=? (result :: V.Vector Float) diff --git a/tensorflow/src/TensorFlow/Build.hs b/tensorflow/src/TensorFlow/Build.hs index 2165c94..2482eb7 100644 --- a/tensorflow/src/TensorFlow/Build.hs +++ b/tensorflow/src/TensorFlow/Build.hs @@ -59,6 +59,7 @@ module TensorFlow.Build , addSummary , SummaryTensor , collectAllSummaries + , initializationNodes ) where import Control.Monad.IO.Class (MonadIO(..))