mirror of
https://github.com/tensorflow/haskell.git
synced 2025-04-28 16:55:23 +02:00
adds initialized value
This commit is contained in:
parent
db75350969
commit
0336df5488
4 changed files with 23 additions and 1 deletions
tensorflow-ops
tensorflow/src/TensorFlow
|
@ -68,6 +68,7 @@ module TensorFlow.Ops
|
|||
, constant
|
||||
, CoreOps.equal
|
||||
, expandDims
|
||||
, initializedValue
|
||||
, initializedVariable
|
||||
, zeroInitializedVariable
|
||||
, CoreOps.fill
|
||||
|
@ -111,7 +112,9 @@ import Data.Int (Int32, Int64)
|
|||
import Prelude hiding (abs, sum, concat)
|
||||
import Data.ProtoLens (def)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.Set (fromList)
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import Lens.Family2.State.Strict (use)
|
||||
import Text.Printf (printf)
|
||||
import Proto.Tensorflow.Core.Framework.Tensor
|
||||
( TensorProto
|
||||
|
@ -161,6 +164,18 @@ placeholder shape' =
|
|||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "shape" .~ shape'
|
||||
|
||||
|
||||
-- | Construct a tensor whose value is the initialized value of the given
|
||||
-- tensor.
|
||||
initializedValue :: forall a. TensorType a
|
||||
=> Tensor Ref a
|
||||
-> Build (Tensor Ref a)
|
||||
initializedValue t = do
|
||||
ns <- use initializationNodes
|
||||
-- Make this tensor depend on the initializers of the other.
|
||||
withNodeDependencies (fromList ns) (render t)
|
||||
|
||||
|
||||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
initializedVariable :: forall a . TensorType a
|
||||
|
|
|
@ -42,6 +42,7 @@ Test-Suite BuildTest
|
|||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
|
|
|
@ -43,10 +43,12 @@ import TensorFlow.Build
|
|||
)
|
||||
import TensorFlow.ControlFlow (named)
|
||||
import TensorFlow.Types (unScalar)
|
||||
import TensorFlow.GenOps.Core (identity)
|
||||
import TensorFlow.Ops
|
||||
( add
|
||||
, assign
|
||||
, constant
|
||||
, initializedValue
|
||||
, initializedVariable
|
||||
, variable
|
||||
)
|
||||
|
@ -109,7 +111,10 @@ testInitializedVariable =
|
|||
testInitializedVariableShape :: Test
|
||||
testInitializedVariableShape =
|
||||
testCase "testInitializedVariableShape" $ runSession $ do
|
||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
||||
vector <- build $ do
|
||||
a <- initializedVariable (constant [1] [42 :: Float])
|
||||
b <- initializedValue a
|
||||
return b
|
||||
result <- run vector
|
||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
|
||||
|
|
|
@ -59,6 +59,7 @@ module TensorFlow.Build
|
|||
, addSummary
|
||||
, SummaryTensor
|
||||
, collectAllSummaries
|
||||
, initializationNodes
|
||||
) where
|
||||
|
||||
import Control.Monad.IO.Class (MonadIO(..))
|
||||
|
|
Loading…
Add table
Reference in a new issue