1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00
This commit is contained in:
Noon van der Silk 2017-03-18 08:23:25 +00:00 committed by GitHub
commit 5913ade2b0
4 changed files with 24 additions and 1 deletions

View File

@ -68,6 +68,7 @@ module TensorFlow.Ops
, constant
, CoreOps.equal
, expandDims
, initializedValue
, initializedVariable
, zeroInitializedVariable
, CoreOps.fill
@ -111,7 +112,9 @@ import Data.Int (Int32, Int64)
import Prelude hiding (abs, sum, concat)
import Data.ProtoLens (def)
import Data.Text.Encoding (encodeUtf8)
import Data.Set (fromList)
import Lens.Family2 ((.~), (&))
import Lens.Family2.State.Strict (use)
import Text.Printf (printf)
import Proto.Tensorflow.Core.Framework.Tensor
( TensorProto
@ -161,6 +164,18 @@ placeholder shape' =
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape'
-- | Construct a tensor whose value is the initialized value of the given
-- tensor.
initializedValue :: forall a v. TensorType a
=> Tensor v a
-> Build (Tensor v a)
initializedValue t = do
ns <- use initializationNodes
-- Make this tensor depend on the initializers of the other.
withNodeDependencies (fromList ns) (render t)
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a

View File

@ -54,6 +54,7 @@ Test-Suite BuildTest
, google-shim
, tensorflow
, tensorflow-ops
, tensorflow-core-ops
, tensorflow-proto
, test-framework
, test-framework-hunit

View File

@ -43,10 +43,12 @@ import TensorFlow.Build
)
import TensorFlow.ControlFlow (named)
import TensorFlow.Types (unScalar)
import TensorFlow.GenOps.Core (identity)
import TensorFlow.Ops
( add
, assign
, constant
, initializedValue
, initializedVariable
, variable
)
@ -109,7 +111,11 @@ testInitializedVariable =
testInitializedVariableShape :: Test
testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do
vector <- build $ initializedVariable (constant [1] [42 :: Float])
vector <- build $ do
a <- initializedVariable (constant [1] [42 :: Float])
b <- initializedValue (identity a)
c <- initializedVariable b
return c
result <- run vector
liftIO $ [42] @=? (result :: V.Vector Float)

View File

@ -59,6 +59,7 @@ module TensorFlow.Build
, addSummary
, SummaryTensor
, collectAllSummaries
, initializationNodes
) where
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)