diff --git a/ChangeLog.md b/ChangeLog.md index 46ce463..be60358 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -4,6 +4,7 @@ - Expand the `Rendered` class and add a `ToTensor` class to let more functions (gradients, feed, colocateWith) support `ResourceHandle` wrappers like `Variables`. +- Add `initializedValue` function for `Variable`. ## v0.1.0.2 - Add extra-lib-dirs for OS X in the Hackage release (#122). diff --git a/tensorflow-ops/src/TensorFlow/Variable.hs b/tensorflow-ops/src/TensorFlow/Variable.hs index 141622f..5c8f6c7 100644 --- a/tensorflow-ops/src/TensorFlow/Variable.hs +++ b/tensorflow-ops/src/TensorFlow/Variable.hs @@ -14,6 +14,7 @@ module TensorFlow.Variable , variable , variable' , readValue + , initializedValue , initializedVariable , initializedVariable' , zeroInitializedVariable @@ -30,15 +31,19 @@ import TensorFlow.Core import TensorFlow.Build (opDef) import TensorFlow.BuildOp (buildInputs, pureOp, OpParams) import TensorFlow.Output (opInputs, unNodeName) -import TensorFlow.Tensor (Rendered(..), ToTensor(..), tensorNodeName) +import TensorFlow.Tensor (Rendered(..), ToTensor(..), renderValue, tensorNodeName) import TensorFlow.Types (tensorType) import qualified TensorFlow.GenOps.Core as CoreOps import TensorFlow.Ops (zeros) -newtype Variable a = Variable (Tensor Value ResourceHandle) +data Variable a = Variable + { variableHandle :: Tensor Value ResourceHandle + , initializedValue :: Maybe (Tensor Value a) + -- ^ The initial value of a 'Variable' created with 'initializedVariable'. + } instance Rendered Variable where - renderedOutput (Variable v) = renderedOutput v + renderedOutput = renderedOutput . variableHandle instance ToTensor Variable where toTensor = readValue @@ -56,7 +61,7 @@ variable' params s = build $ do rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n)) (tensorType (undefined :: a)) s let n = encodeUtf8 $ unNodeName $ tensorNodeName t - return $ Variable t + return $ Variable t Nothing -- | Creates a variable initialized to the given value. -- Initialization happens next time session runs. @@ -68,10 +73,11 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a) => OpParams -> Tensor v a -> m (Variable a) initializedVariable' params initializer = do -- The shape is not known initially. - v@(Variable h) <- variable' params (Shape []) - i <- CoreOps.assignVariableOp h initializer + (Variable h Nothing :: Variable a) <- variable' params (Shape []) + initializer' <- renderValue initializer + i <- CoreOps.assignVariableOp h initializer' addInitializer =<< group i - return v + return (Variable h (Just initializer')) -- | Creates a zero-initialized variable with the given shape. zeroInitializedVariable @@ -102,7 +108,7 @@ readValue = readValue' id readValue' :: forall a . TensorType a => OpParams -> Variable a -> Tensor Build a -readValue' params (Variable h) +readValue' params (Variable h _) = pureOp [] $ do os <- buildInputs h pure $ opDef "ReadVariableOp" @@ -117,7 +123,7 @@ assign = assign' id assign' :: (MonadBuild m, TensorType a) => OpParams -> Variable a -> Tensor v a -> m ControlNode -assign' params (Variable h) v = CoreOps.assignVariableOp' params h v +assign' params (Variable h _) v = CoreOps.assignVariableOp' params h v -- | Increments the value of a variable. assignAdd :: (MonadBuild m, TensorType a) @@ -126,4 +132,4 @@ assignAdd = assignAdd' id assignAdd' :: (MonadBuild m, TensorType a) => OpParams -> Variable a -> Tensor v a -> m ControlNode -assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v +assignAdd' params (Variable h _) v = CoreOps.assignAddVariableOp' params h v diff --git a/tensorflow-ops/tests/VariableTest.hs b/tensorflow-ops/tests/VariableTest.hs index f752a60..0a8f4e9 100644 --- a/tensorflow-ops/tests/VariableTest.hs +++ b/tensorflow-ops/tests/VariableTest.hs @@ -1,6 +1,8 @@ {-# LANGUAGE OverloadedLists #-} module Main (main) where +import Data.Maybe (isJust) +import Control.Monad (when) import Control.Monad.IO.Class (liftIO) import qualified Data.Vector.Storable as V import TensorFlow.Core @@ -12,7 +14,9 @@ import TensorFlow.Core , withControlDependencies) import qualified TensorFlow.Ops as Ops import TensorFlow.Variable - ( readValue + ( Variable + , readValue + , initializedValue , initializedVariable , assign , assignAdd @@ -20,12 +24,13 @@ import TensorFlow.Variable ) import Test.Framework (defaultMain, Test) import Test.Framework.Providers.HUnit (testCase) -import Test.HUnit ((@=?)) +import Test.HUnit ((@=?), assertFailure) main :: IO () main = defaultMain [ testInitializedVariable , testInitializedVariableShape + , testInitializedValue , testDependency , testRereadRef , testAssignAdd @@ -51,6 +56,18 @@ testInitializedVariableShape = result <- run (readValue vector) liftIO $ [42] @=? (result :: V.Vector Float) +testInitializedValue :: Test +testInitializedValue = + testCase "testInitializedValue" $ runSession $ do + initialized <- initializedVariable (Ops.constant [1] [42 :: Float]) + result <- run (initializedValue initialized) + liftIO $ Just [42] @=? (result :: Maybe (V.Vector Float)) + + uninitialized <- variable [1] + -- Can't use @=? because there is no Show instance for Tensor. + when (isJust (initializedValue (uninitialized :: Variable Float))) $ + liftIO $ assertFailure "initializedValue should be Nothing, got Just" + testDependency :: Test testDependency = testCase "testDependency" $ runSession $ do diff --git a/tensorflow/src/TensorFlow/Nodes.hs b/tensorflow/src/TensorFlow/Nodes.hs index 6731600..ba2f32d 100644 --- a/tensorflow/src/TensorFlow/Nodes.hs +++ b/tensorflow/src/TensorFlow/Nodes.hs @@ -89,6 +89,12 @@ instance Nodes t => Nodes [t] where instance Fetchable t a => Fetchable [t] [a] where getFetch ts = sequenceA <$> mapM getFetch ts +instance Nodes t => Nodes (Maybe t) where + getNodes = nodesUnion . fmap getNodes + +instance Fetchable t a => Fetchable (Maybe t) (Maybe a) where + getFetch = fmap sequenceA . mapM getFetch + instance Nodes ControlNode where getNodes (ControlNode o) = pure $ Set.singleton o