1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-03 16:29:46 +01:00

Expand Rendered class to support ResourceHandle wrappers like Variable

This allows functions like `feed`, `colocateWith`, and (in a later commit)
`gradients` to work with `Variable`.
This commit is contained in:
fkm3 2017-05-14 13:32:19 -07:00
parent e924901b90
commit 0f04e5a50d
6 changed files with 26 additions and 25 deletions

View file

@ -46,7 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v1 v2 m .
( MonadBuild m
, Rendered v1
, Rendered (Tensor v1)
, TensorType a
, OneOf '[Int64, Int32] b
, Num b

View file

@ -116,10 +116,10 @@ type GradientCompatible a =
-- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 m . (MonadBuild m
, Rendered v2
, GradientCompatible a
)
gradients :: forall a v1 v2 m . ( MonadBuild m
, Rendered (Tensor v2)
, GradientCompatible a
)
=> Tensor v1 a -- ^ The output of the graph.
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
-> m [Tensor Value a]

View file

@ -241,8 +241,8 @@ zeroInitializedVariable'
zeroInitializedVariable' params = initializedVariable' params . zeros
-- TODO: Support heterogeneous list of tensors.
save :: forall a m v . (Rendered v, MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
save :: forall a m v . (Rendered (Tensor v), MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save.
-> m ControlNode
save path xs = build $ do

View file

@ -30,13 +30,16 @@ import TensorFlow.Core
import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (tensorNodeName)
import TensorFlow.Tensor (Rendered(..), tensorNodeName)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros)
newtype Variable a = Variable (Tensor Value ResourceHandle)
instance Rendered Variable where
renderedOutput (Variable v) = renderedOutput v
-- | Creates a new, uninitialized variable.
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
variable = variable' id

View file

@ -126,13 +126,13 @@ recordResult = do
put $! ResultState (i+1) ns
return $! output i o
instance Rendered v => BuildResult (Tensor v a) where
instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where
buildResult = Tensor . pure <$> recordResult
instance BuildResult ControlNode where
buildResult = ControlNode <$> ask
instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where
instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where
buildResult = loop (tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> Result (TensorList v bs)

View file

@ -13,6 +13,7 @@
-- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
@ -89,19 +90,16 @@ data Feed = Feed Output FFI.TensorData
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
-- name, device, etc.
class TensorKind v => Rendered v where
rendered :: v a -> a
class Rendered t where
renderedOutput :: t a -> Output
instance Rendered Value where
rendered = runValue
instance Rendered (Tensor Value) where
renderedOutput = runValue . tensorOutput
instance Rendered Ref where
rendered = runRef
instance Rendered (Tensor Ref) where
renderedOutput = runRef . tensorOutput
renderedOutput :: Rendered v => Tensor v a -> Output
renderedOutput = rendered . tensorOutput
tensorNodeName :: Rendered v => Tensor v a -> NodeName
tensorNodeName :: Rendered t => t a -> NodeName
tensorNodeName = outputNodeName . renderedOutput
@ -110,7 +108,7 @@ tensorNodeName = outputNodeName . renderedOutput
--
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
feed :: Rendered v => Tensor v a -> TensorData a -> Feed
feed :: Rendered t => t a -> TensorData a -> Feed
feed t (TensorData td) = Feed (renderedOutput t) td
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
@ -129,7 +127,7 @@ tensorRefFromName = tensorFromName
type TensorList v = ListOf (Tensor v)
tensorListOutputs :: Rendered v => TensorList v as -> [Output]
tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output]
tensorListOutputs Nil = []
tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
@ -137,7 +135,7 @@ tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
-- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect.
colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a
colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a
colocateWith t x = do
d <- build $ Device . (^. device)
<$> lookupNode (outputNodeName $ renderedOutput t)
@ -184,10 +182,10 @@ class Monad v => TensorKind v where
toBuild :: v a -> Build a
instance TensorKind Value where
toBuild = return . rendered
toBuild = return . runValue
instance TensorKind Ref where
toBuild = return . rendered
toBuild = return . runRef
instance TensorKind Build where
toBuild = id