Commit Graph

9 Commits

Author SHA1 Message Date
Bart Schuurmans dad19fde31 Add support for loading a Session from a SavedModel 2023-02-05 13:41:24 -08:00
Judah Jacobson 42f4fc647e Add resource-based variable ops. (#98)
The main difference between these and the `Ref`-bases ops is the explicit
`readValue` op.  I'm not sure how this should interact with gradients
and save/restore, so I'm keeping it as a separate module for now.  Once we
figure out the details, we can merge it into `TensorFlow.Ops` and replace
all uses of the old `Ref`-based ops.  (That would also fix #92.)

Also replaces our special case newtype `ResourceHandle` to
`Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto
corresponding to `DT_RESOURCE`.
2017-04-16 09:24:02 -07:00
Judah Jacobson d62c614695 Distinguish between "rendered" and "unrendered" Tensors. (#88)
Distinguish between "rendered" and "unrendered" Tensors.

There are now three types of `Tensor`:

- `Tensor Value a`: rendered value
- `Tensor Ref a`: rendered reference
- `Tensor Build a` : unrendered value

The extra bookkeeping makes it easier to track (and enforce) which tensors are
rendered or not.  For examples where this has been confusing in the past, see

With this change, pure ops look similar to before, returning `Tensor Build`
instead of `Tensor Value`.  "Stateful" (monadic) ops are unchanged.  For
example:

    add :: OneOf [..] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
    assign :: (MonadBuild m, TensorType t)
           => Tensor Ref t -> Tensor v'2 t -> m (Tensor Ref t)

The `gradients` function now requires that the variables over which it's
differentiating are pre-rendered:

    gradients :: (..., Rendered v2) => Tensor v1 a -> [Tensor v2 a]
              -> m [Tensor Value a]

(`Rendered v2` means that `v2` is either a `Ref` or a `Value`.)

Additionally, the implementation of `gradients` now takes care to render every
intermediate value when performing the reverse accumulation.  I suspect this
fixes an exponential blowup for complicated expressions.
2017-04-06 15:10:33 -07:00
Judah Jacobson c99a23b6a7 Add versions of each op that take optional params as an extra arg. (#84)
Each op `foo :: ...` now has a corresponding `foo' :: OpParams -> ...`
which lets you set optional attributes.  `OpParams` is currently a type alias for
`OpDef -> OpDef`.  In the future we should consider more type safety, e.g.,
using type-level strings and OverloadedLabels for optional attributes.

I used it to replace a few manual `buildOp`s in our code with the codegenerated
ops, now that it's easier to set attributes.  I also removed `tensorAttr` and
`named` since it's now possible to set those op attributes directly.

Although this clutters up the API a bit, I think it's simpler than using type
classes to implement optional arguments (as in, for example, `Text.Printf`) --
especially in terms of type inference with the rest of the library.
2017-03-20 18:16:38 -07:00
Judah Jacobson 2c5c879037 Introduce a MonadBuild class, and remove `buildAnd`. (#83)
This change adds a class that both `Build` and `Session` are instances of:

    class MonadBuild m where
        build :: Build a -> m a

All stateful ops (generated and manually written) now have a signature that returns
an instance of `MonadBuild` (rather than just `Build`).  For example:

    assign_ :: (MonadBuild m, TensorType t)
            => Tensor Ref t -> Tensor v t -> m (Tensor Ref t)

This lets us remove a bunch of spurious calls to `build` in user code.  It also
lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run`
(or `run =<< foo`, which is sometimes nicer when foo is a complicated expression).

I went ahead and deleted `buildAnd` altogether since it seems to lead to
confusion; in particular a few tests had `buildAnd run . pure` which is
actually equivalent to just `run`.
2017-03-18 12:08:53 -07:00
fkm3 b3c0997a8c Add support for logging to tensorboard (#74)
Add support for logging to tensorboard

Based on @gnezdo's internal version with some differences:

* Uses a pure haskell implementation of EventWriter instead of FFI.
* Special `buildAnd*` functions were dropped in favor of using
  `mergeAllSummaries :: Build SummaryTensor` with the normal
  `build` function.
2017-02-20 19:16:42 -08:00
fkm3 f170df9d13 Support fetching storable vectors + use them in benchmark (#50)
In addition, you can now fetch TensorData directly. This might be useful in
scenarios where you feed the result of a computation back in, like RNN.

Before:

benchmarking feedFetch/4 byte
time                 83.31 μs   (81.88 μs .. 84.75 μs)
                     0.997 R²   (0.994 R² .. 0.998 R²)
mean                 87.32 μs   (86.06 μs .. 88.83 μs)
std dev              4.580 μs   (3.698 μs .. 5.567 μs)
variance introduced by outliers: 55% (severely inflated)

benchmarking feedFetch/4 KiB
time                 114.9 μs   (111.5 μs .. 118.2 μs)
                     0.996 R²   (0.994 R² .. 0.998 R²)
mean                 117.3 μs   (116.2 μs .. 118.6 μs)
std dev              3.877 μs   (3.058 μs .. 5.565 μs)
variance introduced by outliers: 31% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 109.0 ms   (107.9 ms .. 110.7 ms)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 108.6 ms   (108.2 ms .. 109.2 ms)
std dev              740.2 μs   (353.2 μs .. 1.186 ms)

After:

benchmarking feedFetch/4 byte
time                 82.92 μs   (80.55 μs .. 85.24 μs)
                     0.996 R²   (0.993 R² .. 0.998 R²)
mean                 83.58 μs   (82.34 μs .. 84.89 μs)
std dev              4.327 μs   (3.664 μs .. 5.375 μs)
variance introduced by outliers: 54% (severely inflated)

benchmarking feedFetch/4 KiB
time                 85.69 μs   (83.81 μs .. 87.30 μs)
                     0.997 R²   (0.996 R² .. 0.999 R²)
mean                 86.99 μs   (86.11 μs .. 88.15 μs)
std dev              3.608 μs   (2.854 μs .. 5.273 μs)
variance introduced by outliers: 43% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 1.582 ms   (1.509 ms .. 1.677 ms)
                     0.970 R²   (0.936 R² .. 0.993 R²)
mean                 1.645 ms   (1.554 ms .. 1.981 ms)
std dev              490.6 μs   (138.9 μs .. 1.067 ms)
variance introduced by outliers: 97% (severely inflated)
2016-12-14 18:53:06 -08:00
Greg Steuck 0d4f5a9628 Added sessionTracer to log graph operations. (#26)
* Added TracingTest.
2016-11-14 15:14:51 -08:00
fkm3 630850c2d2 Add TensorFlow.Core module to start formalizing the exposed API (#17)
* Add TensorFlow.Core module to start formalizing the exposed API

* Refer to ops packages instead of modules
2016-11-10 09:47:41 -08:00