1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Expand Rendered class to support ResourceHandle wrappers like Variable

This allows functions like `feed`, `colocateWith`, and (in a later commit)
`gradients` to work with `Variable`.
This commit is contained in:
fkm3 2017-05-14 13:32:19 -07:00
parent e924901b90
commit 0f04e5a50d
6 changed files with 26 additions and 25 deletions

View file

@ -46,7 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v1 v2 m . embeddingLookup :: forall a b v1 v2 m .
( MonadBuild m ( MonadBuild m
, Rendered v1 , Rendered (Tensor v1)
, TensorType a , TensorType a
, OneOf '[Int64, Int32] b , OneOf '[Int64, Int32] b
, Num b , Num b

View file

@ -116,10 +116,10 @@ type GradientCompatible a =
-- | Gradient of @y@ w.r.t. each element of @xs@. -- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 m . (MonadBuild m gradients :: forall a v1 v2 m . ( MonadBuild m
, Rendered v2 , Rendered (Tensor v2)
, GradientCompatible a , GradientCompatible a
) )
=> Tensor v1 a -- ^ The output of the graph. => Tensor v1 a -- ^ The output of the graph.
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed. -> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
-> m [Tensor Value a] -> m [Tensor Value a]

View file

@ -241,8 +241,8 @@ zeroInitializedVariable'
zeroInitializedVariable' params = initializedVariable' params . zeros zeroInitializedVariable' params = initializedVariable' params . zeros
-- TODO: Support heterogeneous list of tensors. -- TODO: Support heterogeneous list of tensors.
save :: forall a m v . (Rendered v, MonadBuild m, TensorType a) save :: forall a m v . (Rendered (Tensor v), MonadBuild m, TensorType a)
=> ByteString -- ^ File path. => ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save. -> [Tensor v a] -- ^ Tensors to save.
-> m ControlNode -> m ControlNode
save path xs = build $ do save path xs = build $ do

View file

@ -30,13 +30,16 @@ import TensorFlow.Core
import TensorFlow.Build (opDef) import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams) import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName) import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (tensorNodeName) import TensorFlow.Tensor (Rendered(..), tensorNodeName)
import TensorFlow.Types (tensorType) import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros) import TensorFlow.Ops (zeros)
newtype Variable a = Variable (Tensor Value ResourceHandle) newtype Variable a = Variable (Tensor Value ResourceHandle)
instance Rendered Variable where
renderedOutput (Variable v) = renderedOutput v
-- | Creates a new, uninitialized variable. -- | Creates a new, uninitialized variable.
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a) variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
variable = variable' id variable = variable' id

View file

@ -126,13 +126,13 @@ recordResult = do
put $! ResultState (i+1) ns put $! ResultState (i+1) ns
return $! output i o return $! output i o
instance Rendered v => BuildResult (Tensor v a) where instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where
buildResult = Tensor . pure <$> recordResult buildResult = Tensor . pure <$> recordResult
instance BuildResult ControlNode where instance BuildResult ControlNode where
buildResult = ControlNode <$> ask buildResult = ControlNode <$> ask
instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where
buildResult = loop (tensorTypes :: TensorTypeList as) buildResult = loop (tensorTypes :: TensorTypeList as)
where where
loop :: TensorTypeList bs -> Result (TensorList v bs) loop :: TensorTypeList bs -> Result (TensorList v bs)

View file

@ -13,6 +13,7 @@
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
@ -89,19 +90,16 @@ data Feed = Feed Output FFI.TensorData
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed -- | A class ensuring that a given tensor is rendered, i.e., has a fixed
-- name, device, etc. -- name, device, etc.
class TensorKind v => Rendered v where class Rendered t where
rendered :: v a -> a renderedOutput :: t a -> Output
instance Rendered Value where instance Rendered (Tensor Value) where
rendered = runValue renderedOutput = runValue . tensorOutput
instance Rendered Ref where instance Rendered (Tensor Ref) where
rendered = runRef renderedOutput = runRef . tensorOutput
renderedOutput :: Rendered v => Tensor v a -> Output tensorNodeName :: Rendered t => t a -> NodeName
renderedOutput = rendered . tensorOutput
tensorNodeName :: Rendered v => Tensor v a -> NodeName
tensorNodeName = outputNodeName . renderedOutput tensorNodeName = outputNodeName . renderedOutput
@ -110,7 +108,7 @@ tensorNodeName = outputNodeName . renderedOutput
-- --
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the -- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
-- rendered 'Tensor' may be different than feeding the original 'Tensor'. -- rendered 'Tensor' may be different than feeding the original 'Tensor'.
feed :: Rendered v => Tensor v a -> TensorData a -> Feed feed :: Rendered t => t a -> TensorData a -> Feed
feed t (TensorData td) = Feed (renderedOutput t) td feed t (TensorData td) = Feed (renderedOutput t) td
-- | Create a 'Tensor' for a given name. This can be used to reference nodes -- | Create a 'Tensor' for a given name. This can be used to reference nodes
@ -129,7 +127,7 @@ tensorRefFromName = tensorFromName
type TensorList v = ListOf (Tensor v) type TensorList v = ListOf (Tensor v)
tensorListOutputs :: Rendered v => TensorList v as -> [Output] tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output]
tensorListOutputs Nil = [] tensorListOutputs Nil = []
tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
@ -137,7 +135,7 @@ tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
-- device as the given Tensor (see also 'withDevice'). Make sure that -- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure -- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect. -- return would not have the desired effect.
colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a
colocateWith t x = do colocateWith t x = do
d <- build $ Device . (^. device) d <- build $ Device . (^. device)
<$> lookupNode (outputNodeName $ renderedOutput t) <$> lookupNode (outputNodeName $ renderedOutput t)
@ -184,10 +182,10 @@ class Monad v => TensorKind v where
toBuild :: v a -> Build a toBuild :: v a -> Build a
instance TensorKind Value where instance TensorKind Value where
toBuild = return . rendered toBuild = return . runValue
instance TensorKind Ref where instance TensorKind Ref where
toBuild = return . rendered toBuild = return . runRef
instance TensorKind Build where instance TensorKind Build where
toBuild = id toBuild = id