1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-01-09 18:39:48 +01:00
Commit graph

17 commits

Author SHA1 Message Date
Judah Jacobson
a11a417ad5 Add another test of CSE and feeds. (#87)
As a follow-up to #86, check that our CSE isn't too aggressive to prevent feeds
of pure ops with distinct names.
2017-03-23 12:58:40 -07:00
Judah Jacobson
fdbfd050f8 Prevent CSE of placeholder ops. (#86)
The bug was introduced in #84.
2017-03-22 22:47:42 -07:00
Judah Jacobson
c99a23b6a7 Add versions of each op that take optional params as an extra arg. (#84)
Each op `foo :: ...` now has a corresponding `foo' :: OpParams -> ...`
which lets you set optional attributes.  `OpParams` is currently a type alias for
`OpDef -> OpDef`.  In the future we should consider more type safety, e.g.,
using type-level strings and OverloadedLabels for optional attributes.

I used it to replace a few manual `buildOp`s in our code with the codegenerated
ops, now that it's easier to set attributes.  I also removed `tensorAttr` and
`named` since it's now possible to set those op attributes directly.

Although this clutters up the API a bit, I think it's simpler than using type
classes to implement optional arguments (as in, for example, `Text.Printf`) --
especially in terms of type inference with the rest of the library.
2017-03-20 18:16:38 -07:00
Judah Jacobson
2c5c879037 Introduce a MonadBuild class, and remove buildAnd. (#83)
This change adds a class that both `Build` and `Session` are instances of:

    class MonadBuild m where
        build :: Build a -> m a

All stateful ops (generated and manually written) now have a signature that returns
an instance of `MonadBuild` (rather than just `Build`).  For example:

    assign_ :: (MonadBuild m, TensorType t)
            => Tensor Ref t -> Tensor v t -> m (Tensor Ref t)

This lets us remove a bunch of spurious calls to `build` in user code.  It also
lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run`
(or `run =<< foo`, which is sometimes nicer when foo is a complicated expression).

I went ahead and deleted `buildAnd` altogether since it seems to lead to
confusion; in particular a few tests had `buildAnd run . pure` which is
actually equivalent to just `run`.
2017-03-18 12:08:53 -07:00
fkm3
4fb68f3aa3 Add example to README + make haddock link more prominent (#60) 2017-01-16 20:44:45 -08:00
fkm3
f170df9d13 Support fetching storable vectors + use them in benchmark (#50)
In addition, you can now fetch TensorData directly. This might be useful in
scenarios where you feed the result of a computation back in, like RNN.

Before:

benchmarking feedFetch/4 byte
time                 83.31 μs   (81.88 μs .. 84.75 μs)
                     0.997 R²   (0.994 R² .. 0.998 R²)
mean                 87.32 μs   (86.06 μs .. 88.83 μs)
std dev              4.580 μs   (3.698 μs .. 5.567 μs)
variance introduced by outliers: 55% (severely inflated)

benchmarking feedFetch/4 KiB
time                 114.9 μs   (111.5 μs .. 118.2 μs)
                     0.996 R²   (0.994 R² .. 0.998 R²)
mean                 117.3 μs   (116.2 μs .. 118.6 μs)
std dev              3.877 μs   (3.058 μs .. 5.565 μs)
variance introduced by outliers: 31% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 109.0 ms   (107.9 ms .. 110.7 ms)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 108.6 ms   (108.2 ms .. 109.2 ms)
std dev              740.2 μs   (353.2 μs .. 1.186 ms)

After:

benchmarking feedFetch/4 byte
time                 82.92 μs   (80.55 μs .. 85.24 μs)
                     0.996 R²   (0.993 R² .. 0.998 R²)
mean                 83.58 μs   (82.34 μs .. 84.89 μs)
std dev              4.327 μs   (3.664 μs .. 5.375 μs)
variance introduced by outliers: 54% (severely inflated)

benchmarking feedFetch/4 KiB
time                 85.69 μs   (83.81 μs .. 87.30 μs)
                     0.997 R²   (0.996 R² .. 0.999 R²)
mean                 86.99 μs   (86.11 μs .. 88.15 μs)
std dev              3.608 μs   (2.854 μs .. 5.273 μs)
variance introduced by outliers: 43% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 1.582 ms   (1.509 ms .. 1.677 ms)
                     0.970 R²   (0.936 R² .. 0.993 R²)
mean                 1.645 ms   (1.554 ms .. 1.981 ms)
std dev              490.6 μs   (138.9 μs .. 1.067 ms)
variance introduced by outliers: 97% (severely inflated)
2016-12-14 18:53:06 -08:00
fkm3
91f508eb5c Fix TensorData encode and decode for Bool (#49) 2016-12-12 19:40:32 -08:00
fkm3
cc08520dc7 Fix gradients calculation for min and max (#48) 2016-12-12 09:47:02 -08:00
Judah Jacobson
1539783ee5 Update type constraints to work around a ghc-8 bug. (#47)
Also removes all the ghc-8-specific logic in the .cabal files.

ghc-8 has issues with deeply nested tuples of constraints.  We can
work around it by:
- Changing TensorTypes to a regular class.  This required FlexibleContexts.
  (But we'll probably need it anyway when we support heterogeneous tensor
  lists.)
- Specializing NoneOf for long type lists.

For more details, see: https://ghc.haskell.org/trac/ghc/ticket/12175.

Also added 'directory' to tensorflow-core-ops' dependencies since it's used
in the Setup script.

One more step towards fixing #38.
2016-11-28 21:15:09 -08:00
Judah Jacobson
a277c7ddb3 Refactor OpGen. (#36)
Also fixes op lists when the same attribute specifies the length of
both an input and an output.  I added a test of "shapeN" which
previously failed with the following error:

    ERROR: Ran out of counts in toResult. Likely misuse of buildListOp.
2016-11-20 10:00:22 -08:00
Greg Steuck
2b5e41ffeb Make code --pedantic (#35)
* Enforce pedantic build mode in CI.
* Our imports drifted really far from where they should be.
2016-11-18 10:42:02 -08:00
Noon van der Silk
69fdbf677f test case to show can't calculate grad for embedding (and associated fix) (#23)
* Fix for embedding gradient calculation

- Passes vectors instead of scalars to slice
- converts the numRows to a scalar
- add `toScalar` utility function
- minor change to test case so that it actually works

* added lib for testing helper functions

* add flatSlice function
2016-11-17 13:54:36 -08:00
fkm3
fc3d398ca9 Optimize fetching (#27)
* Add MNIST data to gitignore
* Add simple tensor round-trip benchmark
* Use deepseq + cleaner imports
* Use safe version of fromIntegral in FFI code
* Don't copy data when fetching tensors

BEFORE

benchmarking feedFetch/4 byte
time                 55.79 μs   (54.88 μs .. 56.62 μs)
                     0.998 R²   (0.997 R² .. 0.999 R²)
mean                 55.61 μs   (55.09 μs .. 56.11 μs)
std dev              1.828 μs   (1.424 μs .. 2.518 μs)
variance introduced by outliers: 34% (moderately inflated)

benchmarking feedFetch/4 KiB
time                 231.4 μs   (221.9 μs .. 247.3 μs)
                     0.988 R²   (0.974 R² .. 1.000 R²)
mean                 226.6 μs   (224.1 μs .. 236.2 μs)
std dev              13.45 μs   (7.115 μs .. 27.14 μs)
variance introduced by outliers: 57% (severely inflated)

benchmarking feedFetch/4 MiB
time                 485.8 ms   (424.6 ms .. 526.7 ms)
                     0.998 R²   (0.994 R² .. 1.000 R²)
mean                 515.7 ms   (512.5 ms .. 517.9 ms)
std dev              3.320 ms   (0.0 s .. 3.822 ms)
variance introduced by outliers: 19% (moderately inflated)

AFTER

benchmarking feedFetch/4 byte
time                 53.11 μs   (52.12 μs .. 54.22 μs)
                     0.996 R²   (0.995 R² .. 0.998 R²)
mean                 54.64 μs   (53.59 μs .. 56.18 μs)
std dev              4.249 μs   (2.910 μs .. 6.076 μs)
variance introduced by outliers: 75% (severely inflated)

benchmarking feedFetch/4 KiB
time                 83.83 μs   (82.72 μs .. 84.92 μs)
                     0.999 R²   (0.998 R² .. 0.999 R²)
mean                 83.82 μs   (83.20 μs .. 84.35 μs)
std dev              1.943 μs   (1.557 μs .. 2.614 μs)
variance introduced by outliers: 20% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 95.54 ms   (93.62 ms .. 97.82 ms)
                     0.999 R²   (0.998 R² .. 1.000 R²)
mean                 96.61 ms   (95.76 ms .. 97.51 ms)
std dev              1.408 ms   (1.005 ms .. 1.889 ms)
2016-11-17 10:41:49 -08:00
Greg Steuck
0d4f5a9628 Added sessionTracer to log graph operations. (#26)
* Added TracingTest.
2016-11-14 15:14:51 -08:00
silky
9c81241439 Tests for "embedding_lookup" and minor fix
- added a test that fails for a partitioned embedding
- added a test that passes for a single embedding
2016-11-09 16:21:40 +11:00
Greg Steuck
4ec78a8fca Replaced topK with topKV2. (#21)
topK is obsolete and generating warnings.
2016-11-08 20:57:22 -08:00
Greg Steuck
67690d1499 Initial commit 2016-10-24 19:26:42 +00:00