diff --git a/tensorflow-ops/src/TensorFlow/Variable.hs b/tensorflow-ops/src/TensorFlow/Variable.hs index 548425d..141622f 100644 --- a/tensorflow-ops/src/TensorFlow/Variable.hs +++ b/tensorflow-ops/src/TensorFlow/Variable.hs @@ -30,7 +30,7 @@ import TensorFlow.Core import TensorFlow.Build (opDef) import TensorFlow.BuildOp (buildInputs, pureOp, OpParams) import TensorFlow.Output (opInputs, unNodeName) -import TensorFlow.Tensor (Rendered(..), tensorNodeName) +import TensorFlow.Tensor (Rendered(..), ToTensor(..), tensorNodeName) import TensorFlow.Types (tensorType) import qualified TensorFlow.GenOps.Core as CoreOps import TensorFlow.Ops (zeros) @@ -40,6 +40,9 @@ newtype Variable a = Variable (Tensor Value ResourceHandle) instance Rendered Variable where renderedOutput (Variable v) = renderedOutput v +instance ToTensor Variable where + toTensor = readValue + -- | Creates a new, uninitialized variable. variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a) variable = variable' id diff --git a/tensorflow/src/TensorFlow/Tensor.hs b/tensorflow/src/TensorFlow/Tensor.hs index bce7a75..6ba1588 100644 --- a/tensorflow/src/TensorFlow/Tensor.hs +++ b/tensorflow/src/TensorFlow/Tensor.hs @@ -38,7 +38,8 @@ import Proto.Tensorflow.Core.Framework.NodeDef (device) import TensorFlow.Build import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..)) import TensorFlow.Types - ( TensorData(..) + ( TensorType + , TensorData(..) , ListOf(..) ) import qualified TensorFlow.Internal.FFI as FFI @@ -189,3 +190,11 @@ instance TensorKind Ref where instance TensorKind Build where toBuild = id + + +-- | Types which can be converted to `Tensor`. +class ToTensor t where + toTensor :: TensorType a => t a -> Tensor Build a + +instance TensorKind v => ToTensor (Tensor v) where + toTensor = expr