mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-03 16:29:46 +01:00
Expand Rendered class to support ResourceHandle wrappers like Variable
This allows functions like `feed`, `colocateWith`, and (in a later commit) `gradients` to work with `Variable`.
This commit is contained in:
parent
e924901b90
commit
0f04e5a50d
6 changed files with 26 additions and 25 deletions
|
@ -46,7 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps
|
|||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup :: forall a b v1 v2 m .
|
||||
( MonadBuild m
|
||||
, Rendered v1
|
||||
, Rendered (Tensor v1)
|
||||
, TensorType a
|
||||
, OneOf '[Int64, Int32] b
|
||||
, Num b
|
||||
|
|
|
@ -116,10 +116,10 @@ type GradientCompatible a =
|
|||
|
||||
|
||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||
gradients :: forall a v1 v2 m . (MonadBuild m
|
||||
, Rendered v2
|
||||
, GradientCompatible a
|
||||
)
|
||||
gradients :: forall a v1 v2 m . ( MonadBuild m
|
||||
, Rendered (Tensor v2)
|
||||
, GradientCompatible a
|
||||
)
|
||||
=> Tensor v1 a -- ^ The output of the graph.
|
||||
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
||||
-> m [Tensor Value a]
|
||||
|
|
|
@ -241,8 +241,8 @@ zeroInitializedVariable'
|
|||
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||
|
||||
-- TODO: Support heterogeneous list of tensors.
|
||||
save :: forall a m v . (Rendered v, MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
save :: forall a m v . (Rendered (Tensor v), MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
-> [Tensor v a] -- ^ Tensors to save.
|
||||
-> m ControlNode
|
||||
save path xs = build $ do
|
||||
|
|
|
@ -30,13 +30,16 @@ import TensorFlow.Core
|
|||
import TensorFlow.Build (opDef)
|
||||
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
||||
import TensorFlow.Output (opInputs, unNodeName)
|
||||
import TensorFlow.Tensor (tensorNodeName)
|
||||
import TensorFlow.Tensor (Rendered(..), tensorNodeName)
|
||||
import TensorFlow.Types (tensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Ops (zeros)
|
||||
|
||||
newtype Variable a = Variable (Tensor Value ResourceHandle)
|
||||
|
||||
instance Rendered Variable where
|
||||
renderedOutput (Variable v) = renderedOutput v
|
||||
|
||||
-- | Creates a new, uninitialized variable.
|
||||
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
||||
variable = variable' id
|
||||
|
|
|
@ -126,13 +126,13 @@ recordResult = do
|
|||
put $! ResultState (i+1) ns
|
||||
return $! output i o
|
||||
|
||||
instance Rendered v => BuildResult (Tensor v a) where
|
||||
instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where
|
||||
buildResult = Tensor . pure <$> recordResult
|
||||
|
||||
instance BuildResult ControlNode where
|
||||
buildResult = ControlNode <$> ask
|
||||
|
||||
instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where
|
||||
instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where
|
||||
buildResult = loop (tensorTypes :: TensorTypeList as)
|
||||
where
|
||||
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
|
@ -89,19 +90,16 @@ data Feed = Feed Output FFI.TensorData
|
|||
|
||||
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
|
||||
-- name, device, etc.
|
||||
class TensorKind v => Rendered v where
|
||||
rendered :: v a -> a
|
||||
class Rendered t where
|
||||
renderedOutput :: t a -> Output
|
||||
|
||||
instance Rendered Value where
|
||||
rendered = runValue
|
||||
instance Rendered (Tensor Value) where
|
||||
renderedOutput = runValue . tensorOutput
|
||||
|
||||
instance Rendered Ref where
|
||||
rendered = runRef
|
||||
instance Rendered (Tensor Ref) where
|
||||
renderedOutput = runRef . tensorOutput
|
||||
|
||||
renderedOutput :: Rendered v => Tensor v a -> Output
|
||||
renderedOutput = rendered . tensorOutput
|
||||
|
||||
tensorNodeName :: Rendered v => Tensor v a -> NodeName
|
||||
tensorNodeName :: Rendered t => t a -> NodeName
|
||||
tensorNodeName = outputNodeName . renderedOutput
|
||||
|
||||
|
||||
|
@ -110,7 +108,7 @@ tensorNodeName = outputNodeName . renderedOutput
|
|||
--
|
||||
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
|
||||
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
|
||||
feed :: Rendered v => Tensor v a -> TensorData a -> Feed
|
||||
feed :: Rendered t => t a -> TensorData a -> Feed
|
||||
feed t (TensorData td) = Feed (renderedOutput t) td
|
||||
|
||||
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
|
||||
|
@ -129,7 +127,7 @@ tensorRefFromName = tensorFromName
|
|||
|
||||
type TensorList v = ListOf (Tensor v)
|
||||
|
||||
tensorListOutputs :: Rendered v => TensorList v as -> [Output]
|
||||
tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output]
|
||||
tensorListOutputs Nil = []
|
||||
tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
|
||||
|
||||
|
@ -137,7 +135,7 @@ tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
|
|||
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||
-- the action has side effects of rendering the desired tensors. A pure
|
||||
-- return would not have the desired effect.
|
||||
colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a
|
||||
colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a
|
||||
colocateWith t x = do
|
||||
d <- build $ Device . (^. device)
|
||||
<$> lookupNode (outputNodeName $ renderedOutput t)
|
||||
|
@ -184,10 +182,10 @@ class Monad v => TensorKind v where
|
|||
toBuild :: v a -> Build a
|
||||
|
||||
instance TensorKind Value where
|
||||
toBuild = return . rendered
|
||||
toBuild = return . runValue
|
||||
|
||||
instance TensorKind Ref where
|
||||
toBuild = return . rendered
|
||||
toBuild = return . runRef
|
||||
|
||||
instance TensorKind Build where
|
||||
toBuild = id
|
||||
|
|
Loading…
Reference in a new issue