mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-23 09:19:49 +01:00
a7cbc27d36
Distinguish between "rendered" and "unrendered" Tensors. There are now three types of `Tensor`: - `Tensor Value a`: rendered value - `Tensor Ref a`: rendered reference - `Tensor Build a` : unrendered value The extra bookkeeping makes it easier to track (and enforce) which tensors are rendered or not. For examples where this has been confusing in the past, see With this change, pure ops look similar to before, returning `Tensor Build` instead of `Tensor Value`. "Stateful" (monadic) ops are unchanged. For example: add :: OneOf [..] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t assign :: (MonadBuild m, TensorType t) => Tensor Ref t -> Tensor v'2 t -> m (Tensor Ref t) The `gradients` function now requires that the variables over which it's differentiating are pre-rendered: gradients :: (..., Rendered v2) => Tensor v1 a -> [Tensor v2 a] -> m [Tensor Value a] (`Rendered v2` means that `v2` is either a `Ref` or a `Value`.) Additionally, the implementation of `gradients` now takes care to render every intermediate value when performing the reverse accumulation. I suspect this fixes an exponential blowup for complicated expressions. |
||
---|---|---|
.. | ||
ArrayOpsTest.hs | ||
BuildTest.hs | ||
DataFlowOpsTest.hs | ||
EmbeddingOpsTest.hs | ||
FeedFetchBench.hs | ||
GradientTest.hs | ||
MiscTest.hs | ||
OpsTest.hs | ||
RegressionTest.hs | ||
TracingTest.hs | ||
TypesTest.hs |