mirror of
https://github.com/tensorflow/haskell.git
synced 2025-05-06 20:36:44 +02:00
adds initialized value
This commit is contained in:
parent
db75350969
commit
0336df5488
4 changed files with 23 additions and 1 deletions
tensorflow-ops
tensorflow/src/TensorFlow
|
@ -68,6 +68,7 @@ module TensorFlow.Ops
|
||||||
, constant
|
, constant
|
||||||
, CoreOps.equal
|
, CoreOps.equal
|
||||||
, expandDims
|
, expandDims
|
||||||
|
, initializedValue
|
||||||
, initializedVariable
|
, initializedVariable
|
||||||
, zeroInitializedVariable
|
, zeroInitializedVariable
|
||||||
, CoreOps.fill
|
, CoreOps.fill
|
||||||
|
@ -111,7 +112,9 @@ import Data.Int (Int32, Int64)
|
||||||
import Prelude hiding (abs, sum, concat)
|
import Prelude hiding (abs, sum, concat)
|
||||||
import Data.ProtoLens (def)
|
import Data.ProtoLens (def)
|
||||||
import Data.Text.Encoding (encodeUtf8)
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
|
import Data.Set (fromList)
|
||||||
import Lens.Family2 ((.~), (&))
|
import Lens.Family2 ((.~), (&))
|
||||||
|
import Lens.Family2.State.Strict (use)
|
||||||
import Text.Printf (printf)
|
import Text.Printf (printf)
|
||||||
import Proto.Tensorflow.Core.Framework.Tensor
|
import Proto.Tensorflow.Core.Framework.Tensor
|
||||||
( TensorProto
|
( TensorProto
|
||||||
|
@ -161,6 +164,18 @@ placeholder shape' =
|
||||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
& opAttr "shape" .~ shape'
|
& opAttr "shape" .~ shape'
|
||||||
|
|
||||||
|
|
||||||
|
-- | Construct a tensor whose value is the initialized value of the given
|
||||||
|
-- tensor.
|
||||||
|
initializedValue :: forall a. TensorType a
|
||||||
|
=> Tensor Ref a
|
||||||
|
-> Build (Tensor Ref a)
|
||||||
|
initializedValue t = do
|
||||||
|
ns <- use initializationNodes
|
||||||
|
-- Make this tensor depend on the initializers of the other.
|
||||||
|
withNodeDependencies (fromList ns) (render t)
|
||||||
|
|
||||||
|
|
||||||
-- | Creates a variable initialized to the given value.
|
-- | Creates a variable initialized to the given value.
|
||||||
-- Initialization happens next time session runs.
|
-- Initialization happens next time session runs.
|
||||||
initializedVariable :: forall a . TensorType a
|
initializedVariable :: forall a . TensorType a
|
||||||
|
|
|
@ -42,6 +42,7 @@ Test-Suite BuildTest
|
||||||
, google-shim
|
, google-shim
|
||||||
, tensorflow
|
, tensorflow
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
|
, tensorflow-core-ops
|
||||||
, tensorflow-proto
|
, tensorflow-proto
|
||||||
, test-framework
|
, test-framework
|
||||||
, test-framework-hunit
|
, test-framework-hunit
|
||||||
|
|
|
@ -43,10 +43,12 @@ import TensorFlow.Build
|
||||||
)
|
)
|
||||||
import TensorFlow.ControlFlow (named)
|
import TensorFlow.ControlFlow (named)
|
||||||
import TensorFlow.Types (unScalar)
|
import TensorFlow.Types (unScalar)
|
||||||
|
import TensorFlow.GenOps.Core (identity)
|
||||||
import TensorFlow.Ops
|
import TensorFlow.Ops
|
||||||
( add
|
( add
|
||||||
, assign
|
, assign
|
||||||
, constant
|
, constant
|
||||||
|
, initializedValue
|
||||||
, initializedVariable
|
, initializedVariable
|
||||||
, variable
|
, variable
|
||||||
)
|
)
|
||||||
|
@ -109,7 +111,10 @@ testInitializedVariable =
|
||||||
testInitializedVariableShape :: Test
|
testInitializedVariableShape :: Test
|
||||||
testInitializedVariableShape =
|
testInitializedVariableShape =
|
||||||
testCase "testInitializedVariableShape" $ runSession $ do
|
testCase "testInitializedVariableShape" $ runSession $ do
|
||||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
vector <- build $ do
|
||||||
|
a <- initializedVariable (constant [1] [42 :: Float])
|
||||||
|
b <- initializedValue a
|
||||||
|
return b
|
||||||
result <- run vector
|
result <- run vector
|
||||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||||
|
|
||||||
|
|
|
@ -59,6 +59,7 @@ module TensorFlow.Build
|
||||||
, addSummary
|
, addSummary
|
||||||
, SummaryTensor
|
, SummaryTensor
|
||||||
, collectAllSummaries
|
, collectAllSummaries
|
||||||
|
, initializationNodes
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad.IO.Class (MonadIO(..))
|
import Control.Monad.IO.Class (MonadIO(..))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue