1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-05-06 20:36:44 +02:00

adds initialized value

This commit is contained in:
silky 2016-12-22 11:57:19 +11:00
parent db75350969
commit 0336df5488
4 changed files with 23 additions and 1 deletions
tensorflow-ops
tensorflow/src/TensorFlow

View file

@ -68,6 +68,7 @@ module TensorFlow.Ops
, constant , constant
, CoreOps.equal , CoreOps.equal
, expandDims , expandDims
, initializedValue
, initializedVariable , initializedVariable
, zeroInitializedVariable , zeroInitializedVariable
, CoreOps.fill , CoreOps.fill
@ -111,7 +112,9 @@ import Data.Int (Int32, Int64)
import Prelude hiding (abs, sum, concat) import Prelude hiding (abs, sum, concat)
import Data.ProtoLens (def) import Data.ProtoLens (def)
import Data.Text.Encoding (encodeUtf8) import Data.Text.Encoding (encodeUtf8)
import Data.Set (fromList)
import Lens.Family2 ((.~), (&)) import Lens.Family2 ((.~), (&))
import Lens.Family2.State.Strict (use)
import Text.Printf (printf) import Text.Printf (printf)
import Proto.Tensorflow.Core.Framework.Tensor import Proto.Tensorflow.Core.Framework.Tensor
( TensorProto ( TensorProto
@ -161,6 +164,18 @@ placeholder shape' =
& opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape' & opAttr "shape" .~ shape'
-- | Construct a tensor whose value is the initialized value of the given
-- tensor.
initializedValue :: forall a. TensorType a
=> Tensor Ref a
-> Build (Tensor Ref a)
initializedValue t = do
ns <- use initializationNodes
-- Make this tensor depend on the initializers of the other.
withNodeDependencies (fromList ns) (render t)
-- | Creates a variable initialized to the given value. -- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs. -- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a initializedVariable :: forall a . TensorType a

View file

@ -42,6 +42,7 @@ Test-Suite BuildTest
, google-shim , google-shim
, tensorflow , tensorflow
, tensorflow-ops , tensorflow-ops
, tensorflow-core-ops
, tensorflow-proto , tensorflow-proto
, test-framework , test-framework
, test-framework-hunit , test-framework-hunit

View file

@ -43,10 +43,12 @@ import TensorFlow.Build
) )
import TensorFlow.ControlFlow (named) import TensorFlow.ControlFlow (named)
import TensorFlow.Types (unScalar) import TensorFlow.Types (unScalar)
import TensorFlow.GenOps.Core (identity)
import TensorFlow.Ops import TensorFlow.Ops
( add ( add
, assign , assign
, constant , constant
, initializedValue
, initializedVariable , initializedVariable
, variable , variable
) )
@ -109,7 +111,10 @@ testInitializedVariable =
testInitializedVariableShape :: Test testInitializedVariableShape :: Test
testInitializedVariableShape = testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do testCase "testInitializedVariableShape" $ runSession $ do
vector <- build $ initializedVariable (constant [1] [42 :: Float]) vector <- build $ do
a <- initializedVariable (constant [1] [42 :: Float])
b <- initializedValue a
return b
result <- run vector result <- run vector
liftIO $ [42] @=? (result :: V.Vector Float) liftIO $ [42] @=? (result :: V.Vector Float)

View file

@ -59,6 +59,7 @@ module TensorFlow.Build
, addSummary , addSummary
, SummaryTensor , SummaryTensor
, collectAllSummaries , collectAllSummaries
, initializationNodes
) where ) where
import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Class (MonadIO(..))