diff --git a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs index 7b9e7fc..1b68870 100644 --- a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs +++ b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs @@ -46,7 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. embeddingLookup :: forall a b v1 v2 m . ( MonadBuild m - , Rendered v1 + , Rendered (Tensor v1) , TensorType a , OneOf '[Int64, Int32] b , Num b diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 612eb42..491800f 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -116,10 +116,10 @@ type GradientCompatible a = -- | Gradient of @y@ w.r.t. each element of @xs@. -gradients :: forall a v1 v2 m . (MonadBuild m - , Rendered v2 - , GradientCompatible a - ) +gradients :: forall a v1 v2 m . ( MonadBuild m + , Rendered (Tensor v2) + , GradientCompatible a + ) => Tensor v1 a -- ^ The output of the graph. -> [Tensor v2 a] -- ^ Tensors for which gradients are computed. -> m [Tensor Value a] diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 2811f75..6031a55 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -241,8 +241,8 @@ zeroInitializedVariable' zeroInitializedVariable' params = initializedVariable' params . zeros -- TODO: Support heterogeneous list of tensors. -save :: forall a m v . (Rendered v, MonadBuild m, TensorType a) - => ByteString -- ^ File path. +save :: forall a m v . (Rendered (Tensor v), MonadBuild m, TensorType a) + => ByteString -- ^ File path. -> [Tensor v a] -- ^ Tensors to save. -> m ControlNode save path xs = build $ do diff --git a/tensorflow-ops/src/TensorFlow/Variable.hs b/tensorflow-ops/src/TensorFlow/Variable.hs index 76e6e48..548425d 100644 --- a/tensorflow-ops/src/TensorFlow/Variable.hs +++ b/tensorflow-ops/src/TensorFlow/Variable.hs @@ -30,13 +30,16 @@ import TensorFlow.Core import TensorFlow.Build (opDef) import TensorFlow.BuildOp (buildInputs, pureOp, OpParams) import TensorFlow.Output (opInputs, unNodeName) -import TensorFlow.Tensor (tensorNodeName) +import TensorFlow.Tensor (Rendered(..), tensorNodeName) import TensorFlow.Types (tensorType) import qualified TensorFlow.GenOps.Core as CoreOps import TensorFlow.Ops (zeros) newtype Variable a = Variable (Tensor Value ResourceHandle) +instance Rendered Variable where + renderedOutput (Variable v) = renderedOutput v + -- | Creates a new, uninitialized variable. variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a) variable = variable' id diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index 50019f7..f47ddef 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -126,13 +126,13 @@ recordResult = do put $! ResultState (i+1) ns return $! output i o -instance Rendered v => BuildResult (Tensor v a) where +instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where buildResult = Tensor . pure <$> recordResult instance BuildResult ControlNode where buildResult = ControlNode <$> ask -instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where +instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where buildResult = loop (tensorTypes :: TensorTypeList as) where loop :: TensorTypeList bs -> Result (TensorList v bs) diff --git a/tensorflow/src/TensorFlow/Tensor.hs b/tensorflow/src/TensorFlow/Tensor.hs index bf3fc1f..bce7a75 100644 --- a/tensorflow/src/TensorFlow/Tensor.hs +++ b/tensorflow/src/TensorFlow/Tensor.hs @@ -13,6 +13,7 @@ -- limitations under the License. {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} @@ -89,19 +90,16 @@ data Feed = Feed Output FFI.TensorData -- | A class ensuring that a given tensor is rendered, i.e., has a fixed -- name, device, etc. -class TensorKind v => Rendered v where - rendered :: v a -> a +class Rendered t where + renderedOutput :: t a -> Output -instance Rendered Value where - rendered = runValue +instance Rendered (Tensor Value) where + renderedOutput = runValue . tensorOutput -instance Rendered Ref where - rendered = runRef +instance Rendered (Tensor Ref) where + renderedOutput = runRef . tensorOutput -renderedOutput :: Rendered v => Tensor v a -> Output -renderedOutput = rendered . tensorOutput - -tensorNodeName :: Rendered v => Tensor v a -> NodeName +tensorNodeName :: Rendered t => t a -> NodeName tensorNodeName = outputNodeName . renderedOutput @@ -110,7 +108,7 @@ tensorNodeName = outputNodeName . renderedOutput -- -- Note that if a 'Tensor' is rendered, its identity may change; so feeding the -- rendered 'Tensor' may be different than feeding the original 'Tensor'. -feed :: Rendered v => Tensor v a -> TensorData a -> Feed +feed :: Rendered t => t a -> TensorData a -> Feed feed t (TensorData td) = Feed (renderedOutput t) td -- | Create a 'Tensor' for a given name. This can be used to reference nodes @@ -129,7 +127,7 @@ tensorRefFromName = tensorFromName type TensorList v = ListOf (Tensor v) -tensorListOutputs :: Rendered v => TensorList v as -> [Output] +tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output] tensorListOutputs Nil = [] tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts @@ -137,7 +135,7 @@ tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts -- device as the given Tensor (see also 'withDevice'). Make sure that -- the action has side effects of rendering the desired tensors. A pure -- return would not have the desired effect. -colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a +colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a colocateWith t x = do d <- build $ Device . (^. device) <$> lookupNode (outputNodeName $ renderedOutput t) @@ -184,10 +182,10 @@ class Monad v => TensorKind v where toBuild :: v a -> Build a instance TensorKind Value where - toBuild = return . rendered + toBuild = return . runValue instance TensorKind Ref where - toBuild = return . rendered + toBuild = return . runRef instance TensorKind Build where toBuild = id