mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
Expand Rendered class to support ResourceHandle wrappers like Variable
This allows functions like `feed`, `colocateWith`, and (in a later commit) `gradients` to work with `Variable`.
This commit is contained in:
parent
e924901b90
commit
0f04e5a50d
6 changed files with 26 additions and 25 deletions
|
@ -46,7 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||||
embeddingLookup :: forall a b v1 v2 m .
|
embeddingLookup :: forall a b v1 v2 m .
|
||||||
( MonadBuild m
|
( MonadBuild m
|
||||||
, Rendered v1
|
, Rendered (Tensor v1)
|
||||||
, TensorType a
|
, TensorType a
|
||||||
, OneOf '[Int64, Int32] b
|
, OneOf '[Int64, Int32] b
|
||||||
, Num b
|
, Num b
|
||||||
|
|
|
@ -116,8 +116,8 @@ type GradientCompatible a =
|
||||||
|
|
||||||
|
|
||||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||||
gradients :: forall a v1 v2 m . (MonadBuild m
|
gradients :: forall a v1 v2 m . ( MonadBuild m
|
||||||
, Rendered v2
|
, Rendered (Tensor v2)
|
||||||
, GradientCompatible a
|
, GradientCompatible a
|
||||||
)
|
)
|
||||||
=> Tensor v1 a -- ^ The output of the graph.
|
=> Tensor v1 a -- ^ The output of the graph.
|
||||||
|
|
|
@ -241,7 +241,7 @@ zeroInitializedVariable'
|
||||||
zeroInitializedVariable' params = initializedVariable' params . zeros
|
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||||
|
|
||||||
-- TODO: Support heterogeneous list of tensors.
|
-- TODO: Support heterogeneous list of tensors.
|
||||||
save :: forall a m v . (Rendered v, MonadBuild m, TensorType a)
|
save :: forall a m v . (Rendered (Tensor v), MonadBuild m, TensorType a)
|
||||||
=> ByteString -- ^ File path.
|
=> ByteString -- ^ File path.
|
||||||
-> [Tensor v a] -- ^ Tensors to save.
|
-> [Tensor v a] -- ^ Tensors to save.
|
||||||
-> m ControlNode
|
-> m ControlNode
|
||||||
|
|
|
@ -30,13 +30,16 @@ import TensorFlow.Core
|
||||||
import TensorFlow.Build (opDef)
|
import TensorFlow.Build (opDef)
|
||||||
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
||||||
import TensorFlow.Output (opInputs, unNodeName)
|
import TensorFlow.Output (opInputs, unNodeName)
|
||||||
import TensorFlow.Tensor (tensorNodeName)
|
import TensorFlow.Tensor (Rendered(..), tensorNodeName)
|
||||||
import TensorFlow.Types (tensorType)
|
import TensorFlow.Types (tensorType)
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
import TensorFlow.Ops (zeros)
|
import TensorFlow.Ops (zeros)
|
||||||
|
|
||||||
newtype Variable a = Variable (Tensor Value ResourceHandle)
|
newtype Variable a = Variable (Tensor Value ResourceHandle)
|
||||||
|
|
||||||
|
instance Rendered Variable where
|
||||||
|
renderedOutput (Variable v) = renderedOutput v
|
||||||
|
|
||||||
-- | Creates a new, uninitialized variable.
|
-- | Creates a new, uninitialized variable.
|
||||||
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
||||||
variable = variable' id
|
variable = variable' id
|
||||||
|
|
|
@ -126,13 +126,13 @@ recordResult = do
|
||||||
put $! ResultState (i+1) ns
|
put $! ResultState (i+1) ns
|
||||||
return $! output i o
|
return $! output i o
|
||||||
|
|
||||||
instance Rendered v => BuildResult (Tensor v a) where
|
instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where
|
||||||
buildResult = Tensor . pure <$> recordResult
|
buildResult = Tensor . pure <$> recordResult
|
||||||
|
|
||||||
instance BuildResult ControlNode where
|
instance BuildResult ControlNode where
|
||||||
buildResult = ControlNode <$> ask
|
buildResult = ControlNode <$> ask
|
||||||
|
|
||||||
instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where
|
instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where
|
||||||
buildResult = loop (tensorTypes :: TensorTypeList as)
|
buildResult = loop (tensorTypes :: TensorTypeList as)
|
||||||
where
|
where
|
||||||
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE FunctionalDependencies #-}
|
{-# LANGUAGE FunctionalDependencies #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
|
@ -89,19 +90,16 @@ data Feed = Feed Output FFI.TensorData
|
||||||
|
|
||||||
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
|
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
|
||||||
-- name, device, etc.
|
-- name, device, etc.
|
||||||
class TensorKind v => Rendered v where
|
class Rendered t where
|
||||||
rendered :: v a -> a
|
renderedOutput :: t a -> Output
|
||||||
|
|
||||||
instance Rendered Value where
|
instance Rendered (Tensor Value) where
|
||||||
rendered = runValue
|
renderedOutput = runValue . tensorOutput
|
||||||
|
|
||||||
instance Rendered Ref where
|
instance Rendered (Tensor Ref) where
|
||||||
rendered = runRef
|
renderedOutput = runRef . tensorOutput
|
||||||
|
|
||||||
renderedOutput :: Rendered v => Tensor v a -> Output
|
tensorNodeName :: Rendered t => t a -> NodeName
|
||||||
renderedOutput = rendered . tensorOutput
|
|
||||||
|
|
||||||
tensorNodeName :: Rendered v => Tensor v a -> NodeName
|
|
||||||
tensorNodeName = outputNodeName . renderedOutput
|
tensorNodeName = outputNodeName . renderedOutput
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,7 +108,7 @@ tensorNodeName = outputNodeName . renderedOutput
|
||||||
--
|
--
|
||||||
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
|
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
|
||||||
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
|
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
|
||||||
feed :: Rendered v => Tensor v a -> TensorData a -> Feed
|
feed :: Rendered t => t a -> TensorData a -> Feed
|
||||||
feed t (TensorData td) = Feed (renderedOutput t) td
|
feed t (TensorData td) = Feed (renderedOutput t) td
|
||||||
|
|
||||||
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
|
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
|
||||||
|
@ -129,7 +127,7 @@ tensorRefFromName = tensorFromName
|
||||||
|
|
||||||
type TensorList v = ListOf (Tensor v)
|
type TensorList v = ListOf (Tensor v)
|
||||||
|
|
||||||
tensorListOutputs :: Rendered v => TensorList v as -> [Output]
|
tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output]
|
||||||
tensorListOutputs Nil = []
|
tensorListOutputs Nil = []
|
||||||
tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
|
tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
|
||||||
|
|
||||||
|
@ -137,7 +135,7 @@ tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
|
||||||
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||||
-- the action has side effects of rendering the desired tensors. A pure
|
-- the action has side effects of rendering the desired tensors. A pure
|
||||||
-- return would not have the desired effect.
|
-- return would not have the desired effect.
|
||||||
colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a
|
colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a
|
||||||
colocateWith t x = do
|
colocateWith t x = do
|
||||||
d <- build $ Device . (^. device)
|
d <- build $ Device . (^. device)
|
||||||
<$> lookupNode (outputNodeName $ renderedOutput t)
|
<$> lookupNode (outputNodeName $ renderedOutput t)
|
||||||
|
@ -184,10 +182,10 @@ class Monad v => TensorKind v where
|
||||||
toBuild :: v a -> Build a
|
toBuild :: v a -> Build a
|
||||||
|
|
||||||
instance TensorKind Value where
|
instance TensorKind Value where
|
||||||
toBuild = return . rendered
|
toBuild = return . runValue
|
||||||
|
|
||||||
instance TensorKind Ref where
|
instance TensorKind Ref where
|
||||||
toBuild = return . rendered
|
toBuild = return . runRef
|
||||||
|
|
||||||
instance TensorKind Build where
|
instance TensorKind Build where
|
||||||
toBuild = id
|
toBuild = id
|
||||||
|
|
Loading…
Reference in a new issue