mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Initial commit
This commit is contained in:
commit
67690d1499
67 changed files with 6400 additions and 0 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
**/.stack-work
|
||||||
|
.stack/
|
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
[submodule "third_party/tensorflow"]
|
||||||
|
path = third_party/tensorflow
|
||||||
|
url = https://github.com/tensorflow/tensorflow.git
|
25
CONTRIBUTING.md
Normal file
25
CONTRIBUTING.md
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
Want to contribute? Great! First, read this page (including the small print at the end).
|
||||||
|
|
||||||
|
### Before you contribute
|
||||||
|
Before we can use your code, you must sign the
|
||||||
|
[Google Individual Contributor License Agreement](https://cla.developers.google.com/about/google-individual)
|
||||||
|
(CLA), which you can do online. The CLA is necessary mainly because you own the
|
||||||
|
copyright to your changes, even after your contribution becomes part of our
|
||||||
|
codebase, so we need your permission to use and distribute your code. We also
|
||||||
|
need to be sure of various other things—for instance that you'll tell us if you
|
||||||
|
know that your code infringes on other people's patents. You don't have to sign
|
||||||
|
the CLA until after you've submitted your code for review and a member has
|
||||||
|
approved it, but you must do it before we can put your code into our codebase.
|
||||||
|
Before you start working on a larger contribution, you should get in touch with
|
||||||
|
us first through the issue tracker with your idea so that we can help out and
|
||||||
|
possibly guide you. Coordinating up front makes it much easier to avoid
|
||||||
|
frustration later on.
|
||||||
|
|
||||||
|
### Code reviews
|
||||||
|
All submissions, including submissions by project members, require review. We
|
||||||
|
use Github pull requests for this purpose.
|
||||||
|
|
||||||
|
### The small print
|
||||||
|
Contributions made by corporations are covered by a different agreement than
|
||||||
|
the one above, the
|
||||||
|
[Software Grant and Corporate Contributor License Agreement](https://cla.developers.google.com/about/google-corporate).
|
203
LICENSE
Normal file
203
LICENSE
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
Copyright 2016 The TensorFlow Authors. All rights reserved.
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2016, The TensorFlow Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
24
README.md
Normal file
24
README.md
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
The tensorflow-haskell package provides Haskell bindings to
|
||||||
|
[TensorFlow](https://www.tensorflow.org/).
|
||||||
|
|
||||||
|
This is not an official Google product.
|
||||||
|
|
||||||
|
# Instructions
|
||||||
|
|
||||||
|
## Build
|
||||||
|
|
||||||
|
For now [docker](https://www.docker.com/) is required. Once you have docker
|
||||||
|
working, the following commands will compile and run the tests.
|
||||||
|
|
||||||
|
git clone --recursive https://github.com/tensorflow/haskell.git tensorflow-haskell
|
||||||
|
cd tensorflow-haskell
|
||||||
|
IMAGE_NAME=tensorflow/haskell:v0
|
||||||
|
docker build -t $IMAGE_NAME docker
|
||||||
|
# TODO: move the setup step to the docker script.
|
||||||
|
stack --docker --docker-image=$IMAGE_NAME setup
|
||||||
|
stack --docker --docker-image=$IMAGE_NAME test
|
||||||
|
|
||||||
|
There is also a demo application:
|
||||||
|
|
||||||
|
cd tensorflow-mnist
|
||||||
|
stack --docker --docker-image=$IMAGE_NAME build --exec Main
|
29
docker/Dockerfile
Normal file
29
docker/Dockerfile
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Prepare the image with:
|
||||||
|
# docker build -t gnezdo/tfhs:2016-10-14-1 docker
|
||||||
|
FROM gcr.io/tensorflow/tensorflow:latest-devel
|
||||||
|
MAINTAINER Greg Steuck <gnezdo+tfhs@google.com>
|
||||||
|
# Installs protoc and the libraries.
|
||||||
|
RUN \
|
||||||
|
cd /tensorflow && \
|
||||||
|
bazel --batch build -c opt '@protobuf//:protoc' && \
|
||||||
|
install -s bazel-bin/external/protobuf/protoc /usr/local/bin && \
|
||||||
|
bazel --batch build -c opt '//tensorflow:libtensorflow_c.so' && \
|
||||||
|
install bazel-bin/tensorflow/libtensorflow_c.so /usr/local/lib && \
|
||||||
|
bazel --batch clean
|
||||||
|
|
||||||
|
RUN apt-get update
|
||||||
|
|
||||||
|
RUN apt-get install -y \
|
||||||
|
# Avoids /usr/bin/ld: cannot find -ltinfo
|
||||||
|
libncurses5-dev \
|
||||||
|
# Makes stack viable in the container
|
||||||
|
libgmp-dev \
|
||||||
|
# Required for locales configuration.
|
||||||
|
locales
|
||||||
|
|
||||||
|
# Our MNIST demo program outputs Unicode characters.
|
||||||
|
RUN dpkg-reconfigure locales && \
|
||||||
|
locale-gen en_US.UTF-8 && \
|
||||||
|
update-locale LANG=en_US.UTF-8
|
||||||
|
|
||||||
|
ENV LANG en_US.UTF-8
|
3
google-shim/Setup.hs
Normal file
3
google-shim/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
23
google-shim/google-shim.cabal
Normal file
23
google-shim/google-shim.cabal
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
name: google-shim
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Adapters to externalize TensorFlow code.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: Google.Test
|
||||||
|
build-depends: base >= 4.7 && < 5
|
||||||
|
, test-framework
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
22
google-shim/src/Google/Test.hs
Normal file
22
google-shim/src/Google/Test.hs
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
-- | Alternative implementations to make dependent code work without
|
||||||
|
-- changes.
|
||||||
|
module Google.Test where
|
||||||
|
|
||||||
|
import Test.Framework (Test, defaultMain)
|
||||||
|
|
||||||
|
googleTest :: [Test] -> IO ()
|
||||||
|
googleTest = defaultMain
|
24
stack.yaml
Normal file
24
stack.yaml
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
resolver: lts-6.2
|
||||||
|
|
||||||
|
packages:
|
||||||
|
- google-shim
|
||||||
|
- tensorflow
|
||||||
|
- tensorflow-core-ops
|
||||||
|
- tensorflow-opgen
|
||||||
|
- tensorflow-ops
|
||||||
|
- tensorflow-proto
|
||||||
|
- tensorflow-mnist
|
||||||
|
- tensorflow-mnist-input-data
|
||||||
|
- tensorflow-queue
|
||||||
|
|
||||||
|
extra-deps:
|
||||||
|
# proto-lens is not yet in Stackage.
|
||||||
|
- proto-lens-0.1.0.4
|
||||||
|
- proto-lens-protoc-0.1.0.4
|
||||||
|
|
||||||
|
# Allow our custom Setup.hs scripts to import Data.ProtoLens.Setup from the version of
|
||||||
|
# `proto-lens-protoc` in stack's local DB. See:
|
||||||
|
# https://github.com/google/proto-lens/blob/master/README.md#using-cabal
|
||||||
|
explicit-setup-deps:
|
||||||
|
"*": true
|
||||||
|
|
100
tensorflow-core-ops/Setup.hs
Normal file
100
tensorflow-core-ops/Setup.hs
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
-- | Generates the wrappers for Ops shipped with tensorflow_c.
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Distribution.Simple.BuildPaths (autogenModulesDir)
|
||||||
|
import Distribution.Simple.LocalBuildInfo (LocalBuildInfo)
|
||||||
|
import Distribution.Simple
|
||||||
|
( defaultMainWithHooks
|
||||||
|
, simpleUserHooks
|
||||||
|
, UserHooks(..)
|
||||||
|
)
|
||||||
|
import Data.List (intercalate)
|
||||||
|
import Data.ProtoLens (decodeMessage)
|
||||||
|
import System.Directory (createDirectoryIfMissing)
|
||||||
|
import System.Exit (exitFailure)
|
||||||
|
import System.FilePath ((</>))
|
||||||
|
import System.IO (hPutStrLn, stderr)
|
||||||
|
import TensorFlow.Internal.FFI (getAllOpList)
|
||||||
|
import TensorFlow.OpGen (docOpList, OpGenFlags(..))
|
||||||
|
import Text.PrettyPrint.Mainland (prettyLazyText)
|
||||||
|
import qualified Data.Text.Lazy.IO as Text
|
||||||
|
|
||||||
|
main = defaultMainWithHooks generatingOpsWrappers
|
||||||
|
|
||||||
|
-- TODO: Generalize for user libraries by replacing getAllOpList with
|
||||||
|
-- a wrapper around TF_LoadLibrary. The complicated part is interplay
|
||||||
|
-- between bazel and Haskell build system.
|
||||||
|
generatingOpsWrappers :: UserHooks
|
||||||
|
generatingOpsWrappers = hooks
|
||||||
|
{ buildHook = \p l h f -> generateSources l >> buildHook hooks p l h f
|
||||||
|
, haddockHook = \p l h f -> generateSources l >> haddockHook hooks p l h f
|
||||||
|
, replHook = \p l h f args -> generateSources l
|
||||||
|
>> replHook hooks p l h f args
|
||||||
|
}
|
||||||
|
where
|
||||||
|
flagsBuilder dir = OpGenFlags
|
||||||
|
{ outputFile = dir </> "Core.hs"
|
||||||
|
, prefix = "TensorFlow.GenOps"
|
||||||
|
, excludeList = intercalate "," blackList
|
||||||
|
}
|
||||||
|
hooks = simpleUserHooks
|
||||||
|
generateSources :: LocalBuildInfo -> IO ()
|
||||||
|
generateSources l = do
|
||||||
|
let dir = autogenModulesDir l </> "TensorFlow/GenOps"
|
||||||
|
createDirectoryIfMissing True dir
|
||||||
|
let flags = flagsBuilder dir
|
||||||
|
pb <- getAllOpList
|
||||||
|
case decodeMessage pb of
|
||||||
|
Left e -> hPutStrLn stderr e >> exitFailure
|
||||||
|
Right x -> Text.writeFile (outputFile flags)
|
||||||
|
(prettyLazyText 80 $ docOpList flags x)
|
||||||
|
|
||||||
|
blackList =
|
||||||
|
-- A few data flow ops take a list of heterogeneous
|
||||||
|
-- parameters which we don't support in general form.
|
||||||
|
[ "HashTable"
|
||||||
|
, "MutableDenseHashTable"
|
||||||
|
, "MutableHashTable"
|
||||||
|
, "MutableHashTableOfTensors"
|
||||||
|
, "QueueDequeue"
|
||||||
|
, "QueueDequeueMany"
|
||||||
|
, "QueueDequeueUpTo"
|
||||||
|
, "Stack"
|
||||||
|
, "TensorArray"
|
||||||
|
-- These should be possible to support by adding a bunch of
|
||||||
|
-- overloads with a variable number of tuple arguments.
|
||||||
|
, "Assert"
|
||||||
|
, "BarrierTakeMany"
|
||||||
|
, "Print"
|
||||||
|
, "QueueEnqueue"
|
||||||
|
, "QueueEnqueueMany"
|
||||||
|
-- These have type ambiguities because one of the type arguments
|
||||||
|
-- doesn't appear in the signature.
|
||||||
|
, "ConditionalAccumulator"
|
||||||
|
, "SparseConditionalAccumulator"
|
||||||
|
-- Need list of types support.
|
||||||
|
, "DecodeCSV"
|
||||||
|
, "ParseExample"
|
||||||
|
, "ParseSingleSequenceExample"
|
||||||
|
, "Save"
|
||||||
|
, "SaveSlices"
|
||||||
|
, "SymbolicGradient"
|
||||||
|
, "_ArrayToList"
|
||||||
|
, "_ListToArray"
|
||||||
|
-- Easy: support larger result tuples.
|
||||||
|
, "Skipgram"
|
||||||
|
]
|
30
tensorflow-core-ops/tensorflow-core-ops.cabal
Normal file
30
tensorflow-core-ops/tensorflow-core-ops.cabal
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
name: tensorflow-core-ops
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Haskell wrappers for Core Tensorflow Ops.
|
||||||
|
description: Code generated signatures for the Ops in libtensorflow_c.
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Custom
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
exposed-modules: TensorFlow.GenOps.Core
|
||||||
|
build-depends: Cabal >= 1.22 && < 1.25
|
||||||
|
, bytestring
|
||||||
|
, proto-lens == 0.1.*
|
||||||
|
, tensorflow-opgen == 0.1.*
|
||||||
|
, tensorflow == 0.1.*
|
||||||
|
, base >= 4.7 && < 5
|
||||||
|
, filepath
|
||||||
|
, mainland-pretty
|
||||||
|
, lens-family
|
||||||
|
, text
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
113
tensorflow-mnist-input-data/Setup.hs
Normal file
113
tensorflow-mnist-input-data/Setup.hs
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
|
||||||
|
-- | Downloads the MNIST data set and packages them as data files.
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Control.Monad (when)
|
||||||
|
import Data.Maybe (fromMaybe)
|
||||||
|
import Distribution.PackageDescription
|
||||||
|
( GenericPackageDescription(packageDescription)
|
||||||
|
, dataDir
|
||||||
|
)
|
||||||
|
import Distribution.Simple
|
||||||
|
( UserHooks(..)
|
||||||
|
, defaultMainWithHooks
|
||||||
|
, simpleUserHooks
|
||||||
|
)
|
||||||
|
import System.IO (hPutStrLn, stderr)
|
||||||
|
import System.FilePath ((</>))
|
||||||
|
import System.Directory (doesFileExist)
|
||||||
|
import qualified Crypto.Hash as Hash
|
||||||
|
import qualified Data.ByteString.Lazy as B
|
||||||
|
import qualified Network.HTTP as HTTP
|
||||||
|
import qualified Network.URI as URI
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = defaultMainWithHooks downloadingDataFiles
|
||||||
|
|
||||||
|
downloadingDataFiles :: UserHooks
|
||||||
|
downloadingDataFiles = hooks
|
||||||
|
{ confHook = \gh@(g, _) c -> downloadFiles g >> confHook hooks gh c
|
||||||
|
}
|
||||||
|
where
|
||||||
|
hooks = simpleUserHooks
|
||||||
|
downloadFiles :: GenericPackageDescription -> IO ()
|
||||||
|
downloadFiles g = do
|
||||||
|
let dir = dataDir (packageDescription g)
|
||||||
|
mapM_ (maybeDownload dir) fileInfos
|
||||||
|
|
||||||
|
maybeDownload :: FilePath -> (String, String) -> IO ()
|
||||||
|
maybeDownload dataDir (basename, sha256) = do
|
||||||
|
let filePath = dataDir </> basename
|
||||||
|
exists <- doesFileExist filePath
|
||||||
|
when (not exists) $ do
|
||||||
|
let url = urlPrefix ++ basename
|
||||||
|
hPutStrLn stderr ("Downloading " ++ url)
|
||||||
|
httpDownload url filePath
|
||||||
|
verify filePath sha256
|
||||||
|
|
||||||
|
httpDownload :: String -> FilePath -> IO ()
|
||||||
|
httpDownload url outFile = do
|
||||||
|
let uri = fromMaybe
|
||||||
|
(error ("Can't be: invalid URI " ++ url))
|
||||||
|
(URI.parseURI url)
|
||||||
|
result <- HTTP.simpleHTTP (HTTP.defaultGETRequest_ uri)
|
||||||
|
HTTP.getResponseCode result >>= \case
|
||||||
|
(2, 0, 0) -> HTTP.getResponseBody result >>= B.writeFile outFile
|
||||||
|
s -> error ( "Failed to download " ++ url ++ " error code " ++ show s
|
||||||
|
++ helpfulMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
verify :: FilePath -> String -> IO ()
|
||||||
|
verify filePath hash = do
|
||||||
|
let sha256 = Hash.hashlazy :: B.ByteString -> Hash.Digest Hash.SHA256
|
||||||
|
computed <- show . sha256 <$> B.readFile filePath
|
||||||
|
when (hash /= computed) $
|
||||||
|
error ( "Incorrect checksum for " ++ filePath
|
||||||
|
++ "\nexpected " ++ hash
|
||||||
|
++ "\ncomputed " ++ computed
|
||||||
|
++ helpfulMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
urlPrefix = "http://yann.lecun.com/exdb/mnist/"
|
||||||
|
|
||||||
|
-- | File names relative to 'urlPrefix' and their sha256.
|
||||||
|
fileInfos = [
|
||||||
|
( "train-images-idx3-ubyte.gz"
|
||||||
|
, "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
|
||||||
|
)
|
||||||
|
,
|
||||||
|
( "train-labels-idx1-ubyte.gz"
|
||||||
|
, "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
||||||
|
)
|
||||||
|
,
|
||||||
|
( "t10k-images-idx3-ubyte.gz"
|
||||||
|
, "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
||||||
|
)
|
||||||
|
,
|
||||||
|
( "t10k-labels-idx1-ubyte.gz"
|
||||||
|
, "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
helpfulMessage =
|
||||||
|
unlines
|
||||||
|
( ""
|
||||||
|
: ""
|
||||||
|
: "Please download the following URLs manually and put them in data/"
|
||||||
|
: [ urlPrefix ++ h | (h, _) <- fileInfos ]
|
||||||
|
)
|
|
@ -0,0 +1,31 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
module TensorFlow.Examples.MNIST.InputData
|
||||||
|
( trainingImageData
|
||||||
|
, trainingLabelData
|
||||||
|
, testImageData
|
||||||
|
, testLabelData
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Paths_tensorflow_mnist_input_data (getDataFileName)
|
||||||
|
|
||||||
|
-- | Download the files containing the canonical MNIST samples and labels.
|
||||||
|
trainingImageData, trainingLabelData :: IO FilePath
|
||||||
|
trainingImageData = getDataFileName "train-images-idx3-ubyte.gz"
|
||||||
|
trainingLabelData = getDataFileName "train-labels-idx1-ubyte.gz"
|
||||||
|
|
||||||
|
testImageData, testLabelData :: IO FilePath
|
||||||
|
testImageData = getDataFileName "t10k-images-idx3-ubyte.gz"
|
||||||
|
testLabelData = getDataFileName "t10k-labels-idx1-ubyte.gz"
|
|
@ -0,0 +1,35 @@
|
||||||
|
name: tensorflow-mnist-input-data
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Downloader of input data for training MNIST.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Custom
|
||||||
|
cabal-version: >=1.22
|
||||||
|
-- These files are downloaded automatically by Setup.hs. If the
|
||||||
|
-- automatic download fails, follow the instructions in error messages
|
||||||
|
-- displayed by Setup.hs.
|
||||||
|
data-dir: data
|
||||||
|
data-files: *.gz
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.Examples.MNIST.InputData
|
||||||
|
other-modules: Paths_tensorflow_mnist_input_data
|
||||||
|
build-depends: Cabal
|
||||||
|
, HTTP
|
||||||
|
, base >= 4.7 && < 5
|
||||||
|
, bytestring
|
||||||
|
, cryptonite
|
||||||
|
, directory
|
||||||
|
, filepath
|
||||||
|
, network-uri
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
3
tensorflow-mnist/Setup.hs
Normal file
3
tensorflow-mnist/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
161
tensorflow-mnist/app/Main.hs
Normal file
161
tensorflow-mnist/app/Main.hs
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
|
||||||
|
import Control.Monad (zipWithM, when, forM, forM_)
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Data.Int (Int32, Int64)
|
||||||
|
import qualified Data.Text.IO as T
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
import qualified TensorFlow.ControlFlow as TF
|
||||||
|
import qualified TensorFlow.Build as TF
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Session as TF
|
||||||
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
import qualified TensorFlow.Types as TF
|
||||||
|
import qualified TensorFlow.Gradient as TF
|
||||||
|
|
||||||
|
import TensorFlow.Examples.MNIST.InputData
|
||||||
|
import TensorFlow.Examples.MNIST.Parse
|
||||||
|
|
||||||
|
numPixels = 28^2 :: Int64
|
||||||
|
numLabels = 10 :: Int64
|
||||||
|
|
||||||
|
-- | Create tensor with random values where the stddev depends on the width.
|
||||||
|
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float)
|
||||||
|
randomParam width (TF.Shape shape) =
|
||||||
|
(* stddev) <$> TF.truncatedNormal (TF.vector shape)
|
||||||
|
where
|
||||||
|
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||||
|
|
||||||
|
-- Types must match due to model structure (sparseToDense requires
|
||||||
|
-- index types to match)
|
||||||
|
type LabelType = Int32
|
||||||
|
type BatchSize = Int32
|
||||||
|
|
||||||
|
-- | Convert scalar labels to one-hot vectors.
|
||||||
|
labelClasses :: TF.Tensor TF.Value LabelType
|
||||||
|
-> LabelType
|
||||||
|
-> BatchSize
|
||||||
|
-> TF.Tensor TF.Value Float
|
||||||
|
labelClasses labels numClasses batchSize =
|
||||||
|
let indices = TF.range 0 (TF.scalar batchSize) 1
|
||||||
|
concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1]
|
||||||
|
in TF.sparseToDense concated
|
||||||
|
(TF.constant [2] [batchSize, numClasses])
|
||||||
|
1 {- ON value -}
|
||||||
|
0 {- default (OFF) value -}
|
||||||
|
|
||||||
|
-- | Fraction of elements that differ between two vectors.
|
||||||
|
errorRate :: Eq a => V.Vector a -> V.Vector a -> Double
|
||||||
|
errorRate xs ys = fromIntegral (len - numCorrect) / fromIntegral len
|
||||||
|
where
|
||||||
|
numCorrect = V.length $ V.filter id $ V.zipWith (==) xs ys
|
||||||
|
len = V.length xs
|
||||||
|
|
||||||
|
data Model = Model {
|
||||||
|
train :: TF.TensorData Float -- ^ images
|
||||||
|
-> TF.TensorData LabelType
|
||||||
|
-> TF.Session ()
|
||||||
|
, infer :: TF.TensorData Float -- ^ images
|
||||||
|
-> TF.Session (V.Vector LabelType) -- ^ predictions
|
||||||
|
}
|
||||||
|
|
||||||
|
createModel :: Int64 -> TF.Build Model
|
||||||
|
createModel batchSize = do
|
||||||
|
-- Inputs.
|
||||||
|
images <- TF.placeholder [batchSize, numPixels]
|
||||||
|
-- Hidden layer.
|
||||||
|
let numUnits = 500
|
||||||
|
hiddenWeights <-
|
||||||
|
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
|
||||||
|
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
|
||||||
|
let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases
|
||||||
|
let hidden = TF.relu hiddenZ
|
||||||
|
-- Logits.
|
||||||
|
logitWeights <-
|
||||||
|
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
|
||||||
|
logitBiases <- TF.zeroInitializedVariable [numLabels]
|
||||||
|
let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases
|
||||||
|
predict <- TF.render $ TF.cast $
|
||||||
|
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
|
||||||
|
|
||||||
|
-- Create training action.
|
||||||
|
labels <- TF.placeholder [batchSize]
|
||||||
|
let labelVecs = labelClasses labels 10 (fromIntegral batchSize)
|
||||||
|
loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||||
|
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
||||||
|
grads <- TF.gradients loss params
|
||||||
|
|
||||||
|
let lr = TF.scalar $ 0.001 / fromIntegral batchSize
|
||||||
|
applyGrad param grad
|
||||||
|
= TF.assign param $ param `TF.sub` (lr * grad)
|
||||||
|
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
||||||
|
|
||||||
|
return Model {
|
||||||
|
train = \imFeed lFeed -> TF.runWithFeeds_ [
|
||||||
|
TF.feed images imFeed
|
||||||
|
, TF.feed labels lFeed
|
||||||
|
] trainStep
|
||||||
|
, infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict
|
||||||
|
}
|
||||||
|
|
||||||
|
main = TF.runSession $ do
|
||||||
|
-- Read training and test data.
|
||||||
|
trainingImages <- liftIO (readMNISTSamples =<< trainingImageData)
|
||||||
|
trainingLabels <- liftIO (readMNISTLabels =<< trainingLabelData)
|
||||||
|
testImages <- liftIO (readMNISTSamples =<< testImageData)
|
||||||
|
testLabels <- liftIO (readMNISTLabels =<< testLabelData)
|
||||||
|
|
||||||
|
let batchSize = 100 :: Int64
|
||||||
|
|
||||||
|
-- Create the model.
|
||||||
|
model <- TF.build $ createModel batchSize
|
||||||
|
|
||||||
|
-- Helpers for generate batches.
|
||||||
|
let selectBatch i xs = take size $ drop (i * size) $ cycle xs
|
||||||
|
where size = fromIntegral batchSize
|
||||||
|
let getImageBatch i xs = TF.encodeTensorData
|
||||||
|
[batchSize, numPixels]
|
||||||
|
$ fromIntegral <$> mconcat (selectBatch i xs)
|
||||||
|
let getExpectedLabelBatch i xs =
|
||||||
|
fromIntegral <$> V.fromList (selectBatch i xs)
|
||||||
|
|
||||||
|
-- Train.
|
||||||
|
forM_ ([0..1000] :: [Int]) $ \i -> do
|
||||||
|
let images = getImageBatch i trainingImages
|
||||||
|
labels = getExpectedLabelBatch i trainingLabels
|
||||||
|
train model images (TF.encodeTensorData [batchSize] labels)
|
||||||
|
when (i `mod` 100 == 0) $ do
|
||||||
|
preds <- infer model images
|
||||||
|
liftIO $ putStrLn $
|
||||||
|
"training error " ++ show (errorRate preds labels * 100)
|
||||||
|
liftIO $ putStrLn ""
|
||||||
|
|
||||||
|
-- Test.
|
||||||
|
let numTestBatches = length testImages `div` fromIntegral batchSize
|
||||||
|
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do
|
||||||
|
infer model (getImageBatch i testImages)
|
||||||
|
let testExpected = fromIntegral <$> V.fromList testLabels
|
||||||
|
liftIO $ putStrLn $
|
||||||
|
"test error " ++ show (errorRate testPreds testExpected * 100)
|
||||||
|
|
||||||
|
-- Show some predictions.
|
||||||
|
liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do
|
||||||
|
putStrLn ""
|
||||||
|
T.putStrLn $ drawMNIST $ testImages !! i
|
||||||
|
putStrLn $ "expected " ++ show (testLabels !! i)
|
||||||
|
putStrLn $ " got " ++ show (testPreds V.! i)
|
BIN
tensorflow-mnist/data/MNIST.pb
Normal file
BIN
tensorflow-mnist/data/MNIST.pb
Normal file
Binary file not shown.
BIN
tensorflow-mnist/data/MNISTBias.ckpt
Normal file
BIN
tensorflow-mnist/data/MNISTBias.ckpt
Normal file
Binary file not shown.
BIN
tensorflow-mnist/data/MNISTWts.ckpt
Normal file
BIN
tensorflow-mnist/data/MNISTWts.ckpt
Normal file
Binary file not shown.
|
@ -0,0 +1,30 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
-- | Paths to test helper files.
|
||||||
|
module TensorFlow.Examples.MNIST.TrainedGraph where
|
||||||
|
|
||||||
|
import Paths_tensorflow_mnist (getDataFileName)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.ByteString.Char8 (pack)
|
||||||
|
|
||||||
|
-- | File containing a Tensorflow serialized proto of MNIST.
|
||||||
|
mnistPb :: IO FilePath
|
||||||
|
mnistPb = getDataFileName "data/MNIST.pb"
|
||||||
|
|
||||||
|
-- | Files containing pre-trained weights for MNIST.
|
||||||
|
wtsCkpt, biasCkpt :: IO ByteString
|
||||||
|
wtsCkpt = pack <$> getDataFileName "data/MNISTWts.ckpt"
|
||||||
|
biasCkpt = pack <$> getDataFileName "data/MNISTBias.ckpt"
|
96
tensorflow-mnist/src/TensorFlow/Examples/MNIST/Parse.hs
Normal file
96
tensorflow-mnist/src/TensorFlow/Examples/MNIST/Parse.hs
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
{-# LANGUAGE TypeSynonymInstances #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE ViewPatterns #-}
|
||||||
|
|
||||||
|
module TensorFlow.Examples.MNIST.Parse where
|
||||||
|
|
||||||
|
import Control.Monad (when, liftM)
|
||||||
|
import Data.Binary.Get (Get, runGet, getWord32be, getLazyByteString)
|
||||||
|
import Data.ByteString.Lazy (toStrict, readFile)
|
||||||
|
import Data.List.Split (chunksOf)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Data.ProtoLens (Message, decodeMessageOrDie)
|
||||||
|
import Data.Text (Text)
|
||||||
|
import Data.Word (Word8, Word32)
|
||||||
|
import Prelude hiding (readFile)
|
||||||
|
import qualified Codec.Compression.GZip as GZip
|
||||||
|
import qualified Data.ByteString.Lazy as L
|
||||||
|
import qualified Data.Text as Text
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
-- | Utilities specific to MNIST.
|
||||||
|
type MNIST = V.Vector Word8
|
||||||
|
|
||||||
|
-- | Produces a unicode rendering of the MNIST digit sample.
|
||||||
|
drawMNIST :: MNIST -> Text
|
||||||
|
drawMNIST = chunk . block
|
||||||
|
where
|
||||||
|
block :: V.Vector Word8 -> Text
|
||||||
|
block (V.splitAt 1 -> ([0], xs)) = " " <> block xs
|
||||||
|
block (V.splitAt 1 -> ([n], xs)) = c `Text.cons` block xs
|
||||||
|
where c = "\9617\9618\9619\9608" !! fromIntegral (n `div` 64)
|
||||||
|
block (V.splitAt 1 -> _) = ""
|
||||||
|
chunk :: Text -> Text
|
||||||
|
chunk "" = "\n"
|
||||||
|
chunk xs = Text.take 28 xs <> "\n" <> chunk (Text.drop 28 xs)
|
||||||
|
|
||||||
|
-- | Check's the file's endianess, throwing an error if it's not as expected.
|
||||||
|
checkEndian :: Get ()
|
||||||
|
checkEndian = do
|
||||||
|
magic <- getWord32be
|
||||||
|
when (magic `notElem` ([2049, 2051] :: [Word32])) $
|
||||||
|
fail "Expected big endian, but image file is little endian."
|
||||||
|
|
||||||
|
-- | Reads an MNIST file and returns a list of samples.
|
||||||
|
readMNISTSamples :: FilePath -> IO [MNIST]
|
||||||
|
readMNISTSamples path = do
|
||||||
|
raw <- GZip.decompress <$> readFile path
|
||||||
|
return $ runGet getMNIST raw
|
||||||
|
where
|
||||||
|
getMNIST :: Get [MNIST]
|
||||||
|
getMNIST = do
|
||||||
|
checkEndian
|
||||||
|
-- Parse header data.
|
||||||
|
cnt <- liftM fromIntegral getWord32be
|
||||||
|
rows <- liftM fromIntegral getWord32be
|
||||||
|
cols <- liftM fromIntegral getWord32be
|
||||||
|
-- Read all of the data, then split into samples.
|
||||||
|
pixels <- getLazyByteString $ fromIntegral $ cnt * rows * cols
|
||||||
|
return $ V.fromList <$> chunksOf (rows * cols) (L.unpack pixels)
|
||||||
|
|
||||||
|
-- | Reads a list of MNIST labels from a file and returns them.
|
||||||
|
readMNISTLabels :: FilePath -> IO [Word8]
|
||||||
|
readMNISTLabels path = do
|
||||||
|
raw <- GZip.decompress <$> readFile path
|
||||||
|
return $ runGet getLabels raw
|
||||||
|
where getLabels :: Get [Word8]
|
||||||
|
getLabels = do
|
||||||
|
checkEndian
|
||||||
|
-- Parse header data.
|
||||||
|
cnt <- liftM fromIntegral getWord32be
|
||||||
|
-- Read all of the labels.
|
||||||
|
L.unpack <$> getLazyByteString cnt
|
||||||
|
|
||||||
|
readMessageFromFileOrDie :: Message m => FilePath -> IO m
|
||||||
|
readMessageFromFileOrDie path = do
|
||||||
|
pb <- readFile path
|
||||||
|
return $ decodeMessageOrDie $ toStrict pb
|
||||||
|
|
||||||
|
-- TODO: Write a writeMessageFromFileOrDie and read/write non-lethal
|
||||||
|
-- versions.
|
80
tensorflow-mnist/tensorflow-mnist.cabal
Normal file
80
tensorflow-mnist/tensorflow-mnist.cabal
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
name: tensorflow-mnist
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: TensorFlow demo application for learning MNIST model.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
data-files: data/*.ckpt
|
||||||
|
, data/*.pb
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
, src-data
|
||||||
|
exposed-modules: TensorFlow.Examples.MNIST.Parse
|
||||||
|
, TensorFlow.Examples.MNIST.TrainedGraph
|
||||||
|
other-modules: Paths_tensorflow_mnist
|
||||||
|
build-depends: proto-lens == 0.1.*
|
||||||
|
, base >= 4.7 && < 5
|
||||||
|
, binary
|
||||||
|
, bytestring
|
||||||
|
, filepath
|
||||||
|
, lens-family
|
||||||
|
, containers
|
||||||
|
, split
|
||||||
|
, tensorflow-proto == 0.1.*
|
||||||
|
, tensorflow-core-ops == 0.1.*
|
||||||
|
, tensorflow
|
||||||
|
, text
|
||||||
|
, vector
|
||||||
|
, zlib
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
executable Main
|
||||||
|
default-language: Haskell2010
|
||||||
|
main-is: Main.hs
|
||||||
|
hs-source-dirs: app
|
||||||
|
build-depends: base
|
||||||
|
, bytestring
|
||||||
|
, filepath
|
||||||
|
, lens-family
|
||||||
|
, proto-lens
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-mnist
|
||||||
|
, tensorflow-mnist-input-data
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, text
|
||||||
|
, transformers
|
||||||
|
, vector
|
||||||
|
|
||||||
|
Test-Suite ParseTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: ParseTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, base
|
||||||
|
, bytestring
|
||||||
|
, proto-lens
|
||||||
|
, lens-family
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-mnist
|
||||||
|
, tensorflow-mnist-input-data
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, text
|
||||||
|
, transformers
|
||||||
|
, vector
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
170
tensorflow-mnist/tests/ParseTest.hs
Normal file
170
tensorflow-mnist/tests/ParseTest.hs
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Data.Int (Int64)
|
||||||
|
import Data.Text (Text)
|
||||||
|
import qualified Data.Text.IO as Text
|
||||||
|
import Lens.Family2 ((&), (.~), (^.))
|
||||||
|
import Prelude hiding (abs)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Graph
|
||||||
|
( GraphDef(..)
|
||||||
|
, version
|
||||||
|
, node )
|
||||||
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
( NodeDef(..)
|
||||||
|
, op )
|
||||||
|
import System.IO as IO
|
||||||
|
import TensorFlow.Examples.MNIST.InputData
|
||||||
|
import TensorFlow.Examples.MNIST.Parse
|
||||||
|
import TensorFlow.Examples.MNIST.TrainedGraph
|
||||||
|
import TensorFlow.Build
|
||||||
|
( asGraphDef
|
||||||
|
, addGraphDef
|
||||||
|
, render
|
||||||
|
)
|
||||||
|
import TensorFlow.Tensor
|
||||||
|
( Tensor(..)
|
||||||
|
, Ref
|
||||||
|
, Value
|
||||||
|
, feed
|
||||||
|
, TensorKind(..)
|
||||||
|
, tensorFromName
|
||||||
|
)
|
||||||
|
import TensorFlow.Ops
|
||||||
|
import TensorFlow.Nodes (unScalar)
|
||||||
|
import TensorFlow.Session
|
||||||
|
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
||||||
|
import TensorFlow.Types (TensorType(..), Shape(..))
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@=?), Assertion)
|
||||||
|
import Google.Test
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
-- | Test that a file can be read and the GraphDef proto correctly parsed.
|
||||||
|
testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
|
||||||
|
-- Check the function on a known well-formatted file.
|
||||||
|
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
|
||||||
|
-- Simple field read.
|
||||||
|
1 @=? mnist^.version
|
||||||
|
-- Count the number of nodes.
|
||||||
|
let nodes :: [NodeDef]
|
||||||
|
nodes = mnist^.node
|
||||||
|
100 @=? length nodes
|
||||||
|
-- Check that the expected op is found at an arbitrary index.
|
||||||
|
"Variable" @=? nodes!!6^.op
|
||||||
|
|
||||||
|
-- | Parse the test set for label and image data. Will only fail if the file is
|
||||||
|
-- missing or incredibly corrupt.
|
||||||
|
testReadMNIST = testCase "testReadMNIST" $ do
|
||||||
|
imageData <- readMNISTSamples =<< testImageData
|
||||||
|
10000 @=? length imageData
|
||||||
|
labelData <- readMNISTLabels =<< testLabelData
|
||||||
|
10000 @=? length labelData
|
||||||
|
|
||||||
|
testNodeName :: Text -> Tensor v a -> Assertion
|
||||||
|
testNodeName n g = n @=? opName
|
||||||
|
where
|
||||||
|
opName = head (gDef^.node)^.op
|
||||||
|
gDef = asGraphDef $ render g
|
||||||
|
|
||||||
|
testGraphDefGen = testCase "testGraphDefGen" $ do
|
||||||
|
-- Test the inferred operation type.
|
||||||
|
let f0 :: Tensor Value Float
|
||||||
|
f0 = 0
|
||||||
|
testNodeName "Const" f0
|
||||||
|
testNodeName "Add" $ 1 + f0
|
||||||
|
testNodeName "Mul" $ 1 * f0
|
||||||
|
testNodeName "Sub" $ 1 - f0
|
||||||
|
testNodeName "Abs" $ abs f0
|
||||||
|
testNodeName "Sign" $ signum f0
|
||||||
|
testNodeName "Neg" $ -f0
|
||||||
|
-- Test the grouping.
|
||||||
|
testNodeName "Add" $ 1 + f0 * 2
|
||||||
|
testNodeName "Add" $ 1 + (f0 * 2)
|
||||||
|
testNodeName "Mul" $ (1 + f0) * 2
|
||||||
|
|
||||||
|
-- | Convert a simple graph to GraphDef, load it, run it, and check the output.
|
||||||
|
testGraphDefExec = testCase "testGraphDefExec" $ do
|
||||||
|
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
||||||
|
runSession $ do
|
||||||
|
build $ addGraphDef graphDef
|
||||||
|
x <- run $ tensorFromName ValueKind "Mul_2"
|
||||||
|
liftIO $ (50 :: Float) @=? unScalar x
|
||||||
|
|
||||||
|
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
|
||||||
|
-- sample data.
|
||||||
|
testMNISTExec = testCase "testMNISTExec" $ do
|
||||||
|
-- Switch to unicode to enable pretty printing of MNIST digits.
|
||||||
|
IO.hSetEncoding IO.stdout IO.utf8
|
||||||
|
|
||||||
|
-- Parse the Graph definition, samples, & labels from files.
|
||||||
|
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
|
||||||
|
mnistSamples <- readMNISTSamples =<< testImageData
|
||||||
|
mnistLabels <- readMNISTLabels =<< testLabelData
|
||||||
|
|
||||||
|
-- Select a sample to run on and convert it into a TensorData of Floats.
|
||||||
|
let idx = 12
|
||||||
|
sample :: MNIST
|
||||||
|
sample = mnistSamples !! idx
|
||||||
|
label = mnistLabels !! idx
|
||||||
|
tensorSample = encodeTensorData (Shape [1,784]) floatSample
|
||||||
|
where
|
||||||
|
floatSample :: V.Vector Float
|
||||||
|
floatSample = V.map fromIntegral sample
|
||||||
|
Text.putStrLn $ drawMNIST sample
|
||||||
|
|
||||||
|
-- Execute the graph on the sample data.
|
||||||
|
runSession $ do
|
||||||
|
-- The version of this session is 0, but the version of the graph is 1.
|
||||||
|
-- Change the graph version to 0 so they're compatible.
|
||||||
|
build $ addGraphDef $ mnist & version .~ 0
|
||||||
|
-- Define nodes that restore saved weights and biases.
|
||||||
|
let bias, wts :: Tensor Ref Float
|
||||||
|
bias = tensorFromName RefKind "Variable"
|
||||||
|
wts = tensorFromName RefKind "weights"
|
||||||
|
wtsCkptPath <- liftIO wtsCkpt
|
||||||
|
biasCkptPath <- liftIO biasCkpt
|
||||||
|
-- Run those restoring nodes on the graph in the current session.
|
||||||
|
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
|
||||||
|
[ restore wtsCkptPath wts
|
||||||
|
, restoreFromName biasCkptPath "bias" bias
|
||||||
|
]
|
||||||
|
-- Encode the expected sample data as one-hot data.
|
||||||
|
let ty = encodeTensorData [10] oneHotLabels
|
||||||
|
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
|
||||||
|
updates = [(fromIntegral label, 1)]
|
||||||
|
let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample
|
||||||
|
, feed (tensorFromName ValueKind "y-input") ty
|
||||||
|
]
|
||||||
|
-- Run the graph with the input feeds and read the ArgMax'd result from
|
||||||
|
-- the test (not training) side of the evaluation.
|
||||||
|
x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax"
|
||||||
|
-- Print the trained model's predicted outcome.
|
||||||
|
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
|
||||||
|
++ "Prediction: " ++ show (unScalar x :: Int64)
|
||||||
|
-- Check whether the prediction matches the expectation.
|
||||||
|
liftIO $ (fromInteger . toInteger $ label :: Int64) @=? unScalar x
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testReadMessageFromFileOrDie
|
||||||
|
, testReadMNIST
|
||||||
|
, testGraphDefGen
|
||||||
|
, testGraphDefExec
|
||||||
|
, testMNISTExec]
|
3
tensorflow-opgen/Setup.hs
Normal file
3
tensorflow-opgen/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
457
tensorflow-opgen/src/TensorFlow/OpGen.hs
Normal file
457
tensorflow-opgen/src/TensorFlow/OpGen.hs
Normal file
|
@ -0,0 +1,457 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
-- | Rendering of TensorFlow operations as Haskell functions.
|
||||||
|
|
||||||
|
module TensorFlow.OpGen
|
||||||
|
( OpGenFlags(..)
|
||||||
|
, docOpList
|
||||||
|
, flagParser)
|
||||||
|
where
|
||||||
|
|
||||||
|
import Prelude hiding (head, tail)
|
||||||
|
|
||||||
|
import Control.Monad (guard)
|
||||||
|
import Data.Char (toLower, toUpper)
|
||||||
|
import Data.Foldable (toList)
|
||||||
|
import Data.Maybe (fromMaybe, maybeToList)
|
||||||
|
import Data.ProtoLens (def, showMessage)
|
||||||
|
import Data.List.NonEmpty (NonEmpty((:|)), head)
|
||||||
|
import qualified Data.List.NonEmpty as NE
|
||||||
|
import Lens.Family2 ((^.), (.~), (&), view)
|
||||||
|
import Options.Applicative (Parser, help, long, strOption, value)
|
||||||
|
import Proto.Tensorflow.Core.Framework.OpDef
|
||||||
|
( OpList
|
||||||
|
, OpDef
|
||||||
|
, OpDef'ArgDef
|
||||||
|
, attr
|
||||||
|
, description
|
||||||
|
, inputArg
|
||||||
|
, name
|
||||||
|
, numberAttr
|
||||||
|
, op
|
||||||
|
, outputArg
|
||||||
|
, summary
|
||||||
|
, type'
|
||||||
|
, typeAttr
|
||||||
|
)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||||
|
import System.FilePath (takeBaseName)
|
||||||
|
import TensorFlow.OpGen.AttrVal
|
||||||
|
(AttrDef
|
||||||
|
, AttrCase(..)
|
||||||
|
, AttrTemplate(..)
|
||||||
|
, Template
|
||||||
|
, attrDef
|
||||||
|
, attrOriginal
|
||||||
|
, attrTemplate
|
||||||
|
, templateDefault
|
||||||
|
, templateRestrictions
|
||||||
|
)
|
||||||
|
import Text.PrettyPrint.Mainland
|
||||||
|
( Doc
|
||||||
|
, (<>)
|
||||||
|
, (<+>)
|
||||||
|
, (</>)
|
||||||
|
, (<+/>)
|
||||||
|
, brackets
|
||||||
|
, comma
|
||||||
|
, commasep
|
||||||
|
, dquotes
|
||||||
|
, empty
|
||||||
|
, enclose
|
||||||
|
, flatten
|
||||||
|
, folddoc
|
||||||
|
, hang
|
||||||
|
, indent
|
||||||
|
, int
|
||||||
|
, parens
|
||||||
|
, sep
|
||||||
|
, stack
|
||||||
|
, strictText
|
||||||
|
, tuple
|
||||||
|
)
|
||||||
|
import qualified Data.Map.Strict as Map
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
import qualified Data.Text as Text
|
||||||
|
import qualified Data.Semigroup as Semigroup
|
||||||
|
import Data.Text (Text)
|
||||||
|
|
||||||
|
data OpGenFlags = OpGenFlags
|
||||||
|
{ outputFile :: String
|
||||||
|
, prefix :: String
|
||||||
|
, excludeList :: String
|
||||||
|
}
|
||||||
|
|
||||||
|
flagParser :: Parser OpGenFlags
|
||||||
|
flagParser = OpGenFlags
|
||||||
|
<$> strOption (mconcat [ long "output"
|
||||||
|
, help "File to write."
|
||||||
|
])
|
||||||
|
<*> strOption (mconcat [ long "prefix"
|
||||||
|
, help "Haskell package prefix to use"
|
||||||
|
])
|
||||||
|
<*> strOption (mconcat [ long "exclude_list"
|
||||||
|
, value ""
|
||||||
|
, help "Comma separated Ops names to ignore"
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
docOpList :: OpGenFlags -> OpList -> Doc
|
||||||
|
docOpList flags opList =
|
||||||
|
stack [ "{-# LANGUAGE ConstraintKinds #-}"
|
||||||
|
, "{-# LANGUAGE DataKinds #-}"
|
||||||
|
, "{-# LANGUAGE FlexibleInstances #-}"
|
||||||
|
, "{-# LANGUAGE OverloadedStrings #-}"
|
||||||
|
, "{-# LANGUAGE RankNTypes #-}"
|
||||||
|
, "{-# LANGUAGE ScopedTypeVariables #-}"
|
||||||
|
, "module" <+> strictText moduleName <+> "where"
|
||||||
|
, empty
|
||||||
|
, imports
|
||||||
|
, empty
|
||||||
|
, folddoc (\x y -> x </> empty </> y)
|
||||||
|
(map renderDef $
|
||||||
|
filter (not . flip elem exclusions . view name) $
|
||||||
|
toList $ opList ^. op)
|
||||||
|
]
|
||||||
|
where moduleName =
|
||||||
|
Text.pack (prefix flags) <> "." <> camelCase
|
||||||
|
-- Discards the optional trailing _op_lib
|
||||||
|
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
|
||||||
|
shortName = Text.pack (takeBaseName $ outputFile flags)
|
||||||
|
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
|
||||||
|
|
||||||
|
camelCase s = Text.concat $ map upCase
|
||||||
|
$ filter (/= "ops")
|
||||||
|
$ Text.splitOn "_" s
|
||||||
|
|
||||||
|
-- | Upper-case the given text.
|
||||||
|
upCase :: Text -> Text
|
||||||
|
upCase = forceCase toUpper
|
||||||
|
|
||||||
|
-- | Lower-case the given name, and prevent it from overlapping with a reserved
|
||||||
|
-- Haskell name.
|
||||||
|
lowCase :: Text -> Text
|
||||||
|
lowCase = replaceReservedName . forceCase toLower
|
||||||
|
|
||||||
|
forceCase :: (Char -> Char) -> Text -> Text
|
||||||
|
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
||||||
|
(Text.uncons s)
|
||||||
|
|
||||||
|
imports = stack [
|
||||||
|
"import Data.ByteString (ByteString)"
|
||||||
|
, "import Data.Complex (Complex)"
|
||||||
|
, "import Data.Int (Int8, Int16, Int32, Int64)"
|
||||||
|
, "import Data.Word (Word8, Word16)"
|
||||||
|
, "import Lens.Family2 ((.~), (&))"
|
||||||
|
, "import TensorFlow.Build"
|
||||||
|
, "import TensorFlow.BuildOp"
|
||||||
|
, "import TensorFlow.Tensor"
|
||||||
|
, "import TensorFlow.Types"
|
||||||
|
]
|
||||||
|
|
||||||
|
renderDef :: OpDef -> Doc
|
||||||
|
renderDef d =
|
||||||
|
stack [
|
||||||
|
haddocks
|
||||||
|
, n <+> "::" <+> hang 0 (typeSig d)
|
||||||
|
, n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" </> -- args are indented
|
||||||
|
-- the body needs to be indented wrt the name
|
||||||
|
indent indentation functionBody
|
||||||
|
, extras -- just for debug
|
||||||
|
]
|
||||||
|
where
|
||||||
|
n = strictText $ fixOpName (d ^. name)
|
||||||
|
args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs
|
||||||
|
tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg]
|
||||||
|
fixOpName = lowCase
|
||||||
|
funcGuard = "eqLengthGuard" <+> brackets (commasep entries)
|
||||||
|
where
|
||||||
|
entries =
|
||||||
|
[ parens $ quotedText nAttr <> comma <+>
|
||||||
|
brackets (commasep $ toList $
|
||||||
|
NE.map renderTensorName tensorNames)
|
||||||
|
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
|
||||||
|
]
|
||||||
|
renderTensorName x = parens $ quotedText x <> comma <+>
|
||||||
|
"length" <+> strictText x
|
||||||
|
-- Uses hang 0 to align the argument vertically on multiple lines.
|
||||||
|
functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||||
|
</> indent indentation (sep tensorArgs)
|
||||||
|
buildFunction
|
||||||
|
| null outputListsSizes = "buildOp"
|
||||||
|
| otherwise = "buildListOp" <+> brackets (commasep outputListsSizes)
|
||||||
|
outputListsSizes = [ strictText numberAttrName
|
||||||
|
| o <- d ^. outputArg
|
||||||
|
, let numberAttrName = o ^. numberAttr
|
||||||
|
, not (Text.null numberAttrName) &&
|
||||||
|
numberAttrName `Map.member` mandatoryAttrMap d
|
||||||
|
]
|
||||||
|
buildOpParts =
|
||||||
|
"opDef" <+> quotedText (d ^. name) :
|
||||||
|
-- Renders tensor arguments.
|
||||||
|
[ "& opAttr" <+> quotedText tfName <+>
|
||||||
|
".~ tensorType (undefined ::" <+> strictText hsName <> ")"
|
||||||
|
| (tfName, (hsName, _)) <- Map.toList typeMap
|
||||||
|
] ++
|
||||||
|
-- Renders mandatory attributes as function parameters.
|
||||||
|
[ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName
|
||||||
|
| (tfName, hsName) <- mandatoryAttrs
|
||||||
|
] ++
|
||||||
|
-- Renders sizes of tensor list types having number_attr.
|
||||||
|
[ "& opAttr" <+> quotedText nAttr <+> ".~" <+>
|
||||||
|
"(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)"
|
||||||
|
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
|
||||||
|
]
|
||||||
|
mandatoryAttrs = [(strictText tf, strictText hs)
|
||||||
|
| (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d)
|
||||||
|
]
|
||||||
|
haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description)
|
||||||
|
extras = enclose "{-\n" "\n-}" $
|
||||||
|
strictText $ Text.pack $
|
||||||
|
showMessage ((def :: OpDef)
|
||||||
|
& inputArg .~ (d ^. inputArg)
|
||||||
|
& outputArg .~ (d ^. outputArg)
|
||||||
|
& attr .~ (d ^. attr))
|
||||||
|
typeMap = opDefTypeMap d
|
||||||
|
|
||||||
|
-- | Makes a quoted string doc out of the given text value.
|
||||||
|
quotedText :: Text.Text -> Doc
|
||||||
|
quotedText = dquotes . strictText
|
||||||
|
|
||||||
|
-- | typeSig renders the type signature of the given OpDef.
|
||||||
|
typeSig :: OpDef -> Doc
|
||||||
|
typeSig d =
|
||||||
|
foralls <+> constraints <+/>
|
||||||
|
signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs])
|
||||||
|
where
|
||||||
|
foralls | Map.null typeMap = empty
|
||||||
|
| otherwise =
|
||||||
|
"forall"
|
||||||
|
<+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap))
|
||||||
|
<+> "."
|
||||||
|
constraints | Map.null typeMap = empty
|
||||||
|
| otherwise =
|
||||||
|
tuple (concatMap
|
||||||
|
(\(t, aDef) ->
|
||||||
|
"TensorType" <+> strictText t
|
||||||
|
: maybeToList (oneOfRestrictions aDef t))
|
||||||
|
(Map.elems typeMap)) <+> "=>"
|
||||||
|
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
|
||||||
|
refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)]
|
||||||
|
tensorArg refType arg = wrapArg refType arg <+>
|
||||||
|
hang 0 ("-- ^" <+> argComment arg)
|
||||||
|
-- Argument type is a list of tensors if number_attr is set;
|
||||||
|
-- otherwise it's a single Tensor.
|
||||||
|
wrapArg refType arg =
|
||||||
|
if Text.null (arg ^. numberAttr) then typ else brackets typ
|
||||||
|
where typ = tensorType refType arg
|
||||||
|
tensorType refType arg =
|
||||||
|
"Tensor" <+> refType <+> maybe directType strictText indirectType
|
||||||
|
where
|
||||||
|
indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap)
|
||||||
|
directType = dtTypeToDoc (arg ^. type')
|
||||||
|
outputs =
|
||||||
|
case d ^. outputArg of
|
||||||
|
[] -> "ControlNode"
|
||||||
|
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
|
||||||
|
os -> renderTupleResult os
|
||||||
|
wrappedOutput = wrapArg "Value"
|
||||||
|
-- Tuple result case is rendered differently to give
|
||||||
|
-- individual elements their own comments.
|
||||||
|
renderTupleResult os =
|
||||||
|
stack $ [ tuple (map wrappedOutput os)
|
||||||
|
, flatten commentSummary
|
||||||
|
] ++ map commentDetails os
|
||||||
|
where
|
||||||
|
commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os]
|
||||||
|
commentDetails o =
|
||||||
|
stack [ "--"
|
||||||
|
, "-- *" <+> argComment o
|
||||||
|
]
|
||||||
|
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||||
|
mandatoryAttrInputs = [
|
||||||
|
dtTypeToDoc dtType <+>
|
||||||
|
hang 0 ("-- ^" <+> argComment' tfName descr)
|
||||||
|
| (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d]
|
||||||
|
typeMap = opDefTypeMap d
|
||||||
|
|
||||||
|
-- | Returns the type restriction for the given tensor type if the
|
||||||
|
-- set of allowed types is not empty (i.e. restricted).
|
||||||
|
oneOfRestrictions :: AttrDef -> Text -> Maybe Doc
|
||||||
|
oneOfRestrictions aDef tName = do
|
||||||
|
typs <- onAttrType (^. templateRestrictions) aDef
|
||||||
|
guard $ not $ null typs
|
||||||
|
let typeList = commasep $ map strictText $
|
||||||
|
Set.toList $ Set.fromList $
|
||||||
|
map dtTypeToHaskell typs
|
||||||
|
return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName
|
||||||
|
|
||||||
|
-- | Identifies the attributes used as tensor cardinalities. In such
|
||||||
|
-- cases a list of tensors is supplied as an input_arg. The number of
|
||||||
|
-- such inputs is communicated as a separate opAttr.
|
||||||
|
-- The result key is TensorFlow attribute name and the value is the
|
||||||
|
-- tensor names which have number_attr set to the result key.
|
||||||
|
numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text)
|
||||||
|
numberAttrMap d = Map.fromListWith (Semigroup.<>) [
|
||||||
|
(nAttr, replaceReservedName (inp ^. name) :| [])
|
||||||
|
| inp <- d ^. inputArg
|
||||||
|
, nAttr <- [inp ^. numberAttr]
|
||||||
|
, not (Text.null nAttr)
|
||||||
|
]
|
||||||
|
|
||||||
|
argComment :: OpDef'ArgDef -> Doc
|
||||||
|
argComment arg = argComment' (arg ^. name) (arg ^. description)
|
||||||
|
|
||||||
|
argComment' :: Text.Text -> Text.Text -> Doc
|
||||||
|
argComment' argName argDesc =
|
||||||
|
bold argName <> splitMultilineText (":" <+>) argDesc
|
||||||
|
|
||||||
|
bold :: Text.Text -> Doc
|
||||||
|
bold n = strictText ("__" <> n <> "__")
|
||||||
|
|
||||||
|
opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef)
|
||||||
|
opDefTypeMap d =
|
||||||
|
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
|
||||||
|
|
||||||
|
attrList :: OpDef -> [(Text.Text, AttrDef)]
|
||||||
|
attrList d = [(a ^. name, attrDef a) | a <- d ^. attr]
|
||||||
|
|
||||||
|
isType :: AttrDef -> Bool
|
||||||
|
isType = fromMaybe False . onAttrType (const True)
|
||||||
|
|
||||||
|
-- | Applies the given function to the data type. Is this a Prism?
|
||||||
|
onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a
|
||||||
|
onAttrType f x = case x ^. attrTemplate of
|
||||||
|
AttrSingle (AttrType a) -> Just (f a)
|
||||||
|
_ -> Nothing
|
||||||
|
|
||||||
|
-- | mandatoryAttrMap contains the attributes chosen by
|
||||||
|
-- isMandatoryAttr, excluding those which are derived from list of
|
||||||
|
-- tensor arguments. The key is the TF name of the attribute. The
|
||||||
|
-- value tuple is (haskell name, TF type, attribute comment).
|
||||||
|
mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text)
|
||||||
|
mandatoryAttrMap d =
|
||||||
|
Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description))
|
||||||
|
| (n, a) <- attrList d
|
||||||
|
, Just dtType <- [isMandatoryAttr a]
|
||||||
|
-- Excludes the attributes rendered as list lengths.
|
||||||
|
, n `Map.notMember` numberAttrMap d
|
||||||
|
]
|
||||||
|
|
||||||
|
-- | Inspects the attribute and if it is one of the implemented
|
||||||
|
-- non-tensor values lacking default, then returns Just the TF type.
|
||||||
|
isMandatoryAttr :: AttrDef -> Maybe DataType
|
||||||
|
isMandatoryAttr x =
|
||||||
|
case x ^. attrTemplate of
|
||||||
|
AttrSingle (AttrBool y) -> noDefault DT_BOOL y
|
||||||
|
AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y
|
||||||
|
AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y
|
||||||
|
_ -> Nothing
|
||||||
|
where
|
||||||
|
noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault)
|
||||||
|
|
||||||
|
dtTypeToDoc :: DataType -> Doc
|
||||||
|
dtTypeToDoc = strictText . dtTypeToHaskell
|
||||||
|
|
||||||
|
-- NOTE: The cases of this function should be kept in sync with
|
||||||
|
-- TensorFlow.Types.AllTensorTypes.
|
||||||
|
dtTypeToHaskell :: DataType -> Text.Text
|
||||||
|
dtTypeToHaskell DT_BOOL = "Bool"
|
||||||
|
dtTypeToHaskell DT_BFLOAT16 = "Data.Word.Word16"
|
||||||
|
dtTypeToHaskell DT_COMPLEX128 = "(Data.Complex.Complex Double)"
|
||||||
|
dtTypeToHaskell DT_COMPLEX64 = "(Data.Complex.Complex Float)"
|
||||||
|
dtTypeToHaskell DT_DOUBLE = "Double"
|
||||||
|
dtTypeToHaskell DT_FLOAT = "Float"
|
||||||
|
dtTypeToHaskell DT_INT16 = "Data.Int.Int16"
|
||||||
|
dtTypeToHaskell DT_INT32 = "Data.Int.Int32"
|
||||||
|
dtTypeToHaskell DT_INT64 = "Data.Int.Int64"
|
||||||
|
dtTypeToHaskell DT_INT8 = "Data.Int.Int8"
|
||||||
|
dtTypeToHaskell DT_QINT32 = "Data.Int.Int32" -- TODO(gnezdo): make unique
|
||||||
|
dtTypeToHaskell DT_QINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique
|
||||||
|
dtTypeToHaskell DT_QINT16 = "Data.Int.Int16" -- TODO(gnezdo): make unique
|
||||||
|
dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||||
|
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique
|
||||||
|
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
|
||||||
|
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
||||||
|
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||||
|
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
||||||
|
dtTypeToHaskell x =
|
||||||
|
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
||||||
|
|
||||||
|
-- | haddockComment escapes TensorFlow doc strings into haddock.
|
||||||
|
-- TODO(gnezdo): deal with the markup.
|
||||||
|
haddockComment :: Text.Text -> Doc
|
||||||
|
haddockComment = strictText
|
||||||
|
|
||||||
|
multilineComment :: Text.Text -> Text.Text -> Doc
|
||||||
|
multilineComment summary' detail =
|
||||||
|
haddockComment summary' </>
|
||||||
|
splitMultilineText insertParagraphAndComment detail
|
||||||
|
where insertParagraphAndComment x = "--" </> "--" <+> x
|
||||||
|
|
||||||
|
-- | Converts the given multi-line detail string into
|
||||||
|
-- a multi-line haddock. Applies the given lead to the
|
||||||
|
-- first line. Returns an empty document for empty detail.
|
||||||
|
splitMultilineText :: (Doc -> Doc) -> Text.Text -> Doc
|
||||||
|
splitMultilineText lead detail =
|
||||||
|
case Text.lines detail of
|
||||||
|
[] -> empty
|
||||||
|
(l : ls) -> stack $ lead (haddockComment l)
|
||||||
|
: map (("--" <+>) . haddockComment) ls
|
||||||
|
|
||||||
|
replaceReservedName :: Text -> Text
|
||||||
|
replaceReservedName n
|
||||||
|
| n `Set.member` reservedKeywords = n <> "'"
|
||||||
|
| otherwise = n
|
||||||
|
|
||||||
|
indentation = 4
|
||||||
|
|
||||||
|
reservedKeywords :: Set.Set Text
|
||||||
|
reservedKeywords = Set.fromList $
|
||||||
|
-- Haskell2010 keywords:
|
||||||
|
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
|
||||||
|
-- We don't include keywords that are allowed to be variable names,
|
||||||
|
-- in particular: "as", "forall", and "hiding".
|
||||||
|
[ "case"
|
||||||
|
, "class"
|
||||||
|
, "data"
|
||||||
|
, "default"
|
||||||
|
, "deriving"
|
||||||
|
, "do"
|
||||||
|
, "else"
|
||||||
|
, "foreign"
|
||||||
|
, "if"
|
||||||
|
, "import"
|
||||||
|
, "in"
|
||||||
|
, "infix"
|
||||||
|
, "infixl"
|
||||||
|
, "infixr"
|
||||||
|
, "instance"
|
||||||
|
, "let"
|
||||||
|
, "module"
|
||||||
|
, "newtype"
|
||||||
|
, "of"
|
||||||
|
, "then"
|
||||||
|
, "type"
|
||||||
|
, "where"
|
||||||
|
]
|
||||||
|
++ -- Nonstandard extensions
|
||||||
|
[ "mdo" -- RecursiveDo
|
||||||
|
, "rec" -- Arrows, RecursiveDo
|
||||||
|
, "proc" -- Arrows
|
||||||
|
]
|
120
tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs
Normal file
120
tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
-- | Wrapping of TensorFlow attributes into Haskell entities.
|
||||||
|
module TensorFlow.OpGen.AttrVal
|
||||||
|
(AttrDef
|
||||||
|
, AttrCase(..)
|
||||||
|
, AttrTemplate(..)
|
||||||
|
, Template
|
||||||
|
, attrDef
|
||||||
|
, attrOriginal
|
||||||
|
, attrTemplate
|
||||||
|
, templateDefault
|
||||||
|
, templateRestrictions
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Int (Int64)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Lens.Family2 (Lens', (^.))
|
||||||
|
import Lens.Family2.Unchecked (lens)
|
||||||
|
import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue
|
||||||
|
import Proto.Tensorflow.Core.Framework.OpDef as OpDef
|
||||||
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||||
|
import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto)
|
||||||
|
import qualified Data.ByteString as B
|
||||||
|
import qualified Data.Text as Text
|
||||||
|
|
||||||
|
-- | Specifies the optional default value and a set of allowed values
|
||||||
|
-- for the given type.
|
||||||
|
data Template a = Template {
|
||||||
|
_templateDefault :: Maybe a
|
||||||
|
-- ^ The default value (mandatory if unspecified)
|
||||||
|
, _templateRestrictions :: [a]
|
||||||
|
-- ^ The allowed set of values, empty if no restrictions
|
||||||
|
}
|
||||||
|
|
||||||
|
templateDefault :: Lens' (Template a) (Maybe a)
|
||||||
|
templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x })
|
||||||
|
|
||||||
|
templateRestrictions :: Lens' (Template a) [a]
|
||||||
|
templateRestrictions = lens _templateRestrictions
|
||||||
|
(\g x -> g { _templateRestrictions = x })
|
||||||
|
|
||||||
|
data UnusedTensor
|
||||||
|
|
||||||
|
data AttrCase f
|
||||||
|
= AttrBytes (f B.ByteString) -- bytes s = 2; // "string"
|
||||||
|
| AttrInt64 (f Int64) -- int64 i = 3; // "int"
|
||||||
|
| AttrFloat (f Float) -- float f = 4; // "float"
|
||||||
|
| AttrBool (f Bool) -- bool b = 5; // "bool"
|
||||||
|
| AttrType (f DataType) -- type = 6; // "type"
|
||||||
|
-- To be translated into TensorFlow.Types.Shape before use.
|
||||||
|
-- Leaving as a proto to reduce dependencies.
|
||||||
|
| AttrShape (f TensorShapeProto) -- shape = 7; // "shape"
|
||||||
|
|
||||||
|
-- | Type-reified representation of TensorFlow AttrDef.
|
||||||
|
-- Initially limited to just the types in Op descriptors.
|
||||||
|
data AttrTemplate
|
||||||
|
= AttrSingle (AttrCase Template)
|
||||||
|
| AttrList (AttrCase [])
|
||||||
|
| AttrTensor UnusedTensor -- tensor = 8; // "tensor"
|
||||||
|
|
||||||
|
data AttrDef = AttrDef {
|
||||||
|
_attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from
|
||||||
|
, _attrTemplate :: AttrTemplate -- ^ the type of the attribute
|
||||||
|
}
|
||||||
|
|
||||||
|
attrTemplate :: Lens' AttrDef AttrTemplate
|
||||||
|
attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x })
|
||||||
|
|
||||||
|
attrOriginal :: Lens' AttrDef OpDef'AttrDef
|
||||||
|
attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x })
|
||||||
|
|
||||||
|
attrDef :: OpDef'AttrDef -> AttrDef
|
||||||
|
attrDef a = AttrDef a
|
||||||
|
$ translate (a^.OpDef.type')
|
||||||
|
(a^.OpDef.defaultValue)
|
||||||
|
(a^.allowedValues)
|
||||||
|
|
||||||
|
-- | Converts the given AttrValue with the type given by the string
|
||||||
|
-- into the AttrVal if the type is known.
|
||||||
|
translate :: Text.Text -- ^ one of the TensorFlow type strings
|
||||||
|
-> AttrValue -- ^ default value
|
||||||
|
-> AttrValue -- ^ allowed values
|
||||||
|
-> AttrTemplate
|
||||||
|
translate t defaults allowed
|
||||||
|
| t == "string" = makeVal AttrBytes maybe's s
|
||||||
|
| t == "int" = makeVal AttrInt64 maybe'i i
|
||||||
|
| t == "float" = makeVal AttrFloat maybe'f f
|
||||||
|
| t == "bool" = makeVal AttrBool maybe'b b
|
||||||
|
| t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type'
|
||||||
|
| t == "shape" = makeVal AttrShape maybe'shape shape
|
||||||
|
| t == "tensor" = AttrTensor $ error "tensor is unimplemented"
|
||||||
|
| t == "list(string)" = makeList AttrBytes $ list.s
|
||||||
|
| t == "list(int)" = makeList AttrInt64 $ list.i
|
||||||
|
| t == "list(float)" = makeList AttrFloat $ list.f
|
||||||
|
| t == "list(bool)" = makeList AttrBool $ list.b
|
||||||
|
| t == "list(type)" = makeList AttrType $ list.AttrValue.type'
|
||||||
|
| t == "list(shape)" = makeList AttrShape $ list.shape
|
||||||
|
| t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented"
|
||||||
|
| t == "func" = AttrTensor $ error "func is unimplemented"
|
||||||
|
| otherwise = error $ show ("Unknown attribute type " <> t) ++
|
||||||
|
"," ++ show defaults ++
|
||||||
|
"," ++ show allowed
|
||||||
|
where makeVal c x y = AttrSingle $ c $
|
||||||
|
Template (defaults^.x) (allowed^.list.y)
|
||||||
|
makeList c y = AttrList $ c $ defaults^.y
|
33
tensorflow-opgen/tensorflow-opgen.cabal
Normal file
33
tensorflow-opgen/tensorflow-opgen.cabal
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
name: tensorflow-opgen
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Code generation for TensorFlow operations.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.OpGen.AttrVal
|
||||||
|
, TensorFlow.OpGen
|
||||||
|
build-depends: proto-lens == 0.1.*
|
||||||
|
, tensorflow-proto == 0.1.*
|
||||||
|
, base >= 4.7 && < 5
|
||||||
|
, bytestring
|
||||||
|
, containers
|
||||||
|
, filepath
|
||||||
|
, lens-family
|
||||||
|
, mainland-pretty
|
||||||
|
, optparse-applicative
|
||||||
|
, semigroups
|
||||||
|
, text
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
1
tensorflow-opgen/third_party
Symbolic link
1
tensorflow-opgen/third_party
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../third_party/tensorflow
|
3
tensorflow-ops/Setup.hs
Normal file
3
tensorflow-ops/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
76
tensorflow-ops/src/TensorFlow/EmbeddingOps.hs
Normal file
76
tensorflow-ops/src/TensorFlow/EmbeddingOps.hs
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
|
||||||
|
-- | Parallel lookups on the list of tensors.
|
||||||
|
module TensorFlow.EmbeddingOps where
|
||||||
|
|
||||||
|
import Control.Monad (zipWithM)
|
||||||
|
import Data.Int (Int32, Int64)
|
||||||
|
import Data.List (genericLength)
|
||||||
|
import TensorFlow.Build (Build, colocateWith, render)
|
||||||
|
import TensorFlow.Ops () -- Num instance for Tensor
|
||||||
|
import TensorFlow.Tensor (Tensor, Value)
|
||||||
|
import TensorFlow.Types (OneOf, TensorType)
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
|
-- | Looks up `ids` in a list of embedding tensors.
|
||||||
|
--
|
||||||
|
-- This function is used to perform parallel lookups on the list of
|
||||||
|
-- tensors in `params`. It is a generalization of `TF.gather`, where
|
||||||
|
-- `params` is interpreted as a partition of a larger embedding
|
||||||
|
-- tensor.
|
||||||
|
--
|
||||||
|
-- The partition_strategy is "mod", we assign each id to partition
|
||||||
|
-- `p = id % len(params)`. For instance,
|
||||||
|
-- 13 ids are split across 5 partitions as:
|
||||||
|
-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
|
||||||
|
--
|
||||||
|
-- The results of the lookup are concatenated into a dense
|
||||||
|
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||||
|
embeddingLookup :: forall a b v .
|
||||||
|
( TensorType a
|
||||||
|
, OneOf '[Int64, Int32] b
|
||||||
|
, Num b
|
||||||
|
)
|
||||||
|
=> [Tensor v a]
|
||||||
|
-- ^ A list of tensors which can be concatenated along
|
||||||
|
-- dimension 0. Each `Tensor` must be appropriately
|
||||||
|
-- sized for `mod` partition strategy.
|
||||||
|
-> Tensor Value b
|
||||||
|
-- ^ A `Tensor` with type `int32` or `int64`
|
||||||
|
-- containing the ids to be looked up in `params`.
|
||||||
|
-- The ids are required to be flat on entry and have
|
||||||
|
-- fewer than 2^31 entries.
|
||||||
|
-> Build (Tensor Value a)
|
||||||
|
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
||||||
|
embeddingLookup params ids =
|
||||||
|
CoreOps.dynamicStitch pindices <$> partitionedResult
|
||||||
|
where np = genericLength params
|
||||||
|
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
|
||||||
|
newIds = ids `CoreOps.div` np
|
||||||
|
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
|
||||||
|
-- Partition list of ids based on assignments into np separate lists
|
||||||
|
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
|
||||||
|
-- Similarly, partition the original indices.
|
||||||
|
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
|
||||||
|
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
|
||||||
|
partitionedResult = zipWithM
|
||||||
|
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
|
||||||
|
params gatherIds
|
697
tensorflow-ops/src/TensorFlow/Gradient.hs
Normal file
697
tensorflow-ops/src/TensorFlow/Gradient.hs
Normal file
|
@ -0,0 +1,697 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE ViewPatterns #-}
|
||||||
|
|
||||||
|
module TensorFlow.Gradient
|
||||||
|
( gradients
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad (forM, zipWithM)
|
||||||
|
import Control.Monad.State.Strict (State, evalState, gets, modify)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Complex (Complex)
|
||||||
|
import Data.Default (def)
|
||||||
|
import Data.Int (Int32, Int64)
|
||||||
|
import Data.List (foldl', sortBy)
|
||||||
|
import Data.Map.Strict (Map)
|
||||||
|
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
||||||
|
import Data.Ord (comparing)
|
||||||
|
import Data.ProtoLens.TextFormat (showMessage)
|
||||||
|
import Data.Set (Set)
|
||||||
|
import Data.Text (Text)
|
||||||
|
import Data.Tuple (swap)
|
||||||
|
import Lens.Family2 (Lens', (&), (^.), (.~), (%~))
|
||||||
|
import Lens.Family2.State.Strict (uses)
|
||||||
|
import Lens.Family2.Stock (at, intAt)
|
||||||
|
import Lens.Family2.Unchecked (lens, iso)
|
||||||
|
import Prelude hiding (sum)
|
||||||
|
import Text.Printf (printf)
|
||||||
|
import qualified Data.Graph.Inductive.Basic as FGL
|
||||||
|
import qualified Data.Graph.Inductive.Graph as FGL
|
||||||
|
import qualified Data.Graph.Inductive.PatriciaTree as FGL
|
||||||
|
import qualified Data.Graph.Inductive.Query.DFS as FGL
|
||||||
|
import qualified Data.IntMap.Strict as IntMap
|
||||||
|
import qualified Data.Map.Strict as Map
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
import qualified Data.Text as Text
|
||||||
|
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
import TensorFlow.Build
|
||||||
|
( Build
|
||||||
|
, render
|
||||||
|
, renderNodeName
|
||||||
|
, renderedNodeDefs
|
||||||
|
, opDef
|
||||||
|
, opAttr
|
||||||
|
)
|
||||||
|
import TensorFlow.BuildOp
|
||||||
|
import TensorFlow.Ops
|
||||||
|
( addN
|
||||||
|
, broadcastGradientArgs
|
||||||
|
, expandDims
|
||||||
|
, fill
|
||||||
|
, matMul
|
||||||
|
, reducedShape
|
||||||
|
, reluGrad
|
||||||
|
, reshape
|
||||||
|
, scalar
|
||||||
|
, shape
|
||||||
|
, softmaxCrossEntropyWithLogits
|
||||||
|
, sum
|
||||||
|
, vector
|
||||||
|
, zerosLike
|
||||||
|
)
|
||||||
|
import TensorFlow.Output
|
||||||
|
( NodeName(..)
|
||||||
|
, Op (Rendered)
|
||||||
|
, Output(..)
|
||||||
|
, OutputIx(..)
|
||||||
|
, outputIndex
|
||||||
|
)
|
||||||
|
import TensorFlow.Tensor
|
||||||
|
( Tensor(..)
|
||||||
|
, TensorKind (ValueKind)
|
||||||
|
, Value
|
||||||
|
, tensorOutput
|
||||||
|
, tensorAttr
|
||||||
|
)
|
||||||
|
import TensorFlow.Types (OneOf, TensorType, attrLens)
|
||||||
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
(NodeDef, attr, input, op, name)
|
||||||
|
|
||||||
|
type GradientCompatible a =
|
||||||
|
-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
|
||||||
|
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
|
||||||
|
|
||||||
|
-- TODO(fmayle): Support control flow.
|
||||||
|
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.
|
||||||
|
-- TODO(fmayle): Do we need to consider control inputs? See _PendingCount in
|
||||||
|
-- tensorflow/python/ops/gradients.py.
|
||||||
|
-- TODO(fmayle): Maybe store the gradient functions and numOutputs on the OpDef.
|
||||||
|
|
||||||
|
|
||||||
|
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||||
|
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
|
||||||
|
-- TODO(gnezdo): remove indirect constraint.
|
||||||
|
-- It's a wart inherited from Num instance.
|
||||||
|
, v1 ~ Value
|
||||||
|
, GradientCompatible a
|
||||||
|
)
|
||||||
|
=> Tensor v1 a -- ^ The output of the graph.
|
||||||
|
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
||||||
|
-> Build [Tensor Value a]
|
||||||
|
gradients y xs = do
|
||||||
|
-- The gradients are computed using "reverse accumulation", similarly to
|
||||||
|
-- what is described here:
|
||||||
|
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
|
||||||
|
--
|
||||||
|
-- The code is summarised as follows:
|
||||||
|
--
|
||||||
|
-- 1. Create an fgl graph of the relevant nodes (ops) and edges (tensors).
|
||||||
|
-- 2. Initialize the gradient of y to 1 (∂y/∂y = 1) and the rest of tensor's
|
||||||
|
-- gradients to nothing.
|
||||||
|
-- 3. Process the nodes in reverse topological order (i.e. each node comes
|
||||||
|
-- after all of its outputs so that the output gradients for a node have
|
||||||
|
-- been completely calculated before it is processed):
|
||||||
|
-- a. Record the gradient for each of the node's output tensors (∂y/∂w
|
||||||
|
-- for each output tensor w).
|
||||||
|
-- b. Calculate the gradient of y w.r.t. each of the node's input
|
||||||
|
-- tensors using the gradients of the node's output tensors.
|
||||||
|
--
|
||||||
|
-- Written differently, for each output tensor w and input tensor v:
|
||||||
|
-- ∂y/∂w = ... (calculated in previous steps)
|
||||||
|
-- ∂w/∂v = ... (op specific)
|
||||||
|
-- ∂y/∂v = ∂y/∂w * ∂w/∂v (technically, if tensor v is an input
|
||||||
|
-- to multiple nodes, then this is only
|
||||||
|
-- part of ∂y/∂v)
|
||||||
|
--
|
||||||
|
-- 4. Lookup the recorded gradient for each x in xs.
|
||||||
|
|
||||||
|
yName <- renderNodeName y
|
||||||
|
-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
|
||||||
|
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
|
||||||
|
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
||||||
|
. flip Map.lookup
|
||||||
|
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
||||||
|
-- Set gradient of y to one.
|
||||||
|
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
||||||
|
initPending = Map.empty & at (nodeMap Map.! yName)
|
||||||
|
. nonEmpty
|
||||||
|
. outputIxAt (y ^. tensorOutput . outputIndex)
|
||||||
|
. nonEmpty
|
||||||
|
.~ [fill (shape y) (scalar 1)]
|
||||||
|
-- Calculate the gradients of y w.r.t. each node in the graph.
|
||||||
|
gradientMap <- graphGrads gr initPending
|
||||||
|
-- Lookup the gradients for each x.
|
||||||
|
forM xs $ \x -> do
|
||||||
|
xName <- renderNodeName x
|
||||||
|
render $ fromMaybe (zerosLike x) $ do
|
||||||
|
n <- nodeMap ^. at xName
|
||||||
|
let i = x ^. tensorOutput . outputIndex
|
||||||
|
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||||
|
|
||||||
|
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
||||||
|
outputIxAt = intAt . unOutputIx
|
||||||
|
|
||||||
|
-- | Incomplete gradients of a node's outputs.
|
||||||
|
--
|
||||||
|
-- The lists represent partial sums. The key is an OutputIx sans newtype.
|
||||||
|
type PendingGradients a = IntMap.IntMap [Tensor Value a]
|
||||||
|
|
||||||
|
-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
|
||||||
|
type Gradients a = IntMap.IntMap (Tensor Value a)
|
||||||
|
|
||||||
|
-- | Graph of TensorFlow operations.
|
||||||
|
type Graph = FGL.Gr NodeDef EdgeLabel
|
||||||
|
|
||||||
|
-- | Data associated with an edge.
|
||||||
|
--
|
||||||
|
-- Pair of
|
||||||
|
-- 1. Output index of a tensor from the source node.
|
||||||
|
-- 2. Input index that the tensor connects to on the destination node.
|
||||||
|
type EdgeLabel = (OutputIx, OutputIx)
|
||||||
|
|
||||||
|
|
||||||
|
-- | State used for calculating gradients.
|
||||||
|
data GradientsState a = GradientsState
|
||||||
|
{ _gradientsPending :: !(Map FGL.Node (PendingGradients a))
|
||||||
|
, _gradientsResult :: !(Map FGL.Node (Gradients a))
|
||||||
|
}
|
||||||
|
|
||||||
|
gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a))
|
||||||
|
gradientsPending = lens _gradientsPending (\x y -> x { _gradientsPending = y })
|
||||||
|
|
||||||
|
gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a))
|
||||||
|
gradientsResult = lens _gradientsResult (\x y -> x { _gradientsResult = y })
|
||||||
|
|
||||||
|
|
||||||
|
-- TODO(fmayle): Use something like Data.List.Safe.
|
||||||
|
-- | Safe version of (!!).
|
||||||
|
safeIndex :: [a] -> Int -> Maybe a
|
||||||
|
_ `safeIndex` n | n < 0 = Nothing
|
||||||
|
[] `safeIndex` _ = Nothing
|
||||||
|
(x:_) `safeIndex` 0 = Just x
|
||||||
|
(_:xs) `safeIndex` n = xs `safeIndex` (n-1)
|
||||||
|
|
||||||
|
-- Copy of http://hackage.haskell.org/package/lens-3.9.0.2/docs/Control-Lens-Iso.html#v%3anon
|
||||||
|
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
|
||||||
|
anon a p = iso (fromMaybe a) go where
|
||||||
|
go b | p b = Nothing
|
||||||
|
| otherwise = Just b
|
||||||
|
|
||||||
|
non :: Eq a => a -> Lens' (Maybe a) a
|
||||||
|
non a = anon a (a==)
|
||||||
|
|
||||||
|
-- | Lens that defaults Nothing to mempty.
|
||||||
|
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
|
||||||
|
nonEmpty = anon mempty null
|
||||||
|
|
||||||
|
-- | Calculate the gradients for every node in a graph.
|
||||||
|
graphGrads :: forall a. GradientCompatible a
|
||||||
|
=> Graph
|
||||||
|
-> Map FGL.Node (PendingGradients a)
|
||||||
|
-- ^ Initial gradients (usually just 1 for the node of interest).
|
||||||
|
-> Build (Map FGL.Node (Gradients a))
|
||||||
|
graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult)
|
||||||
|
where
|
||||||
|
initState = GradientsState initPending Map.empty
|
||||||
|
-- Reverse topological sort.
|
||||||
|
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
|
||||||
|
-- avoid calculating gradients that won't be used.
|
||||||
|
nodeOrder = FGL.topsort $ FGL.grev gr
|
||||||
|
go state node =
|
||||||
|
-- Aggregate the accumulated gradients for this node.
|
||||||
|
let outputGrads =
|
||||||
|
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
|
||||||
|
in if null outputGrads
|
||||||
|
then state
|
||||||
|
else
|
||||||
|
-- Calculate the gradients for each of the node's inputs.
|
||||||
|
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
||||||
|
ctx = FGL.context gr node
|
||||||
|
in updatePendingGradients
|
||||||
|
ctx
|
||||||
|
(calculateInputGrads ctx outputGrads gr)
|
||||||
|
nextState
|
||||||
|
|
||||||
|
-- | Reduce accumulated gradients for each output to one Tensor.
|
||||||
|
sumPendingGradient :: GradientCompatible a
|
||||||
|
=> PendingGradients a -> Gradients a
|
||||||
|
sumPendingGradient = IntMap.mapMaybe f
|
||||||
|
where
|
||||||
|
f [] = Nothing
|
||||||
|
f [x] = Just x
|
||||||
|
f xs = Just (addN xs)
|
||||||
|
|
||||||
|
|
||||||
|
-- | Calculate the gradients of a node's input tensors.
|
||||||
|
--
|
||||||
|
-- This is mostly just a wrapper around opGrad.
|
||||||
|
calculateInputGrads :: forall a. GradientCompatible a
|
||||||
|
=> FGL.Context NodeDef EdgeLabel
|
||||||
|
-> Gradients a -- ^ Output gradients of the node.
|
||||||
|
-> Graph
|
||||||
|
-> [Maybe (Tensor Value a)]
|
||||||
|
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
|
||||||
|
opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
|
||||||
|
where
|
||||||
|
fullOutGrads =
|
||||||
|
fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads
|
||||||
|
-- Create a tensor from an edge (technically an Output, but it seems less
|
||||||
|
-- confusing to refer to it as a tensor here).
|
||||||
|
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
|
||||||
|
edgeToTensor ((i, _), n) =
|
||||||
|
case FGL.lab gr n of
|
||||||
|
Just edgeNodeDef -> Output i (Rendered edgeNodeDef)
|
||||||
|
Nothing -> error $ "calculateInputGrads: missing input node for "
|
||||||
|
++ Text.unpack (nodeDef ^. name)
|
||||||
|
-- Input tensors, sorted by input index.
|
||||||
|
inputTensors = map edgeToTensor $ sortBy (comparing (snd . fst)) inputEdges
|
||||||
|
|
||||||
|
-- | Convert a Map of gradients to a list, with zeros for missing outputs.
|
||||||
|
fullOutputGrads :: (TensorType a, Num a)
|
||||||
|
=> OutputIx -- ^ Number of outputs.
|
||||||
|
-> Op
|
||||||
|
-> Gradients a
|
||||||
|
-> [Tensor Value a]
|
||||||
|
fullOutputGrads n o gs =
|
||||||
|
map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1]
|
||||||
|
where
|
||||||
|
-- A tensor of zeros with the same shape as the i'th output.
|
||||||
|
zero i = zerosLike $ toT (Output i o)
|
||||||
|
|
||||||
|
|
||||||
|
-- | Update the pending gradients of a node's inputs.
|
||||||
|
updatePendingGradients :: forall a. (TensorType a, Num a)
|
||||||
|
=> FGL.Context NodeDef EdgeLabel
|
||||||
|
-> [Maybe (Tensor Value a)]
|
||||||
|
-- ^ Gradient of each input tensor.
|
||||||
|
-> GradientsState a
|
||||||
|
-> GradientsState a
|
||||||
|
updatePendingGradients (inputEdges, _, nodeDef, _) inputGrads initState =
|
||||||
|
foldl' go initState inputEdges
|
||||||
|
where
|
||||||
|
go :: GradientsState a -> (EdgeLabel, FGL.Node) -> GradientsState a
|
||||||
|
go state ((outIndex, OutputIx inIndex), node) =
|
||||||
|
case maybeGradient of
|
||||||
|
Nothing -> state
|
||||||
|
Just g ->
|
||||||
|
-- Add to the list of pending gradients for this tensor.
|
||||||
|
state & gradientsPending
|
||||||
|
. at node
|
||||||
|
. nonEmpty
|
||||||
|
. outputIxAt outIndex
|
||||||
|
. nonEmpty
|
||||||
|
%~ (g:)
|
||||||
|
where
|
||||||
|
badSizeErr = error $ printf "updatePendingGradients: bad input index \
|
||||||
|
\%d for inputGrads of length %d in %s"
|
||||||
|
inIndex (length inputGrads)
|
||||||
|
(show (nodeDef ^. name))
|
||||||
|
maybeGradient = fromMaybe badSizeErr (safeIndex inputGrads inIndex)
|
||||||
|
|
||||||
|
|
||||||
|
-- | Create a graph that includes a node and its transitive dependencies.
|
||||||
|
createGraph :: NodeName -> (NodeName -> NodeDef)
|
||||||
|
-> (Graph, Map NodeName FGL.Node)
|
||||||
|
createGraph nodeName nodeDefLookup = (FGL.nmap nodeDefLookup graph, nodeMap)
|
||||||
|
where
|
||||||
|
-- Parse a tensor name.
|
||||||
|
parseTensorName :: Text -> Maybe (NodeName, OutputIx)
|
||||||
|
parseTensorName n
|
||||||
|
| Text.null n = error "parseTensorName: empty name"
|
||||||
|
| Text.head n == '^' = Nothing -- Control edge
|
||||||
|
| otherwise =
|
||||||
|
let (nm, indexStr) = Text.breakOn ":" n
|
||||||
|
index | Text.null indexStr = 0
|
||||||
|
| otherwise = read $ Text.unpack $ Text.tail indexStr
|
||||||
|
in Just (NodeName nm, OutputIx index)
|
||||||
|
|
||||||
|
-- Build a map from node name to outward edges.
|
||||||
|
--
|
||||||
|
-- The state is the set of visited nodes.
|
||||||
|
collect :: Maybe (NodeName, OutputIx, OutputIx)
|
||||||
|
-> NodeName
|
||||||
|
-> State (Set NodeName)
|
||||||
|
(Map NodeName [(NodeName, OutputIx, OutputIx)])
|
||||||
|
collect outgoingEdge nm = do
|
||||||
|
let nextLookup = Map.singleton nm (maybeToList outgoingEdge)
|
||||||
|
seen <- gets (Set.member nm)
|
||||||
|
modify (Set.insert nm)
|
||||||
|
if seen
|
||||||
|
then pure nextLookup
|
||||||
|
else do
|
||||||
|
let inputs = nodeDefLookup nm ^. input
|
||||||
|
recurse inIndex (parentName, outIndex) =
|
||||||
|
collect (Just (nm, outIndex, inIndex)) parentName
|
||||||
|
subEdgeLookups <-
|
||||||
|
zipWithM recurse [0..] $ mapMaybe parseTensorName inputs
|
||||||
|
pure $ Map.unionsWith (++) (nextLookup:subEdgeLookups)
|
||||||
|
|
||||||
|
edgeLookup = evalState (collect Nothing nodeName) Set.empty
|
||||||
|
-- Associate an ID with each node name.
|
||||||
|
nodeMap = Map.fromList $ zip (Map.keys edgeLookup) [0..]
|
||||||
|
-- Create the graph.
|
||||||
|
graph = FGL.mkGraph (swap <$> Map.toList nodeMap)
|
||||||
|
[ (nodeMap Map.! n, nodeMap Map.! m, (i, j))
|
||||||
|
| (n, edges) <- Map.toList edgeLookup
|
||||||
|
, (m, i, j) <- edges
|
||||||
|
]
|
||||||
|
|
||||||
|
-- | Function to compute the gradient of y w.r.t. each input.
|
||||||
|
--
|
||||||
|
-- Let y be an arbitrary tensor
|
||||||
|
-- and [w_0, ..., w_n] be the output tensors of a node
|
||||||
|
-- and [v_0, ..., v_n] be the input tensors of the same node.
|
||||||
|
--
|
||||||
|
-- Given [∂y/∂w_0, ..., ∂y/∂w_n] and [v_0, ..., v_n], a GradientFunc computes
|
||||||
|
-- [∂y/∂v_0, ..., ∂y/∂v_n] for a particular op type.
|
||||||
|
--
|
||||||
|
-- A Nothing gradient is equivalent to zero (but allows for short circuiting
|
||||||
|
-- computation when all the gradients for something are Nothing).
|
||||||
|
type GradientFunc a = NodeDef
|
||||||
|
-> [Output]
|
||||||
|
-- ^ Input tensors.
|
||||||
|
-> [Tensor Value a]
|
||||||
|
-- ^ Gradient of y w.r.t. each output tensor.
|
||||||
|
-> [Maybe (Tensor Value a)]
|
||||||
|
-- ^ Gradient of y w.r.t. each input tensor.
|
||||||
|
|
||||||
|
|
||||||
|
-- TODO(fmayle): Assert the type is correct.
|
||||||
|
-- | Create a Tensor from an Output.
|
||||||
|
toT :: Output -> Tensor Value a
|
||||||
|
toT = Tensor ValueKind
|
||||||
|
|
||||||
|
-- | The gradient function for an op type.
|
||||||
|
--
|
||||||
|
-- These implementations should match their python counterparts in:
|
||||||
|
-- third_party/tensorflow/python/ops/*_grad.py
|
||||||
|
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
|
||||||
|
|
||||||
|
opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x]
|
||||||
|
opGrad "Neg" _ [_] [dz] = [Just $ -dz]
|
||||||
|
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||||
|
|
||||||
|
opGrad "Square" _ [toT -> x] [dz] =
|
||||||
|
-- TODO(fmayle): Handle complex numbers.
|
||||||
|
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
|
||||||
|
-- (for performance reasons?). Will need to put these functions in the Build
|
||||||
|
-- monad to replicate that.
|
||||||
|
[Just $ dz * (2 * x)]
|
||||||
|
|
||||||
|
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
||||||
|
-- TODO(fmayle): The python version uses a better performance implementation
|
||||||
|
-- when the shape is known without having to run the graph.
|
||||||
|
-- TODO(fmayle): We shouldn't convert the result to a dense tensor. Sparse
|
||||||
|
-- tensor support will require some thinking.
|
||||||
|
[ Just $ CoreOps.unsortedSegmentSum values indices' numRows
|
||||||
|
, Nothing
|
||||||
|
]
|
||||||
|
where
|
||||||
|
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
||||||
|
denseShape = shape (x :: Tensor Value a)
|
||||||
|
numRows = CoreOps.slice denseShape 0 (1 :: Tensor Value Int32)
|
||||||
|
valuesShape = CoreOps.concat 0 [
|
||||||
|
allDimensions
|
||||||
|
, CoreOps.slice denseShape 1 (-1 :: Tensor Value Int32)
|
||||||
|
]
|
||||||
|
values = reshape dz valuesShape
|
||||||
|
-- TODO(fmayle): This could be either Int32 or Int64.
|
||||||
|
indices' = reshape indices allDimensions :: Tensor Value Int32
|
||||||
|
|
||||||
|
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
|
||||||
|
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
|
||||||
|
where
|
||||||
|
sx = shape (x :: Tensor Value a)
|
||||||
|
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
||||||
|
x' = reshape x outputShapeKeptDims
|
||||||
|
dz' = reshape dz outputShapeKeptDims
|
||||||
|
indicators = CoreOps.cast $ CoreOps.equal x' x
|
||||||
|
numSelected = reshape (sum indicators indices) outputShapeKeptDims
|
||||||
|
|
||||||
|
-- Min and Max have identical gradient implementations.
|
||||||
|
opGrad "Min" u v w = opGrad "Max" u v w
|
||||||
|
|
||||||
|
opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
|
||||||
|
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
|
||||||
|
where
|
||||||
|
-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
|
||||||
|
sx = shape (x :: Tensor Value a)
|
||||||
|
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
||||||
|
tileScaling = safeShapeDiv sx outputShapeKeptDims
|
||||||
|
grad = reshape dz outputShapeKeptDims
|
||||||
|
|
||||||
|
opGrad "Mean" u v@[toT -> x, _] w =
|
||||||
|
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
|
||||||
|
where
|
||||||
|
[Just dz, Nothing] = opGrad "Sum" u v w
|
||||||
|
inputShape = shape (x :: Tensor Value a)
|
||||||
|
outputShape = shape (dz :: Tensor Value a)
|
||||||
|
-- TODO(fmayle): Add fast path when shape is known.
|
||||||
|
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
|
||||||
|
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
|
||||||
|
factor = safeShapeDiv inputSize outputSize
|
||||||
|
|
||||||
|
opGrad "Add" _ [toT -> x, toT -> y] [dz] =
|
||||||
|
[ Just $ reshape (sum dz rx) sx
|
||||||
|
, Just $ reshape (sum dz ry) sy ]
|
||||||
|
where
|
||||||
|
sx = shape (x :: Tensor Value a)
|
||||||
|
sy = shape (y :: Tensor Value a)
|
||||||
|
(rx, ry) = broadcastGradientArgs sx sy
|
||||||
|
|
||||||
|
opGrad "Sub" u v w =
|
||||||
|
[Just x, Just (-y)]
|
||||||
|
where
|
||||||
|
[Just x, Just y] = opGrad "Add" u v w
|
||||||
|
|
||||||
|
opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
|
||||||
|
[ Just $ expandDims dz (-1) * snd (softmaxCrossEntropyWithLogits x y)
|
||||||
|
, Nothing ]
|
||||||
|
|
||||||
|
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
|
||||||
|
-- TODO(fmayle): Handle complex numbers.
|
||||||
|
[ Just $ reshape (sum (dz * y) rx) sx
|
||||||
|
, Just $ reshape (sum (x * dz) ry) sy ]
|
||||||
|
where
|
||||||
|
sx = shape (x :: Tensor Value a)
|
||||||
|
sy = shape (y :: Tensor Value a)
|
||||||
|
(rx, ry) = broadcastGradientArgs sx sy
|
||||||
|
|
||||||
|
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
|
||||||
|
-- TODO(fmayle): Handle complex numbers.
|
||||||
|
-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
|
||||||
|
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
|
||||||
|
, Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy
|
||||||
|
]
|
||||||
|
where
|
||||||
|
sx = shape (x :: Tensor Value a)
|
||||||
|
sy = shape (y :: Tensor Value a)
|
||||||
|
(rx, ry) = broadcastGradientArgs sx sy
|
||||||
|
|
||||||
|
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
|
let transposeA = lookupAttr nodeDef "transpose_a"
|
||||||
|
transposeB = lookupAttr nodeDef "transpose_b"
|
||||||
|
transAttrs a b =
|
||||||
|
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
|
||||||
|
in case (transposeA, transposeB) of
|
||||||
|
(False, False) ->
|
||||||
|
[ Just $ (dz `matMul` y) & transAttrs False True
|
||||||
|
, Just $ (x `matMul` dz) & transAttrs True False ]
|
||||||
|
(False, True) ->
|
||||||
|
[ Just $ dz `matMul` y
|
||||||
|
, Just $ (x `matMul` dz) & transAttrs True False ]
|
||||||
|
(True, False) ->
|
||||||
|
[ Just $ (dz `matMul` y) & transAttrs False True
|
||||||
|
, Just $ x `matMul` dz ]
|
||||||
|
(True, True) ->
|
||||||
|
[ Just $ (dz `matMul` y) & transAttrs True True
|
||||||
|
, Just $ (x `matMul` dz) & transAttrs True True ]
|
||||||
|
|
||||||
|
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||||
|
[ Just $ CoreOps.transpose dz
|
||||||
|
(CoreOps.invertPermutation p :: Tensor Value Int32)
|
||||||
|
, Nothing
|
||||||
|
]
|
||||||
|
|
||||||
|
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
|
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
|
||||||
|
& tensorAttr "strides" .~ strides
|
||||||
|
& tensorAttr "padding" .~ padding
|
||||||
|
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
||||||
|
& tensorAttr "data_format" .~ dataFormat
|
||||||
|
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
|
||||||
|
& tensorAttr "strides" .~ strides
|
||||||
|
& tensorAttr "padding" .~ padding
|
||||||
|
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
||||||
|
& tensorAttr "data_format" .~ dataFormat
|
||||||
|
]
|
||||||
|
where
|
||||||
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||||
|
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
|
||||||
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||||
|
|
||||||
|
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||||
|
[ Just $ CoreOps.maxPoolGrad x output dz
|
||||||
|
& tensorAttr "ksize" .~ ksize
|
||||||
|
& tensorAttr "strides" .~ strides
|
||||||
|
& tensorAttr "padding" .~ padding
|
||||||
|
& tensorAttr "data_format" .~ dataFormat
|
||||||
|
]
|
||||||
|
where
|
||||||
|
output :: Tensor Value a
|
||||||
|
output = toT $ Output 0 (Rendered nodeDef)
|
||||||
|
ksize = lookupAttr nodeDef "ksize" :: [Int64]
|
||||||
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||||
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||||
|
|
||||||
|
opGrad "Reshape" _ [toT -> x, _] [dz] =
|
||||||
|
[Just $ reshape dz $ shape (x :: Tensor Value a), Nothing]
|
||||||
|
|
||||||
|
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
||||||
|
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
||||||
|
|
||||||
|
opGrad "RefIdentity" _ _ [dz] = [Just dz]
|
||||||
|
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
|
||||||
|
where
|
||||||
|
-- TODO(gnezdo): too permissive, python only allows float types as src_type.
|
||||||
|
reverseCast =
|
||||||
|
buildOp (opDef "Cast"
|
||||||
|
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
|
||||||
|
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString))
|
||||||
|
dz
|
||||||
|
|
||||||
|
opGrad "DynamicStitch" nodeDef inputs [dz] =
|
||||||
|
replicate halfLen Nothing ++ valuesGrads
|
||||||
|
where
|
||||||
|
halfLen =
|
||||||
|
let len = length inputs
|
||||||
|
half = len `div` 2
|
||||||
|
in if 2 * half == len
|
||||||
|
then half
|
||||||
|
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
|
||||||
|
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32)
|
||||||
|
| idx <- take halfLen inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
|
||||||
|
[ Just reconstructed, Nothing ]
|
||||||
|
where
|
||||||
|
reconstructed = CoreOps.reshape stitched
|
||||||
|
(CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32)
|
||||||
|
stitched = CoreOps.dynamicStitch partitionedIndices dz
|
||||||
|
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
|
||||||
|
np = lookupAttr nodeDef "num_partitions" :: Int64
|
||||||
|
originalIndices =
|
||||||
|
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
|
||||||
|
prefixShape = shapeInt32 indices
|
||||||
|
shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32
|
||||||
|
|
||||||
|
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
|
||||||
|
[ Nothing
|
||||||
|
, Just $ CoreOps.select c dz zeros
|
||||||
|
, Just $ CoreOps.select c zeros dz
|
||||||
|
]
|
||||||
|
where zeros = CoreOps.zerosLike x
|
||||||
|
|
||||||
|
-- TODO(gnezdo): Unlike Python, no control dependency on dz.
|
||||||
|
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ]
|
||||||
|
-- TODO(gnezdo): Reuse the output instead of doing another exp,
|
||||||
|
-- though, it is probably CSE'd away anyway.
|
||||||
|
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ]
|
||||||
|
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
||||||
|
[ Just $ CoreOps.unsortedSegmentSum
|
||||||
|
(CoreOps.gather dz (t :: Tensor Value Int32))
|
||||||
|
(y :: Tensor Value Int32) inputRows
|
||||||
|
, Nothing
|
||||||
|
, Nothing
|
||||||
|
]
|
||||||
|
where inputRows = CoreOps.slice (shape (x :: Tensor Value a)) (scalar (0 :: Int32)) (scalar 1)
|
||||||
|
|
||||||
|
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||||
|
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||||
|
opGrad "Size" _ _ _ = [Nothing]
|
||||||
|
opGrad "ZerosLike" _ _ _ = [Nothing]
|
||||||
|
|
||||||
|
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||||
|
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||||
|
opGrad "Placeholder" _ _ _ = []
|
||||||
|
opGrad "Variable" _ _ _ = []
|
||||||
|
|
||||||
|
opGrad n nodeDef ins grads =
|
||||||
|
error $ "no gradient implemented for " ++
|
||||||
|
show (n, length ins, length grads, showMessage nodeDef, ins)
|
||||||
|
|
||||||
|
-- | The number of outputs for an op type.
|
||||||
|
numOutputs :: NodeDef -> OutputIx
|
||||||
|
numOutputs o =
|
||||||
|
case o ^. op of
|
||||||
|
"Abs" -> 1
|
||||||
|
"Add" -> 1
|
||||||
|
"Cast" -> 1
|
||||||
|
"Const" -> 1
|
||||||
|
"Conv2D" -> 1
|
||||||
|
"Div" -> 1
|
||||||
|
"DynamicStitch" -> 1
|
||||||
|
"DynamicPartition" ->
|
||||||
|
fromIntegral (lookupAttr o "num_partitions" :: Int64)
|
||||||
|
"Exp" -> 1
|
||||||
|
"Gather" -> 1
|
||||||
|
"LabelClasses" -> 1
|
||||||
|
"LabelWeights" -> 1
|
||||||
|
"Log" -> 1
|
||||||
|
"MatMul" -> 1
|
||||||
|
"Max" -> 1
|
||||||
|
"MaxPool" -> 1
|
||||||
|
"Mean" -> 1
|
||||||
|
"Min" -> 1
|
||||||
|
"Mul" -> 1
|
||||||
|
"Neg" -> 1
|
||||||
|
"Placeholder" -> 1
|
||||||
|
"OneHot" -> 1
|
||||||
|
"RefIdentity" -> 1
|
||||||
|
"Relu" -> 1
|
||||||
|
"Reshape" -> 1
|
||||||
|
"Select" -> 1
|
||||||
|
"Size" -> 1
|
||||||
|
"SoftmaxCrossEntropyWithLogits" -> 2
|
||||||
|
"Square" -> 1
|
||||||
|
"SparseSegmentSum" -> 1
|
||||||
|
"Sub" -> 1
|
||||||
|
"Sum" -> 1
|
||||||
|
"Transpose" -> 1
|
||||||
|
"TruncatedNormal" -> 1
|
||||||
|
"Variable" -> 1
|
||||||
|
"ZerosLike" -> 1
|
||||||
|
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||||
|
|
||||||
|
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
||||||
|
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
||||||
|
|
||||||
|
allDimensions = vector [-1 :: Int32]
|
||||||
|
|
||||||
|
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
||||||
|
|
||||||
|
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens
|
296
tensorflow-ops/src/TensorFlow/Ops.hs
Normal file
296
tensorflow-ops/src/TensorFlow/Ops.hs
Normal file
|
@ -0,0 +1,296 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
-- | This module contains definitions for some built-in TensorFlow operations.
|
||||||
|
--
|
||||||
|
-- Note that certain, "stateful" ops like 'variable' and 'assign' return a
|
||||||
|
-- 'Build' action (e.g., @Build (Tensor Ref a)@ instead of a pure value; the
|
||||||
|
-- returned 'Tensor's are always rendered in the current 'Build' context. This
|
||||||
|
-- approach helps us avoid problems with inlining or common subexpression
|
||||||
|
-- elimination, by writing
|
||||||
|
--
|
||||||
|
-- > do
|
||||||
|
-- > v <- variable []
|
||||||
|
-- > w <- assign v 3
|
||||||
|
-- > render $ w * w
|
||||||
|
--
|
||||||
|
-- instead of
|
||||||
|
--
|
||||||
|
-- > let
|
||||||
|
-- > v = variable []
|
||||||
|
-- > w = assign v 3
|
||||||
|
-- > in w * w
|
||||||
|
--
|
||||||
|
-- since the latter could be reasonably transformed by the compiler into (or
|
||||||
|
-- vice versa)
|
||||||
|
--
|
||||||
|
-- > let
|
||||||
|
-- > v = variable []
|
||||||
|
-- > w = assign v 3
|
||||||
|
-- > w' = assign v 3
|
||||||
|
-- > in w * w'
|
||||||
|
--
|
||||||
|
-- Ops should return a 'Build' action if their original 'OpDef' marks them as
|
||||||
|
-- stateful, or if they take any Refs as input. (This mirrors the rules that
|
||||||
|
-- TensorFlow uses to avoid common subexpression elimination.)
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||||
|
|
||||||
|
module TensorFlow.Ops
|
||||||
|
( CoreOps.add
|
||||||
|
, CoreOps.abs
|
||||||
|
, CoreOps.addN
|
||||||
|
, CoreOps.argMax
|
||||||
|
, assign
|
||||||
|
, CoreOps.broadcastGradientArgs
|
||||||
|
, CoreOps.cast
|
||||||
|
, CoreOps.concat
|
||||||
|
, constant
|
||||||
|
, expandDims
|
||||||
|
, initializedVariable
|
||||||
|
, zeroInitializedVariable
|
||||||
|
, CoreOps.fill
|
||||||
|
, CoreOps.matMul
|
||||||
|
, matTranspose
|
||||||
|
, CoreOps.mul
|
||||||
|
, CoreOps.neg
|
||||||
|
, CoreOps.pack
|
||||||
|
, placeholder
|
||||||
|
, CoreOps.range
|
||||||
|
, reducedShape
|
||||||
|
, CoreOps.relu
|
||||||
|
, CoreOps.reluGrad
|
||||||
|
, CoreOps.reshape
|
||||||
|
, restore
|
||||||
|
, restoreFromName
|
||||||
|
, save
|
||||||
|
, scalar
|
||||||
|
, shape
|
||||||
|
, CoreOps.sign
|
||||||
|
, CoreOps.size
|
||||||
|
, CoreOps.softmax
|
||||||
|
, CoreOps.softmaxCrossEntropyWithLogits
|
||||||
|
, CoreOps.sparseToDense
|
||||||
|
, CoreOps.sub
|
||||||
|
, CoreOps.sum
|
||||||
|
, CoreOps.topK
|
||||||
|
, CoreOps.transpose
|
||||||
|
, truncatedNormal
|
||||||
|
, variable
|
||||||
|
, vector
|
||||||
|
, zeros
|
||||||
|
, CoreOps.zerosLike
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Complex (Complex)
|
||||||
|
import Data.Int (Int32, Int64)
|
||||||
|
import Prelude hiding (abs, sum, concat)
|
||||||
|
import Data.ProtoLens (def)
|
||||||
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
|
import Lens.Family2 ((.~), (&))
|
||||||
|
import Text.Printf (printf)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Tensor
|
||||||
|
( TensorProto
|
||||||
|
, dtype
|
||||||
|
, tensorShape
|
||||||
|
)
|
||||||
|
import qualified Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
|
as TensorShape
|
||||||
|
import TensorFlow.Build
|
||||||
|
import TensorFlow.BuildOp
|
||||||
|
import TensorFlow.ControlFlow (group)
|
||||||
|
import TensorFlow.Output (unNodeName)
|
||||||
|
import TensorFlow.Tensor
|
||||||
|
import TensorFlow.Types
|
||||||
|
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
|
import qualified Prelude (abs)
|
||||||
|
|
||||||
|
-- TODO: Look into hs-boot refactoring to allow mutually recursive imports.
|
||||||
|
-- | Must be defined as an orphan because of the dependency order between Ops
|
||||||
|
-- and Tensor.
|
||||||
|
--
|
||||||
|
-- The indirect constraint "v ~ Value" helps disambiguate types, for example in
|
||||||
|
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
|
||||||
|
-- "1".
|
||||||
|
instance ( TensorType a
|
||||||
|
, Num a
|
||||||
|
, v ~ Value
|
||||||
|
, OneOf '[ Double, Float, Int32, Int64
|
||||||
|
, Complex Float, Complex Double] a) => Num (Tensor v a) where
|
||||||
|
(+) = CoreOps.add
|
||||||
|
(*) = CoreOps.mul
|
||||||
|
(-) = CoreOps.sub
|
||||||
|
abs = CoreOps.abs
|
||||||
|
fromInteger = scalar . fromInteger
|
||||||
|
signum = CoreOps.sign
|
||||||
|
negate = CoreOps.neg
|
||||||
|
|
||||||
|
matTranspose :: forall a v . TensorType a
|
||||||
|
=> Tensor v a -> Tensor Value a
|
||||||
|
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
||||||
|
|
||||||
|
-- | Create a new, uninitialized stateful Tensor of the given shape.
|
||||||
|
variable :: forall a . TensorType a => Shape -> Build (Tensor Ref a)
|
||||||
|
variable shape' = buildOp $ opDef "Variable"
|
||||||
|
& opAttr "shape" .~ shape'
|
||||||
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
|
|
||||||
|
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
|
||||||
|
placeholder shape' =
|
||||||
|
buildOp $ opDef "Placeholder"
|
||||||
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
|
& opAttr "shape" .~ shape'
|
||||||
|
|
||||||
|
-- Assign returns the input ref.
|
||||||
|
assign :: forall a v . TensorType a
|
||||||
|
=> Tensor Ref a -> Tensor v a -> Build (Tensor Ref a)
|
||||||
|
assign = buildOp $ opDef "Assign"
|
||||||
|
& opAttr "T" .~ tensorType (undefined :: a)
|
||||||
|
& opAttr "use_locking" .~ True
|
||||||
|
|
||||||
|
-- | Creates a variable initialized to the given value.
|
||||||
|
-- Initialization happens next time session runs.
|
||||||
|
initializedVariable :: forall a . TensorType a
|
||||||
|
=> Tensor Value a -> Build (Tensor Ref a)
|
||||||
|
initializedVariable initializer = do
|
||||||
|
v <- variable [] -- The shape is not known initially.
|
||||||
|
(i :: Tensor Ref a) <-
|
||||||
|
buildOp (opDef "Assign"
|
||||||
|
& opAttr "T" .~ tensorType (undefined :: a)
|
||||||
|
& opAttr "use_locking" .~ True
|
||||||
|
& opAttr "validate_shape" .~ False
|
||||||
|
)
|
||||||
|
v initializer
|
||||||
|
addInitializer =<< group i
|
||||||
|
return v
|
||||||
|
|
||||||
|
-- | Creates a zero-initialized variable with the given shape.
|
||||||
|
zeroInitializedVariable
|
||||||
|
:: (TensorType a, Num a) =>
|
||||||
|
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
||||||
|
zeroInitializedVariable = initializedVariable . zeros
|
||||||
|
|
||||||
|
-- TODO: Support heterogeneous list of tensors.
|
||||||
|
save :: forall a v . TensorType a
|
||||||
|
=> ByteString -- ^ File path.
|
||||||
|
-> [Tensor v a] -- ^ Tensors to save.
|
||||||
|
-> Build ControlNode
|
||||||
|
save path xs = do
|
||||||
|
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
||||||
|
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
|
||||||
|
let types = replicate (length xs) (tensorType (undefined :: a))
|
||||||
|
let saveOp = buildOp $ opDef "Save"
|
||||||
|
& opAttr "T" .~ types
|
||||||
|
saveOp (scalar path) (CoreOps.pack names) xs
|
||||||
|
|
||||||
|
-- | Restore a tensor's value from a checkpoint file.
|
||||||
|
--
|
||||||
|
-- This version allows restoring from a checkpoint file that uses a different
|
||||||
|
-- tensor name than the variable.
|
||||||
|
restoreFromName :: forall a . TensorType a
|
||||||
|
=> ByteString -- ^ File path.
|
||||||
|
-> ByteString -- ^ Tensor name override.
|
||||||
|
-> Tensor Ref a -- ^ Tensor to restore.
|
||||||
|
-> Build ControlNode
|
||||||
|
restoreFromName path name x = do
|
||||||
|
let restoreOp = buildOp $ opDef "Restore"
|
||||||
|
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||||
|
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||||
|
|
||||||
|
-- | Restore a tensor's value from a checkpoint file.
|
||||||
|
restore :: forall a . TensorType a
|
||||||
|
=> ByteString -- ^ File path.
|
||||||
|
-> Tensor Ref a -- ^ Tensor to restore.
|
||||||
|
-> Build ControlNode
|
||||||
|
restore path x = do
|
||||||
|
name <- encodeUtf8 . unNodeName <$> renderNodeName x
|
||||||
|
restoreFromName path name x
|
||||||
|
|
||||||
|
-- | Create a constant tensor.
|
||||||
|
--
|
||||||
|
-- The values should be in row major order, e.g.,
|
||||||
|
--
|
||||||
|
-- element 0: index (0, ..., 0)
|
||||||
|
-- element 1: index (0, ..., 1)
|
||||||
|
-- ...
|
||||||
|
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
|
||||||
|
constant (Shape shape') values
|
||||||
|
| invalidLength = error invalidLengthMsg
|
||||||
|
| otherwise = buildOp $ opDef "Const"
|
||||||
|
& opAttr "value" .~ typedNode
|
||||||
|
& opAttr "dtype" .~ nodeType
|
||||||
|
where
|
||||||
|
invalidLength = product shape' /= fromIntegral (length values)
|
||||||
|
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
|
||||||
|
(product shape')
|
||||||
|
(length values)
|
||||||
|
nodeType = tensorType (undefined :: a)
|
||||||
|
typedNode :: TensorProto
|
||||||
|
typedNode = def
|
||||||
|
& dtype .~ nodeType
|
||||||
|
& tensorShape.TensorShape.dim .~
|
||||||
|
[def & TensorShape.size .~ x | x <- shape']
|
||||||
|
& tensorVal .~ values
|
||||||
|
|
||||||
|
-- | Create a constant vector.
|
||||||
|
vector :: TensorType a => [a] -> Tensor Value a
|
||||||
|
vector xs = constant [fromIntegral $ length xs] xs
|
||||||
|
|
||||||
|
-- | Create a constant scalar.
|
||||||
|
scalar :: forall a . TensorType a => a -> Tensor Value a
|
||||||
|
scalar x = constant [] [x]
|
||||||
|
|
||||||
|
-- Random tensor from the unit normal distribution with bounded values.
|
||||||
|
truncatedNormal :: forall a v . TensorType a
|
||||||
|
=> Tensor v Int64 -- ^ Shape.
|
||||||
|
-> Build (Tensor Value a)
|
||||||
|
truncatedNormal = buildOp $ opDef "TruncatedNormal"
|
||||||
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
|
& opAttr "T" .~ tensorType (undefined :: Int64)
|
||||||
|
|
||||||
|
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
||||||
|
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
||||||
|
|
||||||
|
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
|
||||||
|
shape = CoreOps.shape
|
||||||
|
|
||||||
|
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||||
|
expandDims = CoreOps.expandDims
|
||||||
|
|
||||||
|
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
|
||||||
|
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
|
||||||
|
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
|
||||||
|
reducedShape inputShape axes =
|
||||||
|
let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7]
|
||||||
|
axes32 = toInt32 axes -- [1, 2]
|
||||||
|
toInt32 x = CoreOps.cast x :: Tensor Value Int32
|
||||||
|
inputRank = CoreOps.size inputShape32 -- 4
|
||||||
|
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
|
||||||
|
axesShape = shape axesMod -- [2]
|
||||||
|
in CoreOps.dynamicStitch -- [2, 1, 1, 7]
|
||||||
|
[CoreOps.range 0 inputRank 1, -- [0, 1, 2, 3]
|
||||||
|
axesMod] -- [1, 2]
|
||||||
|
[inputShape32, -- [2, 3, 5, 7]
|
||||||
|
CoreOps.fill axesShape 1] -- [1, 1]
|
191
tensorflow-ops/tensorflow-ops.cabal
Normal file
191
tensorflow-ops/tensorflow-ops.cabal
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
name: tensorflow-ops
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Friendly layer around TensorFlow bindings.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.Gradient
|
||||||
|
, TensorFlow.Ops
|
||||||
|
, TensorFlow.EmbeddingOps
|
||||||
|
build-depends: proto-lens == 0.1.*
|
||||||
|
, base >= 4.7 && < 5
|
||||||
|
, bytestring
|
||||||
|
, fgl
|
||||||
|
, mtl
|
||||||
|
, data-default
|
||||||
|
, lens-family
|
||||||
|
, containers
|
||||||
|
, tensorflow == 0.1.*
|
||||||
|
, tensorflow-proto == 0.1.*
|
||||||
|
, tensorflow-core-ops == 0.1.*
|
||||||
|
, text
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
Test-Suite BuildTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: BuildTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, base
|
||||||
|
, proto-lens
|
||||||
|
, lens-family
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, transformers
|
||||||
|
, vector
|
||||||
|
|
||||||
|
Test-Suite EmbeddingOpsTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: EmbeddingOpsTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, QuickCheck
|
||||||
|
, base
|
||||||
|
, proto-lens
|
||||||
|
, lens-family
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, test-framework-quickcheck2
|
||||||
|
, vector
|
||||||
|
|
||||||
|
Test-Suite ArrayOpsTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: ArrayOpsTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, QuickCheck
|
||||||
|
, base
|
||||||
|
, proto-lens
|
||||||
|
, lens-family
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, test-framework-quickcheck2
|
||||||
|
, transformers
|
||||||
|
, vector
|
||||||
|
|
||||||
|
Test-Suite OpsTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: OpsTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, QuickCheck
|
||||||
|
, base
|
||||||
|
, bytestring
|
||||||
|
, proto-lens
|
||||||
|
, lens-family
|
||||||
|
, google-shim
|
||||||
|
, temporary
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, test-framework-quickcheck2
|
||||||
|
, transformers
|
||||||
|
, vector
|
||||||
|
|
||||||
|
Test-Suite DataFlowOpsTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: DataFlowOpsTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, QuickCheck
|
||||||
|
, base
|
||||||
|
, proto-lens
|
||||||
|
, lens-family
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, test-framework-quickcheck2
|
||||||
|
, vector
|
||||||
|
|
||||||
|
Test-Suite GradientTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: GradientTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, base
|
||||||
|
, proto-lens
|
||||||
|
, lens-family
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
|
||||||
|
Test-Suite MiscTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: MiscTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, base
|
||||||
|
, bytestring
|
||||||
|
, vector
|
||||||
|
, google-shim
|
||||||
|
, transformers
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
|
||||||
|
Test-Suite TypesTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: TypesTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, QuickCheck
|
||||||
|
, base
|
||||||
|
, bytestring
|
||||||
|
, proto-lens
|
||||||
|
, lens-family
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-proto
|
||||||
|
, transformers
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, test-framework-quickcheck2
|
||||||
|
, vector
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
42
tensorflow-ops/tests/ArrayOpsTest.hs
Normal file
42
tensorflow-ops/tests/ArrayOpsTest.hs
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Session as TF
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
|
-- | Test split and concat are inverses.
|
||||||
|
testSplit = testCase "testSplit" $ TF.runSession $ do
|
||||||
|
let original = TF.constant [2, 3] [0..5 :: Float]
|
||||||
|
splitList = CoreOps.split 3 dim original
|
||||||
|
restored = CoreOps.concat dim splitList
|
||||||
|
dim = 1 -- dimension to split along (with size of 3 in original)
|
||||||
|
liftIO $ 3 @=? length splitList
|
||||||
|
(x, y, z) <-
|
||||||
|
TF.buildAnd TF.run $ return (original, restored, splitList !! 1)
|
||||||
|
liftIO $ x @=? (y :: V.Vector Float)
|
||||||
|
liftIO $ V.fromList [1, 4] @=? z
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testSplit
|
||||||
|
]
|
181
tensorflow-ops/tests/BuildTest.hs
Normal file
181
tensorflow-ops/tests/BuildTest.hs
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Data.Functor.Identity (runIdentity)
|
||||||
|
import Lens.Family2 ((^.))
|
||||||
|
import Data.List (sort)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Graph
|
||||||
|
( node )
|
||||||
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
( NodeDef
|
||||||
|
, device
|
||||||
|
, name
|
||||||
|
, op )
|
||||||
|
import TensorFlow.Build
|
||||||
|
( Build
|
||||||
|
, BuildT
|
||||||
|
, asGraphDef
|
||||||
|
, evalBuildT
|
||||||
|
, flushNodeBuffer
|
||||||
|
, hoistBuildT
|
||||||
|
, render
|
||||||
|
, withDevice
|
||||||
|
, colocateWith
|
||||||
|
, withNameScope
|
||||||
|
)
|
||||||
|
import TensorFlow.ControlFlow (named)
|
||||||
|
import TensorFlow.Nodes (unScalar)
|
||||||
|
import TensorFlow.Ops
|
||||||
|
( add
|
||||||
|
, assign
|
||||||
|
, constant
|
||||||
|
, initializedVariable
|
||||||
|
, variable
|
||||||
|
)
|
||||||
|
import TensorFlow.Output (Device(..))
|
||||||
|
import TensorFlow.Tensor (Tensor, Value, Ref)
|
||||||
|
import TensorFlow.Session
|
||||||
|
( build
|
||||||
|
, buildAnd
|
||||||
|
, run
|
||||||
|
, runSession
|
||||||
|
, run_
|
||||||
|
)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
-- | Test named behavior.
|
||||||
|
testNamed = testCase "testNamed" $ do
|
||||||
|
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
|
||||||
|
nodeDef :: NodeDef
|
||||||
|
nodeDef = head $ asGraphDef graph ^. node
|
||||||
|
"RefIdentity" @=? (nodeDef ^. op)
|
||||||
|
"foo" @=? (nodeDef ^. name)
|
||||||
|
|
||||||
|
-- | Test named deRef behavior.
|
||||||
|
testNamedDeRef = testCase "testNamedDeRef" $ do
|
||||||
|
let graph = named "foo" <$> do
|
||||||
|
v :: Tensor Ref Float <- variable []
|
||||||
|
assign v 5
|
||||||
|
-- TODO: Implement TensorFlow get_variable and test it.
|
||||||
|
runSession $ do
|
||||||
|
out <- buildAnd run graph
|
||||||
|
liftIO $ 5 @=? (unScalar out :: Float)
|
||||||
|
|
||||||
|
-- | Test that "run" will render and extend any pure ops that haven't already
|
||||||
|
-- been rendered.
|
||||||
|
testPureRender = testCase "testPureRender" $ runSession $ do
|
||||||
|
result <- run $ 2 `add` 2
|
||||||
|
liftIO $ 4 @=? (unScalar result :: Float)
|
||||||
|
|
||||||
|
-- | Test that "run" assigns any previously accumulated initializers.
|
||||||
|
testInitializedVariable =
|
||||||
|
testCase "testInitializedVariable" $ runSession $ do
|
||||||
|
(formula, reset) <- build $ do
|
||||||
|
v <- initializedVariable 42
|
||||||
|
r <- assign v 24
|
||||||
|
return (1 `add` v, r)
|
||||||
|
result <- run formula
|
||||||
|
liftIO $ 43 @=? (unScalar result :: Float)
|
||||||
|
run_ reset -- Updates v to a different value
|
||||||
|
rerunResult <- run formula
|
||||||
|
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
||||||
|
|
||||||
|
testInitializedVariableShape =
|
||||||
|
testCase "testInitializedVariableShape" $ runSession $ do
|
||||||
|
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
||||||
|
result <- run vector
|
||||||
|
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||||
|
|
||||||
|
-- | Test nameScoped behavior.
|
||||||
|
testNameScoped = testCase "testNameScoped" $ do
|
||||||
|
let graph = withNameScope "foo" $ variable [] :: Build (Tensor Ref Float)
|
||||||
|
nodeDef :: NodeDef
|
||||||
|
[nodeDef] = asGraphDef graph ^. node
|
||||||
|
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
|
||||||
|
"Variable" @=? (nodeDef ^. op)
|
||||||
|
|
||||||
|
-- | Test combined named and nameScoped behavior.
|
||||||
|
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||||
|
let graph :: Build (Tensor Ref Float)
|
||||||
|
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
|
||||||
|
nodeDef :: NodeDef
|
||||||
|
nodeDef = head $ asGraphDef graph ^. node
|
||||||
|
"RefIdentity" @=? (nodeDef ^. op)
|
||||||
|
"foo1/bar1" @=? (nodeDef ^. name)
|
||||||
|
|
||||||
|
-- | Lift a Build action into a context for HUnit to run.
|
||||||
|
liftBuild :: Build a -> BuildT IO a
|
||||||
|
liftBuild = hoistBuildT (return . runIdentity)
|
||||||
|
|
||||||
|
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
|
||||||
|
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
|
||||||
|
flushed field = sort . map field <$> liftBuild flushNodeBuffer
|
||||||
|
|
||||||
|
-- | Test the interaction of rendering, CSE and scoping.
|
||||||
|
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
||||||
|
liftBuild renderNodes
|
||||||
|
names <- flushed (^. name)
|
||||||
|
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
|
||||||
|
-- Render the nodes in a different scope, which should cause them
|
||||||
|
-- to be distinct from the previous ones.
|
||||||
|
liftBuild $ withNameScope "foo" renderNodes
|
||||||
|
scopedNames <- flushed (^. name)
|
||||||
|
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
|
||||||
|
where
|
||||||
|
renderNodes = do
|
||||||
|
-- A stateful op and a pure op.
|
||||||
|
_ :: Tensor Ref Float <- variable []
|
||||||
|
_ :: Tensor Value Float <- render 3
|
||||||
|
-- Another stateful op, and a pure op which should be
|
||||||
|
-- deduped with the previous one.
|
||||||
|
_ :: Tensor Ref Float <- variable []
|
||||||
|
_ :: Tensor Value Float <- render 3
|
||||||
|
return ()
|
||||||
|
|
||||||
|
-- | Test the interaction of rendering, CSE and scoping.
|
||||||
|
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
|
||||||
|
liftBuild renderNodes
|
||||||
|
devices <- flushed (\x -> (x ^. name, x ^. device))
|
||||||
|
liftIO $ [ ("Add_2","dev0")
|
||||||
|
, ("Const_1","dev0")
|
||||||
|
, ("Variable_0","dev0")] @=? devices
|
||||||
|
where
|
||||||
|
renderNodes = do
|
||||||
|
-- A stateful op and a pure op.
|
||||||
|
var :: Tensor Ref Float <- withDevice (Just $ Device "dev0") $ variable []
|
||||||
|
-- Uses render to cause the expression be added to the graph.
|
||||||
|
_ <- colocateWith var $ render $ 3 `add` var
|
||||||
|
return ()
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testInitializedVariable
|
||||||
|
, testInitializedVariableShape
|
||||||
|
, testDeviceColocation
|
||||||
|
, testNamed
|
||||||
|
, testNamedDeRef
|
||||||
|
, testNameScoped
|
||||||
|
, testNamedAndScoped
|
||||||
|
, testPureRender
|
||||||
|
, testRenderDedup
|
||||||
|
]
|
66
tensorflow-ops/tests/DataFlowOpsTest.hs
Normal file
66
tensorflow-ops/tests/DataFlowOpsTest.hs
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
import Data.Int (Int32, Int64)
|
||||||
|
import Data.List (genericLength)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||||
|
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||||
|
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Session as TF
|
||||||
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
|
-- DynamicSplit is undone with DynamicStitch to get the original input
|
||||||
|
-- back.
|
||||||
|
testDynamicPartitionStitchInverse :: forall a.
|
||||||
|
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property
|
||||||
|
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
||||||
|
let splitParts :: [TF.Tensor TF.Value a] =
|
||||||
|
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
||||||
|
partTensor = TF.vector partitions
|
||||||
|
restitchIndices = CoreOps.dynamicPartition numParts
|
||||||
|
(TF.vector [0..genericLength values-1])
|
||||||
|
partTensor
|
||||||
|
-- drop (numParts - 2) from both args to expose b/27343984
|
||||||
|
restitch = CoreOps.dynamicStitch restitchIndices splitParts
|
||||||
|
in monadicIO $ run $ do
|
||||||
|
fromIntegral numParts @=? length splitParts
|
||||||
|
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch
|
||||||
|
V.fromList values @=? valuesOut
|
||||||
|
|
||||||
|
data StitchExample a = StitchExample Int64 [a] [Int32]
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
instance Arbitrary a => Arbitrary (StitchExample a) where
|
||||||
|
arbitrary = do
|
||||||
|
-- Limits the size of the vector.
|
||||||
|
size <- choose (1, 100)
|
||||||
|
values <- vectorOf size arbitrary
|
||||||
|
numParts <- choose (2, 15)
|
||||||
|
partitions <- vectorOf size (choose (0, fromIntegral numParts - 1))
|
||||||
|
return $ StitchExample numParts values partitions
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest
|
||||||
|
[ testProperty "DynamicPartitionStitchInverse"
|
||||||
|
(testDynamicPartitionStitchInverse :: StitchExample Int64 -> Property)
|
||||||
|
]
|
88
tensorflow-ops/tests/EmbeddingOpsTest.hs
Normal file
88
tensorflow-ops/tests/EmbeddingOpsTest.hs
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
-- | Tests for EmbeddingOps.
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Data.Int (Int32, Int64)
|
||||||
|
import Data.List (genericLength)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import TensorFlow.EmbeddingOps (embeddingLookup)
|
||||||
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||||
|
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||||
|
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Session as TF
|
||||||
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
|
-- Verifies that direct gather is the same as dynamic split into
|
||||||
|
-- partitions, followed by embedding lookup.
|
||||||
|
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
|
||||||
|
=> LookupExample a -> Property
|
||||||
|
testEmbeddingLookupUndoesSplit
|
||||||
|
(LookupExample numParts
|
||||||
|
shape@(TF.Shape (firstDim : restDims))
|
||||||
|
values
|
||||||
|
indices) =
|
||||||
|
let modShardedValues :: [TF.Tensor TF.Value a] =
|
||||||
|
CoreOps.dynamicPartition numParts shapedValues cyclicCounter
|
||||||
|
cyclicCounter :: TF.Tensor TF.Value Int32 =
|
||||||
|
TF.vector [0..fromIntegral firstDim-1]
|
||||||
|
`CoreOps.mod` fromIntegral numParts
|
||||||
|
indicesVector = TF.vector indices
|
||||||
|
directs = CoreOps.gather shapedValues indicesVector
|
||||||
|
shapedValues = TF.constant shape values
|
||||||
|
in monadicIO $ run $ do
|
||||||
|
(shapeOut, got, want :: V.Vector a) <-
|
||||||
|
TF.runSession $ TF.buildAnd TF.run $ do
|
||||||
|
embeddings <- embeddingLookup modShardedValues indicesVector
|
||||||
|
return (TF.cast (TF.shape embeddings), embeddings, directs)
|
||||||
|
-- Checks the explicitly documented invariant of embeddingLookup.
|
||||||
|
shapeOut @=? V.fromList (genericLength indices : restDims)
|
||||||
|
got @=? want
|
||||||
|
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
|
||||||
|
|
||||||
|
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit.
|
||||||
|
data LookupExample a = LookupExample
|
||||||
|
Int64 -- ^ number of ways to split.
|
||||||
|
TF.Shape -- ^ shape of the generated tensor
|
||||||
|
[a] -- ^ data for the tensor
|
||||||
|
[Int32] -- ^ indices to split the tensor by
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
instance Arbitrary a => Arbitrary (LookupExample a) where
|
||||||
|
arbitrary = do
|
||||||
|
rank <- choose (1, 4)
|
||||||
|
-- Takes rank-th root of 100 to cap the tensor size.
|
||||||
|
let maxDim = fromIntegral $ ceiling $ 100 ** (1 / fromIntegral rank)
|
||||||
|
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
|
||||||
|
values <- vectorOf (fromIntegral $ product shape) arbitrary
|
||||||
|
numParts <- choose (2, 15)
|
||||||
|
indSize <- choose (0, fromIntegral $ firstDim - 1)
|
||||||
|
indices <- vectorOf indSize (choose (0, fromIntegral firstDim - 1))
|
||||||
|
return $ LookupExample numParts (TF.Shape shape) values indices
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest
|
||||||
|
[ testProperty "EmbeddingLookupUndoesSplit"
|
||||||
|
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
|
||||||
|
]
|
158
tensorflow-ops/tests/GradientTest.hs
Normal file
158
tensorflow-ops/tests/GradientTest.hs
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
import Data.List (sort)
|
||||||
|
import Data.ProtoLens.TextFormat (showMessage)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import Lens.Family2 ((^..))
|
||||||
|
import Test.Framework (Test)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
|
||||||
|
import qualified TensorFlow.Build as TF
|
||||||
|
import qualified TensorFlow.Gradient as TF
|
||||||
|
import qualified TensorFlow.Nodes as TF
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Session as TF
|
||||||
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
|
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||||
|
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||||
|
|
||||||
|
testGradientSimple :: Test
|
||||||
|
testGradientSimple = testCase "testGradientSimple" $ do
|
||||||
|
let x = TF.scalar (3 :: Float)
|
||||||
|
b = TF.scalar (4 :: Float)
|
||||||
|
y = x*x + b
|
||||||
|
grads = TF.gradients y [x, b]
|
||||||
|
-- Assert that the gradients are right.
|
||||||
|
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
||||||
|
6 @=? TF.unScalar dx
|
||||||
|
1 @=? TF.unScalar db
|
||||||
|
-- Assert that the graph has the expected ops.
|
||||||
|
let graphDef = TF.asGraphDef grads
|
||||||
|
putStrLn $ showMessage graphDef
|
||||||
|
let ops = graphDef ^.. node . traverse . op
|
||||||
|
expected = [ "Const"
|
||||||
|
, "Mul"
|
||||||
|
, "Const"
|
||||||
|
, "Add"
|
||||||
|
-- Default output gradient of y.
|
||||||
|
, "Shape"
|
||||||
|
, "Const"
|
||||||
|
, "Fill"
|
||||||
|
-- Add gradient.
|
||||||
|
, "Shape"
|
||||||
|
, "Shape"
|
||||||
|
, "BroadcastGradientArgs"
|
||||||
|
, "Sum"
|
||||||
|
, "Sum"
|
||||||
|
, "Reshape"
|
||||||
|
, "Reshape"
|
||||||
|
-- Mul gradient.
|
||||||
|
, "Shape"
|
||||||
|
-- This Op gets dedup'd because the inputs are the same.
|
||||||
|
-- TODO(fmayle): The same would happen to the Mul and Sum ops
|
||||||
|
-- below if the gradient function didn't multiply one as
|
||||||
|
-- 'dz * y' and the other as 'x * dz'. We could change the
|
||||||
|
-- order, but I'm going to keep it the same as the python
|
||||||
|
-- version for now.
|
||||||
|
--
|
||||||
|
-- , "Shape"
|
||||||
|
, "BroadcastGradientArgs"
|
||||||
|
, "Mul"
|
||||||
|
, "Mul"
|
||||||
|
, "Sum"
|
||||||
|
, "Sum"
|
||||||
|
, "Reshape"
|
||||||
|
, "Reshape"
|
||||||
|
-- AddN to combine x's output gradients.
|
||||||
|
, "AddN"
|
||||||
|
]
|
||||||
|
sort expected @=? sort ops
|
||||||
|
|
||||||
|
testGradientDisconnected :: Test
|
||||||
|
testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
||||||
|
let x = TF.scalar (3 :: Float)
|
||||||
|
b = TF.scalar (4 :: Float)
|
||||||
|
grads = TF.gradients x [x, b]
|
||||||
|
-- Assert that the gradients are right.
|
||||||
|
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
||||||
|
1 @=? TF.unScalar dx
|
||||||
|
0 @=? TF.unScalar db
|
||||||
|
-- Assert that the graph has the expected ops.
|
||||||
|
let graphDef = TF.asGraphDef grads
|
||||||
|
putStrLn $ showMessage graphDef
|
||||||
|
let ops = graphDef ^.. node . traverse . op
|
||||||
|
expected = [ "Const"
|
||||||
|
, "Const"
|
||||||
|
-- Default output gradient of x.
|
||||||
|
, "Shape"
|
||||||
|
, "Const"
|
||||||
|
, "Fill"
|
||||||
|
-- Default output gradient of b.
|
||||||
|
, "ZerosLike"
|
||||||
|
]
|
||||||
|
sort expected @=? sort ops
|
||||||
|
|
||||||
|
|
||||||
|
-- Test that identical "stateful" ops work with createGraph.
|
||||||
|
testCreateGraphStateful :: Test
|
||||||
|
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
||||||
|
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||||
|
let shape = TF.constant (TF.Shape [1]) [1]
|
||||||
|
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||||
|
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||||
|
TF.gradients (x + y*3) [x, y]
|
||||||
|
-- If this test fails, it will likely be caused by an exception within
|
||||||
|
-- `TF.gradients`. These asserts are extra.
|
||||||
|
1 @=? TF.unScalar dx
|
||||||
|
3 @=? TF.unScalar dy
|
||||||
|
|
||||||
|
|
||||||
|
-- Test that name scopes work with createGraph.
|
||||||
|
testCreateGraphNameScopes :: Test
|
||||||
|
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
||||||
|
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||||
|
let shape = TF.constant (TF.Shape [1]) [1]
|
||||||
|
x :: TF.Tensor TF.Value Float <-
|
||||||
|
TF.withNameScope "foo" (TF.truncatedNormal shape)
|
||||||
|
TF.gradients x [x]
|
||||||
|
-- If this test fails, it will likely be caused by an exception within
|
||||||
|
-- `TF.gradients`. This assert is extra.
|
||||||
|
1 @=? TF.unScalar dx
|
||||||
|
|
||||||
|
|
||||||
|
-- Test that createGraph can handle graphs with diamond shapes.
|
||||||
|
testDiamond :: Test
|
||||||
|
testDiamond = testCase "testDiamond" $ do
|
||||||
|
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||||
|
let x = TF.vector [1]
|
||||||
|
y = x*x
|
||||||
|
z = y*y
|
||||||
|
TF.gradients z [x]
|
||||||
|
(4 :: Float) @=? TF.unScalar dx
|
||||||
|
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testGradientSimple
|
||||||
|
, testGradientDisconnected
|
||||||
|
, testCreateGraphStateful
|
||||||
|
, testCreateGraphNameScopes
|
||||||
|
, testDiamond
|
||||||
|
]
|
46
tensorflow-ops/tests/MiscTest.hs
Normal file
46
tensorflow-ops/tests/MiscTest.hs
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Data.Int (Int32)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
import Google.Test
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
import TensorFlow.Ops
|
||||||
|
import TensorFlow.Session
|
||||||
|
|
||||||
|
-- | Test fetching multiple outputs from an op.
|
||||||
|
testMultipleOutputs = testCase "testMultipleOutputs" $
|
||||||
|
runSession $ do
|
||||||
|
(values, indices) <- run $ topK 2 $ constant [1, 4] [10, 40, 20, 30]
|
||||||
|
liftIO $ [40, 30] @=? V.toList (values :: V.Vector Float)
|
||||||
|
liftIO $ [1, 3] @=? V.toList (indices :: V.Vector Int32)
|
||||||
|
|
||||||
|
-- | Test op with variable number of inputs.
|
||||||
|
testVarargs = testCase "testVarargs" $
|
||||||
|
runSession $ do
|
||||||
|
xs <- run $ pack $ map scalar [1..8]
|
||||||
|
liftIO $ [1..8] @=? V.toList (xs :: V.Vector Float)
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testMultipleOutputs
|
||||||
|
, testVarargs
|
||||||
|
]
|
70
tensorflow-ops/tests/OpsTest.hs
Normal file
70
tensorflow-ops/tests/OpsTest.hs
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Data.Int (Int32, Int64)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import System.IO.Temp (withSystemTempDirectory)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
import qualified Data.ByteString.Char8 as B8
|
||||||
|
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
import qualified TensorFlow.Build as TF
|
||||||
|
import qualified TensorFlow.ControlFlow as TF
|
||||||
|
import qualified TensorFlow.Nodes as TF
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Session as TF
|
||||||
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
|
||||||
|
-- | Test that one can easily determine number of elements in the tensor.
|
||||||
|
testSize = testCase "testSize" $ do
|
||||||
|
x <- eval $ TF.size (TF.constant [2, 3] [0..5 :: Float])
|
||||||
|
TF.Scalar (2 * 3 :: Int32) @=? x
|
||||||
|
|
||||||
|
eval = TF.runSession . TF.buildAnd TF.run . return
|
||||||
|
|
||||||
|
-- | Confirms that the original example from Python code works.
|
||||||
|
testReducedShape = testCase "testReducedShape" $ do
|
||||||
|
x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64])
|
||||||
|
(TF.vector [1, 2 :: Int32])
|
||||||
|
V.fromList [2, 1, 1, 7 :: Int32] @=? x
|
||||||
|
|
||||||
|
testSaveRestore = testCase "testSaveRestore" $
|
||||||
|
withSystemTempDirectory "" $ \dirPath -> do
|
||||||
|
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||||
|
var :: TF.Build (TF.Tensor TF.Ref Float)
|
||||||
|
var = TF.render =<< TF.named "a" <$> TF.zeroInitializedVariable []
|
||||||
|
TF.runSession $ do
|
||||||
|
v <- TF.build var
|
||||||
|
TF.buildAnd TF.run_ $ TF.assign v 134
|
||||||
|
TF.buildAnd TF.run_ $ TF.save path [v]
|
||||||
|
result <- TF.runSession $ do
|
||||||
|
v <- TF.build var
|
||||||
|
TF.buildAnd TF.run_ $ TF.restore path v
|
||||||
|
TF.run v
|
||||||
|
liftIO $ TF.Scalar 134 @=? result
|
||||||
|
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testSaveRestore
|
||||||
|
, testSize
|
||||||
|
, testReducedShape
|
||||||
|
]
|
119
tensorflow-ops/tests/TypesTest.hs
Normal file
119
tensorflow-ops/tests/TypesTest.hs
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||||
|
|
||||||
|
import Control.Monad (replicateM)
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Data.Int (Int64)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
import Test.QuickCheck (Arbitrary(..), listOf, suchThat)
|
||||||
|
import qualified Data.ByteString as B
|
||||||
|
import qualified Data.ByteString.Char8 as B8
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
import qualified TensorFlow.ControlFlow as TF
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Session as TF
|
||||||
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
|
instance Arbitrary B.ByteString where
|
||||||
|
arbitrary = B.pack <$> arbitrary
|
||||||
|
|
||||||
|
-- Test encoding tensors, feeding them through tensorflow, and decoding the
|
||||||
|
-- results.
|
||||||
|
testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
||||||
|
TF.runSession $ do
|
||||||
|
let floatData = V.fromList [1..6 :: Float]
|
||||||
|
stringData = V.fromList [B8.pack (show x) | x <- [1..6]]
|
||||||
|
f <- TF.build $ TF.placeholder [2,3]
|
||||||
|
s <- TF.build $ TF.placeholder [2,3]
|
||||||
|
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
||||||
|
, TF.feed s (TF.encodeTensorData [2,3] stringData)
|
||||||
|
]
|
||||||
|
-- It is an error to fetch a tensor that is being fed, so the tensors
|
||||||
|
-- are passed through identity.
|
||||||
|
(f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s)
|
||||||
|
liftIO $ do
|
||||||
|
floatData @=? f'
|
||||||
|
stringData @=? s'
|
||||||
|
|
||||||
|
|
||||||
|
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
instance Arbitrary a => Arbitrary (TensorDataInputs a) where
|
||||||
|
arbitrary = do
|
||||||
|
-- Limit the size of the final vector, and also guard against overflow
|
||||||
|
-- (i.e., p<0) when there are too many dimensions
|
||||||
|
let validProduct p = p > 0 && p < 100
|
||||||
|
sizes <- listOf (arbitrary `suchThat` (>0))
|
||||||
|
`suchThat` (validProduct . product)
|
||||||
|
elems <- replicateM (fromIntegral $ product sizes) arbitrary
|
||||||
|
return $ TensorDataInputs sizes (V.fromList elems)
|
||||||
|
|
||||||
|
-- Test that a vector is unchanged after being encoded and decoded.
|
||||||
|
encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
|
||||||
|
encodeDecodeProp (TensorDataInputs shape vec) =
|
||||||
|
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
|
||||||
|
|
||||||
|
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
|
||||||
|
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
|
||||||
|
|
||||||
|
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
|
||||||
|
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
|
||||||
|
|
||||||
|
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
|
||||||
|
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
|
||||||
|
|
||||||
|
doubleOrInt64Func :: TF.OneOf '[Double, Int64] a => a -> a
|
||||||
|
doubleOrInt64Func = id
|
||||||
|
|
||||||
|
doubleOrFloatFunc :: TF.OneOf '[Double, Float] a => a -> a
|
||||||
|
doubleOrFloatFunc = id
|
||||||
|
|
||||||
|
doubleFunc :: TF.OneOf '[Double] a => a -> a
|
||||||
|
doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
|
||||||
|
|
||||||
|
-- No explicit type signature; make sure it can be inferred automatically.
|
||||||
|
-- (Note: this would fail if we didn't have NoMonomorphismRestriction, since it
|
||||||
|
-- can't simplify the type all the way to `Double -> Double`.
|
||||||
|
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
|
||||||
|
|
||||||
|
typeConstraintTests = testCase "type constraints" $ do
|
||||||
|
42 @=? doubleOrInt64Func (42 :: Double)
|
||||||
|
42 @=? doubleOrInt64Func (42 :: Int64)
|
||||||
|
42 @=? doubleOrFloatFunc (42 :: Double)
|
||||||
|
42 @=? doubleOrFloatFunc (42 :: Float)
|
||||||
|
42 @=? doubleFunc (42 :: Double)
|
||||||
|
42 @=? doubleFuncNoSig (42 :: Double)
|
||||||
|
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testFFIRoundTrip
|
||||||
|
, testEncodeDecodeQcFloat
|
||||||
|
, testEncodeDecodeQcInt64
|
||||||
|
, testEncodeDecodeQcString
|
||||||
|
, typeConstraintTests
|
||||||
|
]
|
17
tensorflow-proto/Setup.hs
Normal file
17
tensorflow-proto/Setup.hs
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
import Data.ProtoLens.Setup
|
||||||
|
|
||||||
|
main = defaultMainGeneratingProtos "../third_party/tensorflow"
|
40
tensorflow-proto/tensorflow-proto.cabal
Normal file
40
tensorflow-proto/tensorflow-proto.cabal
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
name: tensorflow-proto
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: TensorFlow protocol buffers.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Custom
|
||||||
|
cabal-version: >=1.22
|
||||||
|
extra-source-files: ../third_party/tensorflow/tensorflow/core/framework/*.proto
|
||||||
|
, ../third_party/tensorflow/tensorflow/core/protobuf/config.proto
|
||||||
|
|
||||||
|
library
|
||||||
|
exposed-modules: Proto.Tensorflow.Core.Framework.AttrValue
|
||||||
|
, Proto.Tensorflow.Core.Framework.Graph
|
||||||
|
, Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
, Proto.Tensorflow.Core.Framework.OpDef
|
||||||
|
, Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||||
|
, Proto.Tensorflow.Core.Framework.Tensor
|
||||||
|
, Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
|
, Proto.Tensorflow.Core.Framework.Types
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.Config
|
||||||
|
other-modules: Proto.Tensorflow.Core.Framework.AllocationDescription
|
||||||
|
, Proto.Tensorflow.Core.Framework.CostGraph
|
||||||
|
, Proto.Tensorflow.Core.Framework.Function
|
||||||
|
, Proto.Tensorflow.Core.Framework.StepStats
|
||||||
|
, Proto.Tensorflow.Core.Framework.TensorDescription
|
||||||
|
, Proto.Tensorflow.Core.Framework.Versions
|
||||||
|
build-depends: proto-lens == 0.1.*
|
||||||
|
, proto-lens-protoc == 0.1.*
|
||||||
|
, base >= 4.7 && < 5
|
||||||
|
default-language: Haskell2010
|
||||||
|
include-dirs: .
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
3
tensorflow-queue/Setup.hs
Normal file
3
tensorflow-queue/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
78
tensorflow-queue/src/TensorFlow/Queue.hs
Normal file
78
tensorflow-queue/src/TensorFlow/Queue.hs
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
-- | Queues in TensorFlow graph. Very limited support for now.
|
||||||
|
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where
|
||||||
|
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Int (Int64)
|
||||||
|
import Lens.Family2 ((.~), (&))
|
||||||
|
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
|
||||||
|
import TensorFlow.BuildOp (buildOp)
|
||||||
|
import TensorFlow.ControlFlow (group)
|
||||||
|
import TensorFlow.Tensor (Ref, Tensor)
|
||||||
|
import TensorFlow.Types (TensorType, tensorType)
|
||||||
|
|
||||||
|
-- | A queue carrying tuples. The underlying structure is more
|
||||||
|
-- versatile and can be made to support arbitrary tuples.
|
||||||
|
data Queue2 a b = Queue2 { handle :: Handle }
|
||||||
|
|
||||||
|
type Handle = Tensor Ref ByteString
|
||||||
|
|
||||||
|
-- | Adds the given values to the queue.
|
||||||
|
enqueue :: forall a b v1 v2. (TensorType a, TensorType b)
|
||||||
|
=> Queue2 a b
|
||||||
|
-> Tensor v1 a
|
||||||
|
-> Tensor v2 b
|
||||||
|
-> Build ControlNode
|
||||||
|
enqueue q =
|
||||||
|
buildOp (opDef "QueueEnqueue"
|
||||||
|
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a)
|
||||||
|
, tensorType (undefined :: b)])
|
||||||
|
(handle q)
|
||||||
|
|
||||||
|
-- | Retrieves the values from the queue.
|
||||||
|
dequeue :: forall a b . (TensorType a, TensorType b)
|
||||||
|
=> Queue2 a b
|
||||||
|
-> Build (Tensor Ref a, Tensor Ref b)
|
||||||
|
-- ^ Dequeued tensors. They are paired in a sense
|
||||||
|
-- that values appear together, even if they are
|
||||||
|
-- not consumed together.
|
||||||
|
dequeue q =
|
||||||
|
buildOp (opDef "QueueDequeue"
|
||||||
|
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
||||||
|
, tensorType (undefined :: b)])
|
||||||
|
(handle q)
|
||||||
|
|
||||||
|
-- | Creates a new queue with the given capacity and shared name.
|
||||||
|
makeQueue2 :: forall a b . (TensorType a, TensorType b)
|
||||||
|
=> Int64 -- ^ The upper bound on the number of elements in
|
||||||
|
-- this queue. Negative numbers mean no limit.
|
||||||
|
-> ByteString -- ^ If non-empty, this queue will be shared
|
||||||
|
-- under the given name across multiple sessions.
|
||||||
|
-> Build (Queue2 a b)
|
||||||
|
makeQueue2 capacity sharedName = do
|
||||||
|
q <- buildOp (opDef "FIFOQueue"
|
||||||
|
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
||||||
|
, tensorType (undefined :: b)]
|
||||||
|
& opAttr "shared_name" .~ sharedName
|
||||||
|
& opAttr "capacity" .~ capacity
|
||||||
|
)
|
||||||
|
group q >>= addInitializer
|
||||||
|
return (Queue2 q)
|
||||||
|
|
||||||
|
-- TODO(gnezdo): Figure out the closing story for queues.
|
51
tensorflow-queue/tensorflow-queue.cabal
Normal file
51
tensorflow-queue/tensorflow-queue.cabal
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
name: tensorflow-queue
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Basic access to TensorFlow queues.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.Queue
|
||||||
|
build-depends: proto-lens == 0.1.*
|
||||||
|
, base >= 4.7 && < 5
|
||||||
|
, bytestring
|
||||||
|
, lens-family
|
||||||
|
, containers
|
||||||
|
, tensorflow-proto == 0.1.*
|
||||||
|
, tensorflow-core-ops == 0.1.*
|
||||||
|
, tensorflow
|
||||||
|
, text
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
Test-Suite QueueTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: QueueTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
-- Uses multiple threads and blocks without this option.
|
||||||
|
ghc-options: -threaded
|
||||||
|
build-depends: HUnit
|
||||||
|
, base
|
||||||
|
, bytestring
|
||||||
|
, proto-lens
|
||||||
|
, lens-family
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-queue
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, transformers
|
||||||
|
, vector
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
79
tensorflow-queue/tests/QueueTest.hs
Normal file
79
tensorflow-queue/tests/QueueTest.hs
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Data.Int (Int64)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import TensorFlow.Nodes (Scalar(..))
|
||||||
|
import TensorFlow.Ops (scalar)
|
||||||
|
import TensorFlow.Queue
|
||||||
|
import TensorFlow.Session
|
||||||
|
( asyncProdNodes
|
||||||
|
, build
|
||||||
|
, buildAnd
|
||||||
|
, run
|
||||||
|
, runSession
|
||||||
|
, run_
|
||||||
|
)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
import qualified Data.ByteString as BS
|
||||||
|
|
||||||
|
-- | Test basic queue behaviors.
|
||||||
|
testBasic = testCase "testBasic" $ runSession $ do
|
||||||
|
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
|
||||||
|
buildAnd run_ (enqueue q 42 (scalar "Hi"))
|
||||||
|
x <- buildAnd run (dequeue q)
|
||||||
|
liftIO $ (Scalar 42, Scalar "Hi") @=? x
|
||||||
|
|
||||||
|
buildAnd run_ (enqueue q 56 (scalar "Bar"))
|
||||||
|
y <- buildAnd run (dequeue q)
|
||||||
|
liftIO $ (Scalar 56, Scalar "Bar") @=? y
|
||||||
|
|
||||||
|
-- | Test queue pumping.
|
||||||
|
testPump = testCase "testPump" $ runSession $ do
|
||||||
|
(deq, pump) <- build $ do
|
||||||
|
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
|
||||||
|
(,) <$> dequeue q
|
||||||
|
<*> enqueue q 31 (scalar "Baz")
|
||||||
|
-- This is a realistic use. The pump inputs are pre-bound to some
|
||||||
|
-- nodes that produce values when pumped (e.g. read from a
|
||||||
|
-- file).
|
||||||
|
run_ (pump, pump)
|
||||||
|
|
||||||
|
(x, y) <- run (deq, deq)
|
||||||
|
liftIO $ (Scalar 31, Scalar "Baz") @=? x
|
||||||
|
liftIO $ (Scalar 31, Scalar "Baz") @=? y
|
||||||
|
|
||||||
|
testAsync = testCase "testAsync" $ runSession $ do
|
||||||
|
(deq, pump) <- build $ do
|
||||||
|
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
|
||||||
|
(,) <$> dequeue q
|
||||||
|
<*> enqueue q 10 (scalar "Async")
|
||||||
|
-- Pumps the queue until canceled by runSession exiting.
|
||||||
|
asyncProdNodes pump
|
||||||
|
-- Picks up a couple values and verifies they are as expected.
|
||||||
|
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
||||||
|
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testBasic
|
||||||
|
, testPump
|
||||||
|
, testAsync
|
||||||
|
]
|
3
tensorflow/Setup.hs
Normal file
3
tensorflow/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
376
tensorflow/src/TensorFlow/Build.hs
Normal file
376
tensorflow/src/TensorFlow/Build.hs
Normal file
|
@ -0,0 +1,376 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
{-# LANGUAGE Rank2Types #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
module TensorFlow.Build
|
||||||
|
( -- * Graph node types
|
||||||
|
ControlNode(..)
|
||||||
|
, Unique
|
||||||
|
-- * Ops
|
||||||
|
, explicitName
|
||||||
|
, implicitName
|
||||||
|
, opDef
|
||||||
|
, opDefWithName
|
||||||
|
, opName
|
||||||
|
, opType
|
||||||
|
, opAttr
|
||||||
|
, opInputs
|
||||||
|
, opControlInputs
|
||||||
|
-- * The Build monad
|
||||||
|
, GraphState
|
||||||
|
, render
|
||||||
|
, renderNodeName
|
||||||
|
, renderedNodeDefs
|
||||||
|
, BuildT
|
||||||
|
, Build
|
||||||
|
, addInitializer
|
||||||
|
, hoistBuildT
|
||||||
|
, evalBuildT
|
||||||
|
, runBuildT
|
||||||
|
, asGraphDef
|
||||||
|
, addGraphDef
|
||||||
|
, flushInitializers
|
||||||
|
, flushNodeBuffer
|
||||||
|
-- * Creating and looking up Ops
|
||||||
|
, getOrAddOp
|
||||||
|
, addNewOp
|
||||||
|
, renderOutput
|
||||||
|
-- * Modifying all nodes in a Build action
|
||||||
|
, colocateWith
|
||||||
|
, withStateLens
|
||||||
|
, withDevice
|
||||||
|
, withNameScope
|
||||||
|
, withNodeDependencies
|
||||||
|
-- * Internal Summary related bits.
|
||||||
|
, addSummary
|
||||||
|
, SummaryTensor
|
||||||
|
, collectAllSummaries
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (MonadIO(..))
|
||||||
|
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||||
|
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Default (def)
|
||||||
|
import Data.Functor.Identity (Identity(..))
|
||||||
|
import qualified Data.Map.Strict as Map
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
import Data.Set (Set)
|
||||||
|
import Data.String (IsString(..))
|
||||||
|
import Data.Text (Text)
|
||||||
|
import qualified Data.Text as Text
|
||||||
|
import Lens.Family2 (Lens', (.~), (^.), (&))
|
||||||
|
import Lens.Family2.State.Strict (MonadState, use, uses, (.=), (<>=), (%=))
|
||||||
|
import Lens.Family2.Unchecked (lens)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Graph
|
||||||
|
( GraphDef
|
||||||
|
, node
|
||||||
|
)
|
||||||
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
( NodeDef
|
||||||
|
, attr
|
||||||
|
, input
|
||||||
|
, device
|
||||||
|
, name
|
||||||
|
, op
|
||||||
|
)
|
||||||
|
|
||||||
|
import TensorFlow.Orphans ()
|
||||||
|
import TensorFlow.Output
|
||||||
|
import TensorFlow.Tensor
|
||||||
|
|
||||||
|
newtype Unique = Unique Int
|
||||||
|
deriving (Eq, Ord, Enum)
|
||||||
|
|
||||||
|
--------------
|
||||||
|
|
||||||
|
implicitName :: PendingNodeName
|
||||||
|
implicitName = ImplicitName
|
||||||
|
|
||||||
|
explicitName :: Text -> PendingNodeName
|
||||||
|
explicitName = ExplicitName
|
||||||
|
|
||||||
|
newtype Scope = Scope {unScope :: Text}
|
||||||
|
deriving (Eq, Ord, IsString)
|
||||||
|
|
||||||
|
instance Show Scope where
|
||||||
|
show = show . unScope
|
||||||
|
|
||||||
|
opDef :: OpType -> OpDef
|
||||||
|
opDef = opDefWithName ImplicitName
|
||||||
|
|
||||||
|
opDefWithName :: PendingNodeName -> OpType -> OpDef
|
||||||
|
opDefWithName n t = OpDef
|
||||||
|
{ _opName = n
|
||||||
|
, _opType = t
|
||||||
|
, _opAttrs = Map.empty
|
||||||
|
, _opInputs = []
|
||||||
|
, _opControlInputs = []
|
||||||
|
}
|
||||||
|
|
||||||
|
-- | Synonym for the tensors that return serialized Summary proto.
|
||||||
|
type SummaryTensor = Tensor Value ByteString
|
||||||
|
|
||||||
|
data GraphState = GraphState
|
||||||
|
{ _renderedNodes :: !(Map.Map PendingNode NodeDef)
|
||||||
|
-- ^ Nodes which have been rendered. Keeps track of the unique ID we
|
||||||
|
-- assign each implicitly-named node. Also prevents us from adding the
|
||||||
|
-- same node (implicit or explicit) more than once to the nodeBuffer.
|
||||||
|
, _renderedNodeDefs :: !(Map.Map NodeName NodeDef)
|
||||||
|
-- ^ The NodeDefs of nodes which have been rendered. Used by the
|
||||||
|
-- Gradient module to inspect the node graph.
|
||||||
|
, _nodeBuffer :: [NodeDef]
|
||||||
|
-- ^ A list of nodes that should be passed to TensorFlow during
|
||||||
|
-- the next call to Session.extend (TF_ExtendGraph).
|
||||||
|
, _nextUnique :: !Unique
|
||||||
|
-- ^ Unique ID for the next node
|
||||||
|
-- TODO(judahjacobson): watch for clashes between auto and user names.
|
||||||
|
, _defaultDevice :: !(Maybe Device)
|
||||||
|
, _currentScope :: [Scope]
|
||||||
|
, _defaultControlInputs :: !(Set NodeName)
|
||||||
|
, _initializationNodes :: [NodeName]
|
||||||
|
-- ^ The nodes to run next time a TF.run is issued, typically
|
||||||
|
-- variable initializers.
|
||||||
|
, _summaries :: [SummaryTensor]
|
||||||
|
-- ^ The tensors for summary
|
||||||
|
}
|
||||||
|
|
||||||
|
-- | A node definition without its final name. Used as a key in the
|
||||||
|
-- "renderedNodes" map.
|
||||||
|
-- The NodeDef contained inside has an empty "name" field.
|
||||||
|
data PendingNode = PendingNode [Scope] !PendingNodeName !NodeDef
|
||||||
|
deriving (Eq, Ord)
|
||||||
|
|
||||||
|
-- Returns an _incomplete_ NodeDef. The name is fixed by addNewOpFromPending.
|
||||||
|
pendingNodeDef :: PendingNode -> NodeDef
|
||||||
|
pendingNodeDef (PendingNode _ _ n) = n
|
||||||
|
|
||||||
|
initGraphState :: GraphState
|
||||||
|
initGraphState =
|
||||||
|
GraphState Map.empty Map.empty [] (Unique 0) Nothing [] Set.empty [] []
|
||||||
|
|
||||||
|
renderedNodes :: Lens' GraphState (Map.Map PendingNode NodeDef)
|
||||||
|
renderedNodes = lens _renderedNodes (\g x -> g { _renderedNodes = x })
|
||||||
|
|
||||||
|
renderedNodeDefs :: Lens' GraphState (Map.Map NodeName NodeDef)
|
||||||
|
renderedNodeDefs = lens _renderedNodeDefs (\g x -> g { _renderedNodeDefs = x })
|
||||||
|
|
||||||
|
nodeBuffer :: Lens' GraphState [NodeDef]
|
||||||
|
nodeBuffer = lens _nodeBuffer (\g x -> g { _nodeBuffer = x })
|
||||||
|
|
||||||
|
nextUnique :: Lens' GraphState Unique
|
||||||
|
nextUnique = lens _nextUnique (\g x -> g { _nextUnique = x })
|
||||||
|
|
||||||
|
defaultDevice :: Lens' GraphState (Maybe Device)
|
||||||
|
defaultDevice = lens _defaultDevice (\g x -> g { _defaultDevice = x })
|
||||||
|
|
||||||
|
currentScope :: Lens' GraphState [Scope]
|
||||||
|
currentScope = lens _currentScope (\g x -> g { _currentScope = x })
|
||||||
|
|
||||||
|
defaultControlInputs :: Lens' GraphState (Set NodeName)
|
||||||
|
defaultControlInputs = lens _defaultControlInputs
|
||||||
|
(\g x -> g { _defaultControlInputs = x })
|
||||||
|
|
||||||
|
initializationNodes :: Lens' GraphState [NodeName]
|
||||||
|
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
|
||||||
|
|
||||||
|
summaries :: Lens' GraphState [SummaryTensor]
|
||||||
|
summaries = lens _summaries (\g x -> g { _summaries = x })
|
||||||
|
|
||||||
|
-- | An action for building nodes in a TensorFlow graph.
|
||||||
|
-- Used to manage build state internally as part of the @Session@ monad.
|
||||||
|
newtype BuildT m a = BuildT (StateT GraphState m a)
|
||||||
|
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
||||||
|
MonadState GraphState)
|
||||||
|
|
||||||
|
-- | An action for building nodes in a TensorFlow graph.
|
||||||
|
type Build = BuildT Identity
|
||||||
|
|
||||||
|
-- | This is Control.Monad.Morph.hoist sans the dependency.
|
||||||
|
hoistBuildT :: (forall a . m a -> n a) -> BuildT m b -> BuildT n b
|
||||||
|
hoistBuildT f (BuildT m) = BuildT $ mapStateT f m
|
||||||
|
|
||||||
|
runBuildT :: BuildT m a -> m (a, GraphState)
|
||||||
|
runBuildT (BuildT f) = runStateT f initGraphState
|
||||||
|
|
||||||
|
evalBuildT :: Monad m => BuildT m a -> m a
|
||||||
|
evalBuildT (BuildT f) = evalStateT f initGraphState
|
||||||
|
|
||||||
|
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
|
||||||
|
flushNodeBuffer :: Monad m => BuildT m [NodeDef]
|
||||||
|
flushNodeBuffer = do
|
||||||
|
ns <- use nodeBuffer
|
||||||
|
nodeBuffer .= []
|
||||||
|
return ns
|
||||||
|
|
||||||
|
-- | Get all the initializers that have accumulated so far, and clear
|
||||||
|
-- that buffer.
|
||||||
|
flushInitializers :: Monad m => BuildT m [NodeName]
|
||||||
|
flushInitializers = do
|
||||||
|
ns <- use initializationNodes
|
||||||
|
initializationNodes .= []
|
||||||
|
return ns
|
||||||
|
|
||||||
|
-- | Registers the given node to be executed before the next
|
||||||
|
-- 'TensorFlow.Session.run'.
|
||||||
|
addInitializer :: ControlNode -> Build ()
|
||||||
|
addInitializer (ControlNode o) = do
|
||||||
|
i <- getOrAddOp o
|
||||||
|
initializationNodes %= (i:)
|
||||||
|
|
||||||
|
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
||||||
|
-- the given 'Build' action.
|
||||||
|
asGraphDef :: Build a -> GraphDef
|
||||||
|
asGraphDef b = def & node .~ gs ^. nodeBuffer
|
||||||
|
where
|
||||||
|
gs = snd $ runIdentity $ runBuildT b
|
||||||
|
|
||||||
|
-- TODO: check against existing nodes for conflicts?
|
||||||
|
addGraphDef :: GraphDef -> Build ()
|
||||||
|
addGraphDef g = nodeBuffer <>= g ^. node
|
||||||
|
|
||||||
|
-- | Render the given op if it hasn't been rendered already, and return its
|
||||||
|
-- name.
|
||||||
|
getOrAddOp :: Op -> Build NodeName
|
||||||
|
getOrAddOp o = NodeName . (^. name) <$> resolveOp o
|
||||||
|
|
||||||
|
resolveOp :: Op -> Build NodeDef
|
||||||
|
resolveOp (Rendered n) = return n
|
||||||
|
resolveOp (Unrendered o) = do
|
||||||
|
pending <- getPendingNode o
|
||||||
|
uses renderedNodes (Map.lookup pending) >>= \case
|
||||||
|
Just n -> return n
|
||||||
|
Nothing -> addNewOpFromPending pending
|
||||||
|
|
||||||
|
-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops
|
||||||
|
-- which are not safe to dedup (e.g, "variable" and "assign").
|
||||||
|
addNewOp :: OpDef -> Build NodeDef
|
||||||
|
addNewOp o = getPendingNode o >>= addNewOpFromPending
|
||||||
|
|
||||||
|
addNewOpFromPending :: PendingNode -> Build NodeDef
|
||||||
|
addNewOpFromPending pending = do
|
||||||
|
nodeName <- renderPendingNode pending
|
||||||
|
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
|
||||||
|
nodeBuffer %= (nodeDef :)
|
||||||
|
renderedNodes %= Map.insert pending nodeDef
|
||||||
|
renderedNodeDefs %= Map.insert nodeName nodeDef
|
||||||
|
return nodeDef
|
||||||
|
|
||||||
|
-- | Get the pending node corresponding to an OpDef, which may or may not have
|
||||||
|
-- been rendered before. Implicitly renders all of this node's inputs.
|
||||||
|
getPendingNode :: OpDef -> Build PendingNode
|
||||||
|
getPendingNode o = do
|
||||||
|
-- An empty string in the proto field means that no specific
|
||||||
|
-- device is specified.
|
||||||
|
dev <- maybe "" deviceName <$> use defaultDevice
|
||||||
|
inputs <- mapM getInput (o ^. opInputs)
|
||||||
|
scope <- use currentScope
|
||||||
|
controls <- use defaultControlInputs
|
||||||
|
let controlInputs
|
||||||
|
= map getDep (o ^. opControlInputs ++ Set.toList controls)
|
||||||
|
return $ PendingNode scope (o ^. opName)
|
||||||
|
$ def & op .~ (unOpType (o ^. opType) :: Text)
|
||||||
|
& attr .~ _opAttrs o
|
||||||
|
& input .~ (inputs ++ controlInputs)
|
||||||
|
& device .~ dev
|
||||||
|
where
|
||||||
|
getInput (Output (OutputIx k) subOp)
|
||||||
|
= (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp
|
||||||
|
getDep = ("^" <>) . unNodeName
|
||||||
|
|
||||||
|
-- | Pick a name for a pending node. If it has an explicit name, just use that;
|
||||||
|
-- if the name is implicit, assign a new unique name based on the op type.
|
||||||
|
renderPendingNode :: PendingNode -> Build NodeName
|
||||||
|
renderPendingNode (PendingNode scope pendingName nodeDef)
|
||||||
|
= NodeName . (scopePrefix <>) <$> getName
|
||||||
|
where
|
||||||
|
scopePrefix = Text.concat $ fmap ((<> "/") . unScope) scope
|
||||||
|
getName = case pendingName of
|
||||||
|
ExplicitName n -> return n
|
||||||
|
ImplicitName -> do
|
||||||
|
u@(Unique k) <- use nextUnique
|
||||||
|
nextUnique .= succ u
|
||||||
|
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
|
||||||
|
|
||||||
|
|
||||||
|
-- | Render an 'Output' and return a string representation for the TensorFlow
|
||||||
|
-- foreign APIs.
|
||||||
|
renderOutput :: Output -> Build Text
|
||||||
|
renderOutput (Output (OutputIx i) o) = do
|
||||||
|
n <- getOrAddOp o
|
||||||
|
return $ unNodeName n <> Text.pack (":" ++ show i)
|
||||||
|
|
||||||
|
-- | Modify some part of the state, run an action, and restore the state
|
||||||
|
-- after that action is done.
|
||||||
|
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b
|
||||||
|
withStateLens accessor f act = do
|
||||||
|
old <- use accessor
|
||||||
|
accessor %= f
|
||||||
|
result <- act
|
||||||
|
accessor .= old
|
||||||
|
return result
|
||||||
|
|
||||||
|
-- | Set a device for all nodes rendered in the given 'Build' action
|
||||||
|
-- (unless further overridden by another use of withDevice).
|
||||||
|
withDevice :: Maybe Device -> Build a -> Build a
|
||||||
|
withDevice d = withStateLens defaultDevice (const d)
|
||||||
|
|
||||||
|
-- | Places all nodes rendered in the given 'Build' action on the same
|
||||||
|
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||||
|
-- the action has side effects of rendering the desired tensors. A pure
|
||||||
|
-- return would not have the desired effect.
|
||||||
|
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a
|
||||||
|
colocateWith t x = do
|
||||||
|
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
||||||
|
withDevice (Just d) x
|
||||||
|
|
||||||
|
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
||||||
|
withNameScope :: Text -> Build a -> Build a
|
||||||
|
withNameScope s = withStateLens currentScope (Scope s :)
|
||||||
|
|
||||||
|
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
||||||
|
withNodeDependencies :: Set NodeName -> Build a -> Build a
|
||||||
|
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
||||||
|
|
||||||
|
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
||||||
|
-- the 'Build' context. Also renders any dependencies of the 'Tensor' that
|
||||||
|
-- weren't already rendered.
|
||||||
|
--
|
||||||
|
-- This operation is idempotent; @render >=> render === render@. However,
|
||||||
|
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
|
||||||
|
-- may result in two different 'Tensor's.
|
||||||
|
render :: Tensor v a -> Build (Tensor v a)
|
||||||
|
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
|
||||||
|
|
||||||
|
-- | Render a 'Tensor' and get its node's name.
|
||||||
|
renderNodeName :: Tensor v a -> Build NodeName
|
||||||
|
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
|
||||||
|
|
||||||
|
-- | Records the given summary action in Build for retrieval with
|
||||||
|
-- 'collectAllSummaries'. The summary op is required to produce a
|
||||||
|
-- Summary protocol buffer in string form. For safety, use the
|
||||||
|
-- pre-composed functions: Logging.scalarSummary and
|
||||||
|
-- Logging.histogramSummary.
|
||||||
|
addSummary :: SummaryTensor -> Build ()
|
||||||
|
addSummary t = summaries %= (t :)
|
||||||
|
|
||||||
|
-- | Retrieves the summary ops collected thus far. Typically this only
|
||||||
|
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
|
||||||
|
-- repeatedly, the values accumulate.
|
||||||
|
collectAllSummaries :: Monad m => BuildT m [SummaryTensor]
|
||||||
|
collectAllSummaries = use summaries
|
199
tensorflow/src/TensorFlow/BuildOp.hs
Normal file
199
tensorflow/src/TensorFlow/BuildOp.hs
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
|
||||||
|
module TensorFlow.BuildOp
|
||||||
|
( OpResult
|
||||||
|
, BuildOp
|
||||||
|
, buildOp
|
||||||
|
, buildListOp
|
||||||
|
, eqLengthGuard
|
||||||
|
)
|
||||||
|
where
|
||||||
|
|
||||||
|
import Control.Monad (replicateM)
|
||||||
|
import Control.Monad.Reader (ReaderT, runReaderT, ask)
|
||||||
|
import Control.Monad.State.Strict (State, runState, get, put)
|
||||||
|
import Data.Int (Int64)
|
||||||
|
import Lens.Family2 ((&), (<>~), (^.))
|
||||||
|
|
||||||
|
import TensorFlow.Build
|
||||||
|
import TensorFlow.Output
|
||||||
|
import TensorFlow.Tensor
|
||||||
|
|
||||||
|
data ResultState = ResultState !OutputIx [Int64] deriving Show
|
||||||
|
|
||||||
|
type Result = ReaderT Op (State ResultState)
|
||||||
|
|
||||||
|
-- | Class of types that can be used as op outputs.
|
||||||
|
class OpResult a where
|
||||||
|
toResult :: Result a
|
||||||
|
|
||||||
|
instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where
|
||||||
|
toResult = (,) <$> toResult <*> toResult
|
||||||
|
|
||||||
|
instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where
|
||||||
|
toResult = (,,) <$> toResult <*> toResult <*> toResult
|
||||||
|
|
||||||
|
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4)
|
||||||
|
=> OpResult (a1, a2, a3, a4) where
|
||||||
|
toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult
|
||||||
|
|
||||||
|
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5)
|
||||||
|
=> OpResult (a1, a2, a3, a4, a5) where
|
||||||
|
toResult = (,,,,) <$> toResult
|
||||||
|
<*> toResult
|
||||||
|
<*> toResult
|
||||||
|
<*> toResult
|
||||||
|
<*> toResult
|
||||||
|
|
||||||
|
instance ( OpResult a1
|
||||||
|
, OpResult a2
|
||||||
|
, OpResult a3
|
||||||
|
, OpResult a4
|
||||||
|
, OpResult a5
|
||||||
|
, OpResult a6
|
||||||
|
)
|
||||||
|
=> OpResult (a1, a2, a3, a4, a5, a6) where
|
||||||
|
toResult = (,,,,,)
|
||||||
|
<$> toResult
|
||||||
|
<*> toResult
|
||||||
|
<*> toResult
|
||||||
|
<*> toResult
|
||||||
|
<*> toResult
|
||||||
|
<*> toResult
|
||||||
|
|
||||||
|
tensorResult :: TensorKind v -> Result (Tensor v a)
|
||||||
|
tensorResult v = do
|
||||||
|
o <- ask
|
||||||
|
ResultState i ns <- get
|
||||||
|
put $! ResultState (i+1) ns
|
||||||
|
return $! Tensor v $ output i o
|
||||||
|
|
||||||
|
instance OpResult (Tensor Value a) where
|
||||||
|
toResult = tensorResult ValueKind
|
||||||
|
|
||||||
|
instance OpResult (Tensor Ref a) where
|
||||||
|
toResult = tensorResult RefKind
|
||||||
|
|
||||||
|
instance OpResult ControlNode where
|
||||||
|
toResult = ControlNode <$> ask
|
||||||
|
|
||||||
|
instance OpResult a => OpResult [a] where
|
||||||
|
toResult = do
|
||||||
|
ResultState i ns <- get
|
||||||
|
case ns of
|
||||||
|
[] -> error $ "Ran out of counts in toResult. " ++
|
||||||
|
"Likely misuse of buildListOp."
|
||||||
|
(n : rest) -> do
|
||||||
|
put $! ResultState i rest
|
||||||
|
replicateM (fromIntegral n) toResult
|
||||||
|
|
||||||
|
runResult :: OpResult a => [Int64] -> Op -> a
|
||||||
|
runResult ns o =
|
||||||
|
case runState (runReaderT toResult o) (ResultState 0 ns) of
|
||||||
|
(x, ResultState _ []) -> x
|
||||||
|
(_, ns') -> error $ "Ununsed length in runResult attributes: " ++
|
||||||
|
show (ns, ns')
|
||||||
|
|
||||||
|
-- | Make a new "pure" op, which may be deduped with identical ops within
|
||||||
|
-- the same scope.
|
||||||
|
pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a
|
||||||
|
pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts
|
||||||
|
|
||||||
|
-- | Make a new "stateful" op, which will not be deduped with otherwise
|
||||||
|
-- identical ops.
|
||||||
|
buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a
|
||||||
|
buildResult ns o ts
|
||||||
|
= runResult ns . Rendered <$> addNewOp (addReversedInputs o ts)
|
||||||
|
|
||||||
|
addReversedInputs :: OpDef -> [Output] -> OpDef
|
||||||
|
addReversedInputs o ts = o & opInputs <>~ reverse ts
|
||||||
|
|
||||||
|
-- | Class of types that can be used as op functions.
|
||||||
|
class BuildOp f where
|
||||||
|
buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr)
|
||||||
|
-> OpDef
|
||||||
|
-> [Output] -- ^ Accumulator for inputs to the op.
|
||||||
|
-> f
|
||||||
|
|
||||||
|
-- | Starts an operation that returns a structured set of tensors
|
||||||
|
-- (singletons or tuples).
|
||||||
|
buildOp :: BuildOp f => OpDef -> f
|
||||||
|
buildOp o = buildOp' [] o []
|
||||||
|
|
||||||
|
-- | Starts an operation that returns a list of tensors.
|
||||||
|
buildListOp :: BuildOp f => [Int64]
|
||||||
|
-- ^ Cardinality of the corresponding list of tensors output.
|
||||||
|
-> OpDef -> f
|
||||||
|
buildListOp counts o = buildOp' counts o []
|
||||||
|
|
||||||
|
instance BuildOp ControlNode where
|
||||||
|
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
||||||
|
|
||||||
|
instance BuildOp (Tensor Value a) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance BuildOp (Tensor Ref a) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance BuildOp [Tensor Value a] where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4)
|
||||||
|
=> BuildOp (t1, t2, t3, t4) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5)
|
||||||
|
=> BuildOp (t1, t2, t3, t4, t5) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance ( OpResult t1
|
||||||
|
, OpResult t2
|
||||||
|
, OpResult t3
|
||||||
|
, OpResult t4
|
||||||
|
, OpResult t5
|
||||||
|
, OpResult t6
|
||||||
|
)
|
||||||
|
=> BuildOp (t1, t2, t3, t4, t5, t6) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance OpResult a => BuildOp (Build a) where
|
||||||
|
buildOp' = buildResult
|
||||||
|
|
||||||
|
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
||||||
|
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
|
||||||
|
|
||||||
|
instance BuildOp f => BuildOp ([Tensor v a] -> f) where
|
||||||
|
buildOp' rf o accum ts
|
||||||
|
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
|
||||||
|
|
||||||
|
-- | Returns true if all the integers in each tuple are identical.
|
||||||
|
-- Throws an error with a descriptive message if not.
|
||||||
|
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
|
||||||
|
eqLengthGuard = all eachOk
|
||||||
|
where
|
||||||
|
eachOk (_, []) = True
|
||||||
|
-- The next line has (== 1) . length . nub in disguise
|
||||||
|
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
|
||||||
|
error ("number_attr " ++ numberAttrName ++
|
||||||
|
" contains tensors with different length " ++ show pairs)
|
87
tensorflow/src/TensorFlow/ControlFlow.hs
Normal file
87
tensorflow/src/TensorFlow/ControlFlow.hs
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
module TensorFlow.ControlFlow
|
||||||
|
( -- * Dependencies
|
||||||
|
withControlDependencies
|
||||||
|
, group
|
||||||
|
-- * Operations
|
||||||
|
, identity
|
||||||
|
, noOp
|
||||||
|
, named
|
||||||
|
) where
|
||||||
|
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
import Data.Text (Text)
|
||||||
|
import Lens.Family2 ((&), (^.), (.~))
|
||||||
|
|
||||||
|
import TensorFlow.BuildOp
|
||||||
|
import TensorFlow.Build
|
||||||
|
import TensorFlow.Nodes
|
||||||
|
import TensorFlow.Output
|
||||||
|
import TensorFlow.Tensor
|
||||||
|
import TensorFlow.Types
|
||||||
|
|
||||||
|
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
||||||
|
-- on the nodes in the first argument.
|
||||||
|
withControlDependencies :: Nodes t => t -> Build a -> Build a
|
||||||
|
withControlDependencies deps act = do
|
||||||
|
nodes <- getNodes deps
|
||||||
|
withNodeDependencies nodes act
|
||||||
|
|
||||||
|
-- TODO(judahjacobson): Reimplement withDependencies.
|
||||||
|
|
||||||
|
-- | Create an op that groups multiple operations.
|
||||||
|
--
|
||||||
|
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
||||||
|
-- no output.
|
||||||
|
group :: Nodes t => t -> Build ControlNode
|
||||||
|
group deps = do
|
||||||
|
nodes <- Set.toList <$> getNodes deps
|
||||||
|
-- TODO: slicker way
|
||||||
|
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
||||||
|
|
||||||
|
|
||||||
|
-- | Returns a 'Tensor' with the same shape and contents as the input.
|
||||||
|
identity :: TensorType a => Tensor v a -> Tensor v a
|
||||||
|
identity = namedIdentity implicitName
|
||||||
|
|
||||||
|
-- | Returns a 'Tensor' with a given name and the same shape and contents as
|
||||||
|
-- the input.
|
||||||
|
--
|
||||||
|
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
|
||||||
|
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
|
||||||
|
-- whether we can change that op.
|
||||||
|
named :: TensorType a => Text -> Tensor v a -> Tensor v a
|
||||||
|
named = namedIdentity . explicitName
|
||||||
|
|
||||||
|
-- | An internal version of "identity" that allows setting the name
|
||||||
|
-- of the output Tensor.
|
||||||
|
namedIdentity :: forall a v . TensorType a
|
||||||
|
=> PendingNodeName -> Tensor v a -> Tensor v a
|
||||||
|
namedIdentity n t = case t ^. tensorKind of
|
||||||
|
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
|
||||||
|
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
|
||||||
|
where
|
||||||
|
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
|
||||||
|
|
||||||
|
|
||||||
|
-- | Does nothing. Only useful as a placeholder for control edges.
|
||||||
|
noOp :: ControlNode
|
||||||
|
noOp = buildOp $ opDef "NoOp"
|
243
tensorflow/src/TensorFlow/Internal/FFI.hs
Normal file
243
tensorflow/src/TensorFlow/Internal/FFI.hs
Normal file
|
@ -0,0 +1,243 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE DeriveDataTypeable #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
module TensorFlow.Internal.FFI
|
||||||
|
( TensorFlowException(..)
|
||||||
|
, Raw.Session
|
||||||
|
, withSession
|
||||||
|
, extendGraph
|
||||||
|
, run
|
||||||
|
, TensorData(..)
|
||||||
|
, setSessionConfig
|
||||||
|
, setSessionTarget
|
||||||
|
, getAllOpList
|
||||||
|
-- * Internal helper.
|
||||||
|
, useProtoAsVoidPtrLen
|
||||||
|
)
|
||||||
|
where
|
||||||
|
|
||||||
|
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
||||||
|
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
||||||
|
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
||||||
|
import Control.Monad (when)
|
||||||
|
import Data.Int (Int64)
|
||||||
|
import Data.Typeable (Typeable)
|
||||||
|
import Data.Word (Word8)
|
||||||
|
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
|
||||||
|
import Foreign.C.String (CString)
|
||||||
|
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
|
||||||
|
import Foreign.Marshal.Alloc (free)
|
||||||
|
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
|
||||||
|
import System.IO.Unsafe (unsafePerformIO)
|
||||||
|
import qualified Data.ByteString as B
|
||||||
|
import qualified Data.Text as T
|
||||||
|
import qualified Data.Text.Encoding as T
|
||||||
|
import qualified Data.Text.Encoding.Error as T
|
||||||
|
import qualified Data.Vector.Storable as S
|
||||||
|
|
||||||
|
import Data.ProtoLens (Message, encodeMessage)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||||
|
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||||
|
|
||||||
|
import qualified TensorFlow.Internal.Raw as Raw
|
||||||
|
|
||||||
|
data TensorFlowException = TensorFlowException Raw.Code T.Text
|
||||||
|
deriving (Show, Eq, Typeable)
|
||||||
|
|
||||||
|
instance Exception TensorFlowException
|
||||||
|
|
||||||
|
-- | All of the data needed to represent a tensor.
|
||||||
|
data TensorData = TensorData
|
||||||
|
{ tensorDataDimensions :: [Int64]
|
||||||
|
, tensorDataType :: !DataType
|
||||||
|
, tensorDataBytes :: !(S.Vector Word8)
|
||||||
|
}
|
||||||
|
deriving (Show, Eq)
|
||||||
|
|
||||||
|
-- | Runs the given action after creating a session with options
|
||||||
|
-- populated by the given optionSetter.
|
||||||
|
withSession :: (Raw.SessionOptions -> IO ())
|
||||||
|
-> ((IO () -> IO ()) -> Raw.Session -> IO a)
|
||||||
|
-- ^ The action can spawn concurrent tasks which will
|
||||||
|
-- be canceled before withSession returns.
|
||||||
|
-> IO a
|
||||||
|
withSession optionSetter action = do
|
||||||
|
drain <- newMVar []
|
||||||
|
let cleanup s =
|
||||||
|
-- Closes the session to nudge the pending run calls to fail and exit.
|
||||||
|
finally (checkStatus (Raw.closeSession s)) $ do
|
||||||
|
runners <- takeMVar drain
|
||||||
|
-- Collects all runners before deleting the session.
|
||||||
|
mapM_ shutDownRunner runners
|
||||||
|
checkStatus (Raw.deleteSession s)
|
||||||
|
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
||||||
|
optionSetter options
|
||||||
|
bracket
|
||||||
|
(checkStatus (Raw.newSession options))
|
||||||
|
cleanup
|
||||||
|
(action (asyncCollector drain))
|
||||||
|
|
||||||
|
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
|
||||||
|
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
|
||||||
|
where
|
||||||
|
launchAndRecord restRunners = (: restRunners) <$> async runner
|
||||||
|
|
||||||
|
shutDownRunner :: Async () -> IO ()
|
||||||
|
shutDownRunner r = do
|
||||||
|
cancel r
|
||||||
|
-- TODO(gnezdo): manage exceptions better than print.
|
||||||
|
either print (const (return ())) =<< waitCatch r
|
||||||
|
|
||||||
|
extendGraph :: Raw.Session -> GraphDef -> IO ()
|
||||||
|
extendGraph session pb =
|
||||||
|
useProtoAsVoidPtrLen pb $ \ptr len ->
|
||||||
|
checkStatus $ Raw.extendGraph session ptr len
|
||||||
|
|
||||||
|
|
||||||
|
run :: Raw.Session
|
||||||
|
-> [(B.ByteString, TensorData)] -- ^ Feeds.
|
||||||
|
-> [B.ByteString] -- ^ Fetches.
|
||||||
|
-> [B.ByteString] -- ^ Targets.
|
||||||
|
-> IO [TensorData]
|
||||||
|
run session feeds fetches targets = do
|
||||||
|
let nullTensor = Raw.Tensor nullPtr
|
||||||
|
-- Use mask to avoid leaking input tensors before they are passed to 'run'
|
||||||
|
-- and output tensors before they are passed to 'createTensorData'.
|
||||||
|
mask_ $
|
||||||
|
-- Feeds
|
||||||
|
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
|
||||||
|
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
|
||||||
|
withArrayLen feedTensors $ \_ cFeedTensors ->
|
||||||
|
-- Fetches.
|
||||||
|
withStringArrayLen fetches $ \fetchesLen fetchNames ->
|
||||||
|
-- tensorOuts is an array of null Tensor pointers that will be filled
|
||||||
|
-- by the call to Raw.run.
|
||||||
|
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
|
||||||
|
-- Targets.
|
||||||
|
withStringArrayLen targets $ \targetsLen ctargets -> do
|
||||||
|
checkStatus $ Raw.run
|
||||||
|
session
|
||||||
|
nullPtr
|
||||||
|
feedNames cFeedTensors (fromIntegral feedsLen)
|
||||||
|
fetchNames tensorOuts (fromIntegral fetchesLen)
|
||||||
|
ctargets (fromIntegral targetsLen)
|
||||||
|
nullPtr
|
||||||
|
outTensors <- peekArray fetchesLen tensorOuts
|
||||||
|
mapM createTensorData outTensors
|
||||||
|
|
||||||
|
|
||||||
|
-- Internal.
|
||||||
|
|
||||||
|
|
||||||
|
-- | Use a list of ByteString as a list of CString.
|
||||||
|
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
|
||||||
|
withStringList strings fn = go strings []
|
||||||
|
where
|
||||||
|
go [] cs = fn (reverse cs)
|
||||||
|
-- TODO(fmayle): Is it worth using unsafeAsCString here?
|
||||||
|
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
|
||||||
|
|
||||||
|
|
||||||
|
-- | Use a list of ByteString as an array of CString.
|
||||||
|
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
|
||||||
|
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
|
||||||
|
|
||||||
|
|
||||||
|
-- | Create a Raw.Tensor from a TensorData.
|
||||||
|
createRawTensor :: TensorData -> IO Raw.Tensor
|
||||||
|
createRawTensor (TensorData dims dt byteVec) =
|
||||||
|
withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do
|
||||||
|
let len = S.length byteVec
|
||||||
|
dest <- mallocArray len
|
||||||
|
S.unsafeWith byteVec $ \x -> copyArray dest x len
|
||||||
|
Raw.newTensor (toEnum $ fromEnum dt)
|
||||||
|
cdims (fromIntegral cdimsLen)
|
||||||
|
(castPtr dest) (fromIntegral len)
|
||||||
|
tensorDeallocFunPtr nullPtr
|
||||||
|
|
||||||
|
{-# NOINLINE tensorDeallocFunPtr #-}
|
||||||
|
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
|
||||||
|
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
|
||||||
|
|
||||||
|
-- | Create a TensorData from a Raw.Tensor.
|
||||||
|
--
|
||||||
|
-- Takes ownership of the Raw.Tensor.
|
||||||
|
createTensorData :: Raw.Tensor -> IO TensorData
|
||||||
|
createTensorData t = do
|
||||||
|
-- Read dimensions.
|
||||||
|
numDims <- Raw.numDims t
|
||||||
|
dims <- mapM (Raw.dim t) [0..numDims-1]
|
||||||
|
-- Read type.
|
||||||
|
dtype <- toEnum . fromEnum <$> Raw.tensorType t
|
||||||
|
-- Read data.
|
||||||
|
len <- fromIntegral <$> Raw.tensorByteSize t
|
||||||
|
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
|
||||||
|
-- TODO(fmayle): Don't copy the data.
|
||||||
|
v <- S.fromList <$> peekArray len bytes
|
||||||
|
-- Free tensor.
|
||||||
|
Raw.deleteTensor t
|
||||||
|
return $ TensorData (map fromIntegral dims) dtype v
|
||||||
|
|
||||||
|
-- | Runs the given action which does FFI calls updating a provided
|
||||||
|
-- status object. If the status is not OK it is thrown as
|
||||||
|
-- TensorFlowException.
|
||||||
|
checkStatus :: (Raw.Status -> IO a) -> IO a
|
||||||
|
checkStatus fn =
|
||||||
|
bracket Raw.newStatus Raw.deleteStatus $ \status -> do
|
||||||
|
result <- fn status
|
||||||
|
code <- Raw.getCode status
|
||||||
|
when (code /= Raw.TF_OK) $ do
|
||||||
|
msg <- T.decodeUtf8With T.lenientDecode <$>
|
||||||
|
(Raw.message status >>= B.packCString)
|
||||||
|
throwIO $ TensorFlowException code msg
|
||||||
|
return result
|
||||||
|
|
||||||
|
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
|
||||||
|
setSessionConfig pb opt =
|
||||||
|
useProtoAsVoidPtrLen pb $ \ptr len ->
|
||||||
|
checkStatus (Raw.setConfig opt ptr len)
|
||||||
|
|
||||||
|
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
|
||||||
|
setSessionTarget target = B.useAsCString target . Raw.setTarget
|
||||||
|
|
||||||
|
-- | Serializes the given msg and provides it as (ptr,len) argument
|
||||||
|
-- to the given action.
|
||||||
|
useProtoAsVoidPtrLen :: (Message msg, Num c) =>
|
||||||
|
msg -> (Ptr b -> c -> IO a) -> IO a
|
||||||
|
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
|
||||||
|
\(bytes, len) -> f (castPtr bytes) (fromIntegral len)
|
||||||
|
|
||||||
|
-- | Returns the serialized OpList of all OpDefs defined in this
|
||||||
|
-- address space.
|
||||||
|
getAllOpList :: IO B.ByteString
|
||||||
|
getAllOpList = do
|
||||||
|
foreignPtr <-
|
||||||
|
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
|
||||||
|
-- Makes a copy because it is more reliable than eviscerating
|
||||||
|
-- Buffer to steal its memory (including custom deallocator).
|
||||||
|
withForeignPtr foreignPtr $
|
||||||
|
\ptr -> B.packCStringLen =<< (,)
|
||||||
|
<$> (castPtr <$> Raw.getBufferData ptr)
|
||||||
|
<*> (fromIntegral <$> Raw.getBufferLength ptr)
|
||||||
|
where
|
||||||
|
checkCall = do
|
||||||
|
p <- Raw.getAllOpList
|
||||||
|
when (p == nullPtr) (throwIO exception)
|
||||||
|
return p
|
||||||
|
exception = TensorFlowException
|
||||||
|
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"
|
152
tensorflow/src/TensorFlow/Internal/Raw.chs
Normal file
152
tensorflow/src/TensorFlow/Internal/Raw.chs
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE ForeignFunctionInterface #-}
|
||||||
|
|
||||||
|
module TensorFlow.Internal.Raw where
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/c/c_api.h"
|
||||||
|
|
||||||
|
import Foreign
|
||||||
|
import Foreign.C
|
||||||
|
|
||||||
|
{#enum TF_DataType as DataType {} deriving (Show, Eq) #}
|
||||||
|
{#enum TF_Code as Code {} deriving (Show, Eq) #}
|
||||||
|
|
||||||
|
|
||||||
|
-- Status.
|
||||||
|
{#pointer *TF_Status as Status newtype #}
|
||||||
|
|
||||||
|
newStatus :: IO Status
|
||||||
|
newStatus = {# call TF_NewStatus as ^ #}
|
||||||
|
|
||||||
|
deleteStatus :: Status -> IO ()
|
||||||
|
deleteStatus = {# call TF_DeleteStatus as ^ #}
|
||||||
|
|
||||||
|
setStatus :: Status -> Code -> CString -> IO ()
|
||||||
|
setStatus s c = {# call TF_SetStatus as ^ #} s (fromIntegral $ fromEnum c)
|
||||||
|
|
||||||
|
getCode :: Status -> IO Code
|
||||||
|
getCode s = toEnum . fromIntegral <$> {# call TF_GetCode as ^ #} s
|
||||||
|
|
||||||
|
message :: Status -> IO CString
|
||||||
|
message = {# call TF_Message as ^ #}
|
||||||
|
|
||||||
|
|
||||||
|
-- Buffer.
|
||||||
|
data Buffer
|
||||||
|
{#pointer *TF_Buffer as BufferPtr -> Buffer #}
|
||||||
|
|
||||||
|
getBufferData :: BufferPtr -> IO (Ptr ())
|
||||||
|
getBufferData = {#get TF_Buffer->data #}
|
||||||
|
|
||||||
|
getBufferLength :: BufferPtr -> IO CULong
|
||||||
|
getBufferLength ={#get TF_Buffer->length #}
|
||||||
|
|
||||||
|
-- Tensor.
|
||||||
|
{#pointer *TF_Tensor as Tensor newtype #}
|
||||||
|
|
||||||
|
instance Storable Tensor where
|
||||||
|
sizeOf (Tensor t) = sizeOf t
|
||||||
|
alignment (Tensor t) = alignment t
|
||||||
|
peek p = fmap Tensor (peek (castPtr p))
|
||||||
|
poke p (Tensor t) = poke (castPtr p) t
|
||||||
|
|
||||||
|
newTensor :: DataType
|
||||||
|
-> Ptr CLong -- dimensions array
|
||||||
|
-> CInt -- num dimensions
|
||||||
|
-> Ptr () -- data
|
||||||
|
-> CULong -- data len
|
||||||
|
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()) -- deallocator
|
||||||
|
-> Ptr () -- deallocator arg
|
||||||
|
-> IO Tensor
|
||||||
|
newTensor dt = {# call TF_NewTensor as ^ #} (fromIntegral $ fromEnum dt)
|
||||||
|
|
||||||
|
deleteTensor :: Tensor -> IO ()
|
||||||
|
deleteTensor = {# call TF_DeleteTensor as ^ #}
|
||||||
|
|
||||||
|
tensorType :: Tensor -> IO DataType
|
||||||
|
tensorType t = toEnum . fromIntegral <$> {# call TF_TensorType as ^ #} t
|
||||||
|
|
||||||
|
numDims :: Tensor -> IO CInt
|
||||||
|
numDims = {# call TF_NumDims as ^ #}
|
||||||
|
|
||||||
|
dim :: Tensor -> CInt -> IO CLong
|
||||||
|
dim = {# call TF_Dim as ^ #}
|
||||||
|
|
||||||
|
tensorByteSize :: Tensor -> IO CULong
|
||||||
|
tensorByteSize = {# call TF_TensorByteSize as ^ #}
|
||||||
|
|
||||||
|
tensorData :: Tensor -> IO (Ptr ())
|
||||||
|
tensorData = {# call TF_TensorData as ^ #}
|
||||||
|
|
||||||
|
|
||||||
|
-- Session Options.
|
||||||
|
{# pointer *TF_SessionOptions as SessionOptions newtype #}
|
||||||
|
|
||||||
|
newSessionOptions :: IO SessionOptions
|
||||||
|
newSessionOptions = {# call TF_NewSessionOptions as ^ #}
|
||||||
|
|
||||||
|
setTarget :: SessionOptions -> CString -> IO ()
|
||||||
|
setTarget = {# call TF_SetTarget as ^ #}
|
||||||
|
|
||||||
|
setConfig :: SessionOptions -> Ptr () -> CULong -> Status -> IO ()
|
||||||
|
setConfig = {# call TF_SetConfig as ^ #}
|
||||||
|
|
||||||
|
deleteSessionOptions :: SessionOptions -> IO ()
|
||||||
|
deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
|
||||||
|
|
||||||
|
|
||||||
|
-- Session.
|
||||||
|
{# pointer *TF_Session as Session newtype #}
|
||||||
|
|
||||||
|
newSession :: SessionOptions -> Status -> IO Session
|
||||||
|
newSession = {# call TF_NewSession as ^ #}
|
||||||
|
|
||||||
|
closeSession :: Session -> Status -> IO ()
|
||||||
|
closeSession = {# call TF_CloseSession as ^ #}
|
||||||
|
|
||||||
|
deleteSession :: Session -> Status -> IO ()
|
||||||
|
deleteSession = {# call TF_DeleteSession as ^ #}
|
||||||
|
|
||||||
|
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
|
||||||
|
extendGraph = {# call TF_ExtendGraph as ^ #}
|
||||||
|
|
||||||
|
run :: Session
|
||||||
|
-> BufferPtr -- RunOptions proto.
|
||||||
|
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count).
|
||||||
|
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count).
|
||||||
|
-> Ptr CString -> CInt -- Target nodes (names, count).
|
||||||
|
-> BufferPtr -- RunMetadata proto.
|
||||||
|
-> Status
|
||||||
|
-> IO ()
|
||||||
|
run = {# call TF_Run as ^ #}
|
||||||
|
|
||||||
|
-- FFI helpers.
|
||||||
|
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
|
||||||
|
foreign import ccall "wrapper"
|
||||||
|
wrapTensorDealloc :: TensorDeallocFn -> IO (FunPtr TensorDeallocFn)
|
||||||
|
|
||||||
|
|
||||||
|
-- | Get the OpList of all OpDefs defined in this address space.
|
||||||
|
-- Returns a BufferPtr, ownership of which is transferred to the caller
|
||||||
|
-- (and can be freed using deleteBuffer).
|
||||||
|
--
|
||||||
|
-- The data in the buffer will be the serialized OpList proto for ops registered
|
||||||
|
-- in this address space.
|
||||||
|
getAllOpList :: IO BufferPtr
|
||||||
|
getAllOpList = {# call TF_GetAllOpList as ^ #}
|
||||||
|
|
||||||
|
foreign import ccall "&TF_DeleteBuffer"
|
||||||
|
deleteBuffer :: FunPtr (BufferPtr -> IO ())
|
50
tensorflow/src/TensorFlow/Internal/VarInt.hs
Normal file
50
tensorflow/src/TensorFlow/Internal/VarInt.hs
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
|
||||||
|
{-|
|
||||||
|
Module : TensorFlow.Internal.VarInt
|
||||||
|
Description : Encoders and decoders for varint types.
|
||||||
|
|
||||||
|
Originally taken from internal proto-lens code.
|
||||||
|
-}
|
||||||
|
module TensorFlow.Internal.VarInt
|
||||||
|
( getVarInt
|
||||||
|
, putVarInt
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Attoparsec.ByteString as Parse
|
||||||
|
import Data.Bits
|
||||||
|
import Data.ByteString.Lazy.Builder as Builder
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Data.Word (Word64)
|
||||||
|
|
||||||
|
-- | Decode an unsigned varint.
|
||||||
|
getVarInt :: Parser Word64
|
||||||
|
getVarInt = loop 1 0
|
||||||
|
where
|
||||||
|
loop !s !n = do
|
||||||
|
b <- anyWord8
|
||||||
|
let n' = n + s * fromIntegral (b .&. 127)
|
||||||
|
if (b .&. 128) == 0
|
||||||
|
then return n'
|
||||||
|
else loop (128*s) n'
|
||||||
|
|
||||||
|
-- | Encode a Word64.
|
||||||
|
putVarInt :: Word64 -> Builder
|
||||||
|
putVarInt n
|
||||||
|
| n < 128 = Builder.word8 (fromIntegral n)
|
||||||
|
| otherwise = Builder.word8 (fromIntegral $ n .&. 127 .|. 128)
|
||||||
|
<> putVarInt (n `shiftR` 7)
|
141
tensorflow/src/TensorFlow/Nodes.hs
Normal file
141
tensorflow/src/TensorFlow/Nodes.hs
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
module TensorFlow.Nodes where
|
||||||
|
|
||||||
|
import Control.Applicative (liftA2, liftA3)
|
||||||
|
import Data.Map.Strict (Map)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Data.Set (Set)
|
||||||
|
import Data.String (IsString)
|
||||||
|
import Data.Text (Text)
|
||||||
|
import Lens.Family2 ((^.))
|
||||||
|
import qualified Data.Map.Strict as Map
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
import TensorFlow.Build
|
||||||
|
import TensorFlow.Output
|
||||||
|
import TensorFlow.Tensor
|
||||||
|
import TensorFlow.Types
|
||||||
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
|
||||||
|
-- | Types that contain ops which can be run.
|
||||||
|
class Nodes t where
|
||||||
|
getNodes :: t -> Build (Set NodeName)
|
||||||
|
|
||||||
|
-- | Types that tensor representations (e.g. 'Tensor', 'ControlNode') can be
|
||||||
|
-- fetched into.
|
||||||
|
--
|
||||||
|
-- Includes collections of tensors (e.g. tuples).
|
||||||
|
class Nodes t => Fetchable t a where
|
||||||
|
getFetch :: t -> Build (Fetch a)
|
||||||
|
|
||||||
|
-- | Fetch action. Keeps track of what needs to be fetched and how to decode
|
||||||
|
-- the fetched data.
|
||||||
|
data Fetch a = Fetch
|
||||||
|
{ -- | Nodes to fetch
|
||||||
|
fetches :: Set Text
|
||||||
|
-- | Function to create an 'a' from the fetched data.
|
||||||
|
, fetchRestore :: Map Text FFI.TensorData -> a
|
||||||
|
}
|
||||||
|
|
||||||
|
instance Functor Fetch where
|
||||||
|
fmap f (Fetch fetch restore) = Fetch fetch (f . restore)
|
||||||
|
|
||||||
|
instance Applicative Fetch where
|
||||||
|
pure x = Fetch Set.empty (const x)
|
||||||
|
Fetch fetch restore <*> Fetch fetch' restore' =
|
||||||
|
Fetch (fetch <> fetch') (restore <*> restore')
|
||||||
|
|
||||||
|
nodesUnion :: (Monoid b, Traversable t, Applicative f) => t (f b) -> f b
|
||||||
|
nodesUnion = fmap (foldMap id) . sequenceA
|
||||||
|
|
||||||
|
instance (Nodes t1, Nodes t2) => Nodes (t1, t2) where
|
||||||
|
getNodes (x, y) = nodesUnion [getNodes x, getNodes y]
|
||||||
|
|
||||||
|
instance (Nodes t1, Nodes t2, Nodes t3) => Nodes (t1, t2, t3) where
|
||||||
|
getNodes (x, y, z) = nodesUnion [getNodes x, getNodes y, getNodes z]
|
||||||
|
|
||||||
|
instance (Fetchable t1 a1, Fetchable t2 a2) => Fetchable (t1, t2) (a1, a2) where
|
||||||
|
getFetch (x, y) = liftA2 (,) <$> getFetch x <*> getFetch y
|
||||||
|
|
||||||
|
instance (Fetchable t1 a1, Fetchable t2 a2, Fetchable t3 a3)
|
||||||
|
=> Fetchable (t1, t2, t3) (a1, a2, a3) where
|
||||||
|
getFetch (x, y, z) =
|
||||||
|
liftA3 (,,) <$> getFetch x <*> getFetch y <*> getFetch z
|
||||||
|
|
||||||
|
instance Nodes t => Nodes [t] where
|
||||||
|
getNodes = nodesUnion . map getNodes
|
||||||
|
|
||||||
|
instance Fetchable t a => Fetchable [t] [a] where
|
||||||
|
getFetch ts = sequenceA <$> mapM getFetch ts
|
||||||
|
|
||||||
|
instance Nodes ControlNode where
|
||||||
|
getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o
|
||||||
|
|
||||||
|
-- We use the constraint @(a ~ ())@ to help with type inference. For example,
|
||||||
|
-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session
|
||||||
|
-- ()@. If we used @instance Fetchable ControlNode ()@ instead, then that
|
||||||
|
-- expression would be ambiguous without explicitly specifying the return type.
|
||||||
|
instance a ~ () => Fetchable ControlNode a where
|
||||||
|
getFetch _ = return $ pure ()
|
||||||
|
|
||||||
|
instance Nodes (Tensor v a) where
|
||||||
|
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
||||||
|
|
||||||
|
fetchTensorList :: TensorType a => Tensor v a -> Build (Fetch (Shape, [a]))
|
||||||
|
fetchTensorList t = fmap (fmap V.toList) <$> fetchTensorVector t
|
||||||
|
|
||||||
|
fetchTensorVector :: forall a v . TensorType a
|
||||||
|
=> Tensor v a -> Build (Fetch (Shape, V.Vector a))
|
||||||
|
fetchTensorVector (Tensor _ o) = do
|
||||||
|
outputName <- renderOutput o
|
||||||
|
return $ Fetch (Set.singleton outputName) $ \tensors ->
|
||||||
|
let tensorData = tensors Map.! outputName
|
||||||
|
shape = Shape $ FFI.tensorDataDimensions tensorData
|
||||||
|
vec = decodeTensorData $ TensorData tensorData
|
||||||
|
|
||||||
|
expectedType = tensorType (undefined :: a)
|
||||||
|
actualType = FFI.tensorDataType tensorData
|
||||||
|
badTypeError = error $ "Bad tensor type: expected "
|
||||||
|
++ show expectedType
|
||||||
|
++ ", got "
|
||||||
|
++ show actualType
|
||||||
|
in if expectedType /= actualType
|
||||||
|
then badTypeError
|
||||||
|
else (shape, vec)
|
||||||
|
|
||||||
|
-- The constraint "a ~ a'" means that the input/output of fetch can constrain
|
||||||
|
-- the TensorType of each other.
|
||||||
|
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (V.Vector a') where
|
||||||
|
getFetch t = fmap snd <$> fetchTensorVector t
|
||||||
|
|
||||||
|
newtype Scalar a = Scalar {unScalar :: a}
|
||||||
|
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
|
||||||
|
RealFrac, IsString)
|
||||||
|
|
||||||
|
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (Scalar a') where
|
||||||
|
getFetch t = fmap (Scalar . headFromSingleton . snd) <$> fetchTensorList t
|
||||||
|
where
|
||||||
|
headFromSingleton [x] = x
|
||||||
|
headFromSingleton xs
|
||||||
|
= error $ "Unable to extract singleton from tensor of length "
|
||||||
|
++ show (length xs)
|
46
tensorflow/src/TensorFlow/Orphans.hs
Normal file
46
tensorflow/src/TensorFlow/Orphans.hs
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||||
|
-- Orphan instances for certain proto messages/enums, used internally.
|
||||||
|
-- TODO(judahjacobson): consider making proto-lens generate some or all of
|
||||||
|
-- these automatically; or, alternately, make new Haskell datatypes.
|
||||||
|
module TensorFlow.Orphans() where
|
||||||
|
|
||||||
|
import Proto.Tensorflow.Core.Framework.AttrValue
|
||||||
|
( AttrValue(..)
|
||||||
|
, AttrValue'ListValue(..)
|
||||||
|
, NameAttrList(..)
|
||||||
|
)
|
||||||
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
( NodeDef(..))
|
||||||
|
import Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||||
|
( ResourceHandle(..))
|
||||||
|
import Proto.Tensorflow.Core.Framework.Tensor
|
||||||
|
(TensorProto(..))
|
||||||
|
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
|
(TensorShapeProto(..), TensorShapeProto'Dim(..))
|
||||||
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||||
|
|
||||||
|
deriving instance Ord AttrValue
|
||||||
|
deriving instance Ord AttrValue'ListValue
|
||||||
|
deriving instance Ord DataType
|
||||||
|
deriving instance Ord NameAttrList
|
||||||
|
deriving instance Ord NodeDef
|
||||||
|
deriving instance Ord ResourceHandle
|
||||||
|
deriving instance Ord TensorProto
|
||||||
|
deriving instance Ord TensorShapeProto
|
||||||
|
deriving instance Ord TensorShapeProto'Dim
|
156
tensorflow/src/TensorFlow/Output.hs
Normal file
156
tensorflow/src/TensorFlow/Output.hs
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
|
{-# LANGUAGE Rank2Types #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
module TensorFlow.Output
|
||||||
|
( ControlNode(..)
|
||||||
|
, Device(..)
|
||||||
|
-- * Ops
|
||||||
|
, NodeName(..)
|
||||||
|
, Op(..)
|
||||||
|
, opUnrendered
|
||||||
|
, OpDef(..)
|
||||||
|
, opName
|
||||||
|
, opType
|
||||||
|
, opAttr
|
||||||
|
, opInputs
|
||||||
|
, opControlInputs
|
||||||
|
, OpType(..)
|
||||||
|
, OutputIx(..)
|
||||||
|
, Output(..)
|
||||||
|
, output
|
||||||
|
, outputIndex
|
||||||
|
, outputOp
|
||||||
|
, PendingNodeName(..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import qualified Data.Map.Strict as Map
|
||||||
|
import Data.ProtoLens.TextFormat (showMessage)
|
||||||
|
import Data.String (IsString(..))
|
||||||
|
import Data.Text (Text)
|
||||||
|
import qualified Data.Text as Text
|
||||||
|
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.))
|
||||||
|
import Lens.Family2.Unchecked (lens)
|
||||||
|
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
|
||||||
|
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
|
||||||
|
import Data.Default (def)
|
||||||
|
import TensorFlow.Types (Attribute, attrLens)
|
||||||
|
import TensorFlow.Orphans ()
|
||||||
|
|
||||||
|
-- | A type of graph node which has no outputs. These nodes are
|
||||||
|
-- valuable for causing side effects when they are run.
|
||||||
|
newtype ControlNode = ControlNode { unControlNode :: Op }
|
||||||
|
|
||||||
|
-- | The type of op of a node in the graph. This corresponds to the proto field
|
||||||
|
-- NodeDef.op.
|
||||||
|
newtype OpType = OpType { unOpType :: Text }
|
||||||
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
|
instance IsString OpType where
|
||||||
|
fromString = OpType . Text.pack
|
||||||
|
|
||||||
|
-- | An output of a TensorFlow node.
|
||||||
|
data Output = Output !OutputIx !Op
|
||||||
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
|
output :: OutputIx -> Op -> Output
|
||||||
|
output = Output
|
||||||
|
|
||||||
|
outputOp :: Lens' Output Op
|
||||||
|
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)
|
||||||
|
|
||||||
|
outputIndex :: Lens' Output OutputIx
|
||||||
|
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)
|
||||||
|
|
||||||
|
newtype OutputIx = OutputIx { unOutputIx :: Int }
|
||||||
|
deriving (Eq, Ord, Num, Enum, Show)
|
||||||
|
|
||||||
|
-- | A device that a node can be assigned to.
|
||||||
|
-- There's a naming convention where the device names
|
||||||
|
-- are constructed from job and replica names.
|
||||||
|
newtype Device = Device {deviceName :: Text}
|
||||||
|
deriving (Eq, Ord, IsString)
|
||||||
|
|
||||||
|
instance Show Device where
|
||||||
|
show (Device d) = show d
|
||||||
|
|
||||||
|
-- | The representation of a node in a TensorFlow graph.
|
||||||
|
data Op
|
||||||
|
= Rendered !NodeDef -- ^ Properties are fixed, including the
|
||||||
|
-- device, name, and scope.
|
||||||
|
| Unrendered !OpDef -- ^ Properties are not fixed, and may change depending
|
||||||
|
-- on which context this op is rendered in.
|
||||||
|
deriving (Eq, Ord)
|
||||||
|
|
||||||
|
instance Show Op where
|
||||||
|
show (Rendered n) = "Rendered " ++ showMessage n
|
||||||
|
show (Unrendered o) = "Unrendered " ++ show (o ^. opName)
|
||||||
|
|
||||||
|
-- | Traverse on the 'Unrendered' of an 'Op'.
|
||||||
|
--
|
||||||
|
-- Same implementation as _Left.
|
||||||
|
opUnrendered :: Traversal' Op OpDef
|
||||||
|
opUnrendered f (Unrendered a) = Unrendered <$> f a
|
||||||
|
opUnrendered _ (Rendered b) = pure (Rendered b)
|
||||||
|
|
||||||
|
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
|
||||||
|
data OpDef = OpDef
|
||||||
|
{ _opName :: !PendingNodeName
|
||||||
|
, _opType :: !OpType
|
||||||
|
, _opAttrs :: !(Map.Map Text AttrValue)
|
||||||
|
, _opInputs :: [Output]
|
||||||
|
, _opControlInputs :: [NodeName]
|
||||||
|
} deriving (Eq, Ord)
|
||||||
|
|
||||||
|
-- | The name specified for an unrendered Op. If an Op has an
|
||||||
|
-- ImplicitName, it will be assigned based on the opType plus a
|
||||||
|
-- unique identifier. Does not contain the "scope" prefix.
|
||||||
|
data PendingNodeName = ExplicitName !Text | ImplicitName
|
||||||
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
|
-- | The name of a node in the graph. This corresponds to the proto field
|
||||||
|
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
|
||||||
|
-- (if the node was implicitly named).
|
||||||
|
newtype NodeName = NodeName { unNodeName :: Text }
|
||||||
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
|
opName :: Lens' OpDef PendingNodeName
|
||||||
|
opName = lens _opName (\o x -> o {_opName = x})
|
||||||
|
|
||||||
|
opType :: Lens' OpDef OpType
|
||||||
|
opType = lens _opType (\o x -> o { _opType = x})
|
||||||
|
|
||||||
|
opAttr :: Attribute a => Text -> Lens' OpDef a
|
||||||
|
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
|
||||||
|
. lens (Map.findWithDefault def n) (flip (Map.insert n))
|
||||||
|
. attrLens
|
||||||
|
|
||||||
|
opInputs :: Lens' OpDef [Output]
|
||||||
|
opInputs = lens _opInputs (\o x -> o {_opInputs = x})
|
||||||
|
|
||||||
|
opControlInputs :: Lens' OpDef [NodeName]
|
||||||
|
opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
|
||||||
|
|
||||||
|
-- TODO(gnezdo): IsString instance is weird and we should move that
|
||||||
|
-- code into a Build function
|
||||||
|
instance IsString Output where
|
||||||
|
fromString s = case break (==':') s of
|
||||||
|
(n, ':':ixStr)
|
||||||
|
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
|
||||||
|
_ -> Output 0 $ assigned s
|
||||||
|
where assigned n = Rendered $ def & name .~ Text.pack n
|
||||||
|
|
202
tensorflow/src/TensorFlow/Session.hs
Normal file
202
tensorflow/src/TensorFlow/Session.hs
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
|
{-# LANGUAGE Rank2Types #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
|
||||||
|
module TensorFlow.Session (
|
||||||
|
Session,
|
||||||
|
-- * Opaque value created via 'sessionConfig' and 'sessionTarget'.
|
||||||
|
SessionOption,
|
||||||
|
sessionConfig,
|
||||||
|
sessionTarget,
|
||||||
|
runSession,
|
||||||
|
runSessionWithOptions,
|
||||||
|
build,
|
||||||
|
buildAnd,
|
||||||
|
buildWithSummary,
|
||||||
|
extend,
|
||||||
|
addGraphDef,
|
||||||
|
run,
|
||||||
|
runWithFeeds,
|
||||||
|
run_,
|
||||||
|
runWithFeeds_,
|
||||||
|
asyncProdNodes,
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad (forever, unless, void)
|
||||||
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
|
import Control.Monad.Trans.Class (lift)
|
||||||
|
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Functor.Identity (runIdentity)
|
||||||
|
import qualified Data.Map.Strict as Map
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
import Data.Set (Set)
|
||||||
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
|
import Data.ProtoLens (def)
|
||||||
|
import Lens.Family2 ((&), (.~))
|
||||||
|
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||||
|
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||||
|
|
||||||
|
import TensorFlow.Build
|
||||||
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
import qualified TensorFlow.Internal.Raw as Raw
|
||||||
|
import TensorFlow.Nodes
|
||||||
|
import TensorFlow.Output (NodeName, unNodeName)
|
||||||
|
import TensorFlow.Tensor
|
||||||
|
|
||||||
|
-- Common state threaded through the session.
|
||||||
|
data SessionState
|
||||||
|
= SessionState {
|
||||||
|
rawSession :: FFI.Session
|
||||||
|
, asyncCollector :: IO () -> IO ()
|
||||||
|
-- ^ Starts the given action concurrently.
|
||||||
|
}
|
||||||
|
|
||||||
|
newtype Session a
|
||||||
|
= Session (ReaderT SessionState (BuildT IO) a)
|
||||||
|
deriving (Functor, Applicative, Monad, MonadIO)
|
||||||
|
|
||||||
|
-- | Run 'Session' actions in a new TensorFlow session.
|
||||||
|
runSession :: Session a -> IO a
|
||||||
|
runSession = runSessionWithOptions []
|
||||||
|
|
||||||
|
-- | Setting of an option for the session (see 'runSessionWithOptions').
|
||||||
|
newtype SessionOption =
|
||||||
|
SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () }
|
||||||
|
|
||||||
|
-- | Target can be: "local", ip:port, host:port.
|
||||||
|
-- The set of supported factories depends on the linked in libraries.
|
||||||
|
-- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary.
|
||||||
|
sessionTarget :: ByteString -> SessionOption
|
||||||
|
sessionTarget = SessionOption . FFI.setSessionTarget
|
||||||
|
|
||||||
|
-- | Uses the specified config for the created session.
|
||||||
|
sessionConfig :: ConfigProto -> SessionOption
|
||||||
|
sessionConfig = SessionOption . FFI.setSessionConfig
|
||||||
|
|
||||||
|
-- | Run 'Session' actions in a new TensorFlow session created with
|
||||||
|
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
||||||
|
runSessionWithOptions :: [SessionOption] -> Session a -> IO a
|
||||||
|
runSessionWithOptions options (Session m) =
|
||||||
|
FFI.withSession applyOptions $
|
||||||
|
\as rs -> evalBuildT (runReaderT m (SessionState rs as))
|
||||||
|
where applyOptions opt = mapM_ (`unSesssionOption` opt) options
|
||||||
|
|
||||||
|
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||||
|
-- renderings.
|
||||||
|
build :: Build a -> Session a
|
||||||
|
build = Session . lift . hoistBuildT (return . runIdentity)
|
||||||
|
|
||||||
|
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||||
|
-- renderings. Returns the merged summary ops which can be used for
|
||||||
|
-- logging, see 'TensorFlow.Logging.build' for a convenient wrapper.
|
||||||
|
buildWithSummary :: forall a . Build a -> Session (a, [SummaryTensor])
|
||||||
|
buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
|
||||||
|
where v :: BuildT IO a
|
||||||
|
v = hoistBuildT (return . runIdentity) b
|
||||||
|
|
||||||
|
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
||||||
|
-- any pending initializers.
|
||||||
|
--
|
||||||
|
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
||||||
|
extend :: Session ()
|
||||||
|
extend = do
|
||||||
|
let withSessionWhen vs action =
|
||||||
|
unless (null vs) $ Session (asks rawSession) >>= action
|
||||||
|
nodesToExtend <- build flushNodeBuffer
|
||||||
|
withSessionWhen nodesToExtend $ \session ->
|
||||||
|
liftIO $ FFI.extendGraph session
|
||||||
|
$ def & node .~ nodesToExtend
|
||||||
|
-- Now that all the nodes are created, run the initializers.
|
||||||
|
initializers <- build flushInitializers
|
||||||
|
withSessionWhen initializers $ \session ->
|
||||||
|
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
||||||
|
|
||||||
|
-- | Helper combinator for doing something with the result of a 'Build' action.
|
||||||
|
-- Example usage:
|
||||||
|
--
|
||||||
|
-- > buildAnd run :: Fetchable t a => Build t -> Session a
|
||||||
|
buildAnd :: (a -> Session b) -> Build a -> Session b
|
||||||
|
buildAnd f m = build m >>= f
|
||||||
|
|
||||||
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||||
|
-- rendered, and fetch the corresponding values for 'a'.
|
||||||
|
run :: Fetchable t a => t -> Session a
|
||||||
|
run = runWithFeeds []
|
||||||
|
|
||||||
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||||
|
-- rendered, feed the given input values, and fetch the corresponding result
|
||||||
|
-- values for 'a'.
|
||||||
|
runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a
|
||||||
|
runWithFeeds feeds t = do
|
||||||
|
ns <- build $ getNodes t
|
||||||
|
-- Note that this call to "fetch" shouldn't affect the following "extend"
|
||||||
|
-- call, since all nodes in t and its inputs/deps will be rendered by the
|
||||||
|
-- above call to getNodes.
|
||||||
|
fetch <- build $ getFetch t
|
||||||
|
runFetchWithFeeds feeds ns fetch
|
||||||
|
|
||||||
|
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
|
||||||
|
runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
||||||
|
extend
|
||||||
|
feeds' <- build $ fixFeeds feeds
|
||||||
|
let fetchNames = encodeUtf8 <$> Set.toList fetch
|
||||||
|
targetNames = toNodeNames $ Set.toList target
|
||||||
|
session <- Session (asks rawSession)
|
||||||
|
runResult <- liftIO $ FFI.run session
|
||||||
|
feeds'
|
||||||
|
fetchNames
|
||||||
|
targetNames
|
||||||
|
let resultTensorsMap = Map.fromList $ zip (Set.toList fetch) runResult
|
||||||
|
return $ restore resultTensorsMap
|
||||||
|
|
||||||
|
toNodeNames :: [NodeName] -> [ByteString]
|
||||||
|
toNodeNames = map (encodeUtf8 . unNodeName)
|
||||||
|
|
||||||
|
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
|
||||||
|
-- already rendered. This behaves like 'run' except that it doesn't do any
|
||||||
|
-- fetches.
|
||||||
|
run_ :: Nodes t => t -> Session ()
|
||||||
|
run_ = runWithFeeds_ []
|
||||||
|
|
||||||
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||||
|
-- rendered, feed the given input values, and fetch the corresponding result
|
||||||
|
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
|
||||||
|
-- any fetches.
|
||||||
|
runWithFeeds_ :: Nodes t => [Feed] -> t -> Session ()
|
||||||
|
runWithFeeds_ feeds t = do
|
||||||
|
ns <- build $ getNodes t
|
||||||
|
runFetchWithFeeds feeds ns (pure ())
|
||||||
|
|
||||||
|
fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)]
|
||||||
|
fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o
|
||||||
|
|
||||||
|
-- | Starts a concurrent thread which evaluates the given Nodes
|
||||||
|
-- forever until runSession exits or an exception occurs. Graph
|
||||||
|
-- extension happens synchronously, but the resultant run proceeds as
|
||||||
|
-- a separate thread.
|
||||||
|
asyncProdNodes :: Nodes t
|
||||||
|
=> t -- ^ Node to evaluate concurrently.
|
||||||
|
-> Session ()
|
||||||
|
asyncProdNodes nodes = do
|
||||||
|
target <- build (getNodes nodes)
|
||||||
|
extend
|
||||||
|
let targetNames = toNodeNames $ Set.toList target
|
||||||
|
state <- Session ask
|
||||||
|
let loop = forever (void (FFI.run (rawSession state) [] [] targetNames))
|
||||||
|
liftIO (asyncCollector state loop)
|
85
tensorflow/src/TensorFlow/Tensor.hs
Normal file
85
tensorflow/src/TensorFlow/Tensor.hs
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE Rank2Types #-}
|
||||||
|
|
||||||
|
module TensorFlow.Tensor where
|
||||||
|
|
||||||
|
import Data.String (IsString(..))
|
||||||
|
import qualified Data.Text as Text
|
||||||
|
import Lens.Family2 (Lens', Traversal')
|
||||||
|
import Lens.Family2.Unchecked (lens)
|
||||||
|
|
||||||
|
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
||||||
|
import TensorFlow.Types (TensorData(..), Attribute)
|
||||||
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
|
||||||
|
-- | A named output of a TensorFlow operation.
|
||||||
|
--
|
||||||
|
-- The type parameter @a@ is the type of the elements in the 'Tensor'. The
|
||||||
|
-- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is
|
||||||
|
-- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a
|
||||||
|
-- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via
|
||||||
|
-- 'value'.
|
||||||
|
data Tensor v a = Tensor (TensorKind v) Output
|
||||||
|
|
||||||
|
data Value
|
||||||
|
data Ref
|
||||||
|
|
||||||
|
-- | This class provides a runtime switch on whether a 'Tensor' should be
|
||||||
|
-- treated as a 'Value' or as a 'Ref'.
|
||||||
|
data TensorKind v where
|
||||||
|
ValueKind :: TensorKind Value
|
||||||
|
RefKind :: TensorKind Ref
|
||||||
|
|
||||||
|
tensorKind :: Lens' (Tensor v a) (TensorKind v)
|
||||||
|
tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
|
||||||
|
|
||||||
|
tensorOutput :: Lens' (Tensor v a) Output
|
||||||
|
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
||||||
|
|
||||||
|
-- TODO: Come up with a better API for handling attributes.
|
||||||
|
-- | Lens for the attributes of a tensor.
|
||||||
|
--
|
||||||
|
-- Only valid if the tensor has not yet been rendered. If the tensor has been
|
||||||
|
-- rendered, the traversal will be over nothing (nothing can be read or
|
||||||
|
-- written).
|
||||||
|
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
|
||||||
|
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
|
||||||
|
|
||||||
|
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
||||||
|
-- Ref into Value. This behaves like a no-op.
|
||||||
|
value :: Tensor v a -> Tensor Value a
|
||||||
|
value (Tensor _ o) = Tensor ValueKind o
|
||||||
|
|
||||||
|
-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
|
||||||
|
-- when running the graph.
|
||||||
|
data Feed = Feed Output FFI.TensorData
|
||||||
|
|
||||||
|
-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
|
||||||
|
-- the graph.
|
||||||
|
--
|
||||||
|
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
|
||||||
|
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
|
||||||
|
feed :: Tensor v a -> TensorData a -> Feed
|
||||||
|
feed (Tensor _ o) (TensorData td) = Feed o td
|
||||||
|
|
||||||
|
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
|
||||||
|
-- in a 'GraphDef' that was loaded via 'addGraphDef'.
|
||||||
|
-- TODO(judahjacobson): add more safety checks here.
|
||||||
|
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
|
||||||
|
tensorFromName v = Tensor v . fromString . Text.unpack
|
382
tensorflow/src/TensorFlow/Types.hs
Normal file
382
tensorflow/src/TensorFlow/Types.hs
Normal file
|
@ -0,0 +1,382 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
-- We use UndecidableInstances for type families with recursive definitions
|
||||||
|
-- like "\\". Those instances will terminate since each equation unwraps one
|
||||||
|
-- cons cell of a type-level list.
|
||||||
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
|
|
||||||
|
module TensorFlow.Types
|
||||||
|
( TensorType(..)
|
||||||
|
, TensorData(..)
|
||||||
|
, Shape(..)
|
||||||
|
, protoShape
|
||||||
|
, Attribute(..)
|
||||||
|
-- * Type constraints
|
||||||
|
, OneOf
|
||||||
|
, type (/=)
|
||||||
|
-- ** Implementation of constraints
|
||||||
|
, TypeError
|
||||||
|
, ExcludedCase
|
||||||
|
, TensorTypes
|
||||||
|
, NoneOf
|
||||||
|
, type (\\)
|
||||||
|
, Delete
|
||||||
|
, AllTensorTypes
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Complex (Complex)
|
||||||
|
import Data.Default (def)
|
||||||
|
import Data.Int (Int8, Int16, Int32, Int64)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Data.Word (Word8, Word16, Word64)
|
||||||
|
import Foreign.Storable (Storable)
|
||||||
|
import GHC.Exts (Constraint, IsList(..))
|
||||||
|
import Lens.Family2 (Lens', view, (&), (.~))
|
||||||
|
import Lens.Family2.Unchecked (iso)
|
||||||
|
import qualified Data.Attoparsec.ByteString as Atto
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import qualified Data.ByteString as B
|
||||||
|
import Data.ByteString.Builder (Builder)
|
||||||
|
import qualified Data.ByteString.Builder as Builder
|
||||||
|
import qualified Data.ByteString.Lazy as L
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
import qualified Data.Vector.Storable as S
|
||||||
|
import Proto.Tensorflow.Core.Framework.AttrValue
|
||||||
|
( AttrValue(..)
|
||||||
|
, AttrValue'ListValue(..)
|
||||||
|
, b
|
||||||
|
, f
|
||||||
|
, i
|
||||||
|
, s
|
||||||
|
, list
|
||||||
|
, type'
|
||||||
|
, shape
|
||||||
|
, tensor
|
||||||
|
)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||||
|
( TensorProto(..)
|
||||||
|
, floatVal
|
||||||
|
, doubleVal
|
||||||
|
, intVal
|
||||||
|
, stringVal
|
||||||
|
, int64Val
|
||||||
|
, stringVal
|
||||||
|
, boolVal
|
||||||
|
)
|
||||||
|
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
|
( TensorShapeProto(..)
|
||||||
|
, dim
|
||||||
|
, size
|
||||||
|
)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||||
|
|
||||||
|
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
|
||||||
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
|
||||||
|
-- | Data about a tensor that is encoded for the TensorFlow APIs.
|
||||||
|
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
||||||
|
|
||||||
|
-- | The class of scalar types supported by tensorflow.
|
||||||
|
class TensorType a where
|
||||||
|
tensorType :: a -> DataType
|
||||||
|
tensorRefType :: a -> DataType
|
||||||
|
tensorVal :: Lens' TensorProto [a]
|
||||||
|
-- | Decode the bytes of a TensorData into a Vector.
|
||||||
|
decodeTensorData :: TensorData a -> V.Vector a
|
||||||
|
-- | Encode a Vector into a TensorData.
|
||||||
|
--
|
||||||
|
-- The values should be in row major order, e.g.,
|
||||||
|
--
|
||||||
|
-- element 0: index (0, ..., 0)
|
||||||
|
-- element 1: index (0, ..., 1)
|
||||||
|
-- ...
|
||||||
|
encodeTensorData :: Shape -> V.Vector a -> TensorData a
|
||||||
|
|
||||||
|
-- All types, besides ByteString, are encoded as simple arrays and we can use
|
||||||
|
-- Vector.Storable to encode/decode by type casting pointers.
|
||||||
|
|
||||||
|
-- TODO(fmayle): Assert that the data type matches the return type.
|
||||||
|
simpleDecode :: Storable a => TensorData a -> V.Vector a
|
||||||
|
simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData
|
||||||
|
|
||||||
|
simpleEncode :: forall a . (TensorType a, Storable a)
|
||||||
|
=> Shape -> V.Vector a -> TensorData a
|
||||||
|
simpleEncode (Shape xs)
|
||||||
|
= TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert
|
||||||
|
where
|
||||||
|
dt = tensorType (undefined :: a)
|
||||||
|
|
||||||
|
instance TensorType Float where
|
||||||
|
tensorType _ = DT_FLOAT
|
||||||
|
tensorRefType _ = DT_FLOAT_REF
|
||||||
|
tensorVal = floatVal
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorType Double where
|
||||||
|
tensorType _ = DT_DOUBLE
|
||||||
|
tensorRefType _ = DT_DOUBLE_REF
|
||||||
|
tensorVal = doubleVal
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorType Int32 where
|
||||||
|
tensorType _ = DT_INT32
|
||||||
|
tensorRefType _ = DT_INT32_REF
|
||||||
|
tensorVal = intVal
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorType Int64 where
|
||||||
|
tensorType _ = DT_INT64
|
||||||
|
tensorRefType _ = DT_INT64_REF
|
||||||
|
tensorVal = int64Val
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
integral :: Integral a => Lens' [Int32] [a]
|
||||||
|
integral = iso (fmap fromIntegral) (fmap fromIntegral)
|
||||||
|
|
||||||
|
instance TensorType Word8 where
|
||||||
|
tensorType _ = DT_UINT8
|
||||||
|
tensorRefType _ = DT_UINT8_REF
|
||||||
|
tensorVal = intVal . integral
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorType Word16 where
|
||||||
|
tensorType _ = DT_UINT16
|
||||||
|
tensorRefType _ = DT_UINT16_REF
|
||||||
|
tensorVal = intVal . integral
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorType Int16 where
|
||||||
|
tensorType _ = DT_INT16
|
||||||
|
tensorRefType _ = DT_INT16_REF
|
||||||
|
tensorVal = intVal . integral
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorType Int8 where
|
||||||
|
tensorType _ = DT_INT8
|
||||||
|
tensorRefType _ = DT_INT8_REF
|
||||||
|
tensorVal = intVal . integral
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorType ByteString where
|
||||||
|
tensorType _ = DT_STRING
|
||||||
|
tensorRefType _ = DT_STRING_REF
|
||||||
|
tensorVal = stringVal
|
||||||
|
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
||||||
|
-- table offsets for each element :: [Word64]
|
||||||
|
-- at each element offset:
|
||||||
|
-- string length :: VarInt64
|
||||||
|
-- string data :: [Word8]
|
||||||
|
-- TODO(fmayle): Benchmark these functions.
|
||||||
|
decodeTensorData tensorData =
|
||||||
|
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
||||||
|
if expected /= count
|
||||||
|
then Left $ "decodeTensorData for ByteString count mismatch " ++
|
||||||
|
show (expected, count)
|
||||||
|
else V.mapM decodeString (S.convert offsets)
|
||||||
|
where
|
||||||
|
expected = S.length offsets
|
||||||
|
count = fromIntegral $ product $ FFI.tensorDataDimensions
|
||||||
|
$ unTensorData tensorData
|
||||||
|
bytes = FFI.tensorDataBytes $ unTensorData tensorData
|
||||||
|
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
|
||||||
|
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
|
||||||
|
decodeString :: Word64 -> Either String ByteString
|
||||||
|
decodeString offset =
|
||||||
|
let stringDataStart = B.drop (fromIntegral offset) dataBytes
|
||||||
|
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
|
||||||
|
stringParser :: Atto.Parser ByteString
|
||||||
|
stringParser = getVarInt >>= Atto.take . fromIntegral
|
||||||
|
encodeTensorData (Shape xs) vec =
|
||||||
|
TensorData $ FFI.TensorData xs dt byteVector
|
||||||
|
where
|
||||||
|
dt = tensorType (undefined :: ByteString)
|
||||||
|
-- Add a string to an offset table and data blob.
|
||||||
|
addString :: (Builder, Builder, Word64)
|
||||||
|
-> ByteString
|
||||||
|
-> (Builder, Builder, Word64)
|
||||||
|
addString (table, strings, offset) str =
|
||||||
|
( table <> Builder.word64LE offset
|
||||||
|
, strings <> lengthBytes <> Builder.byteString str
|
||||||
|
, offset + lengthBytesLen + strLen
|
||||||
|
)
|
||||||
|
where
|
||||||
|
strLen = fromIntegral $ B.length str
|
||||||
|
lengthBytes = putVarInt $ fromIntegral $ B.length str
|
||||||
|
lengthBytesLen =
|
||||||
|
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
|
||||||
|
-- Encode all strings.
|
||||||
|
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
|
||||||
|
-- Concat offset table with data.
|
||||||
|
bytes = table' <> strings'
|
||||||
|
-- Convert to Vector Word8.
|
||||||
|
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
|
||||||
|
|
||||||
|
|
||||||
|
instance TensorType Bool where
|
||||||
|
tensorType _ = DT_BOOL
|
||||||
|
tensorRefType _ = DT_BOOL_REF
|
||||||
|
tensorVal = boolVal
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorType (Complex Float) where
|
||||||
|
tensorType _ = DT_COMPLEX64
|
||||||
|
tensorRefType _ = DT_COMPLEX64
|
||||||
|
tensorVal = error "TODO (Complex Float)"
|
||||||
|
decodeTensorData = error "TODO (Complex Float)"
|
||||||
|
encodeTensorData = error "TODO (Complex Float)"
|
||||||
|
|
||||||
|
instance TensorType (Complex Double) where
|
||||||
|
tensorType _ = DT_COMPLEX128
|
||||||
|
tensorRefType _ = DT_COMPLEX128
|
||||||
|
tensorVal = error "TODO (Complex Double)"
|
||||||
|
decodeTensorData = error "TODO (Complex Double)"
|
||||||
|
encodeTensorData = error "TODO (Complex Double)"
|
||||||
|
|
||||||
|
-- | Shape (dimensions) of a tensor.
|
||||||
|
newtype Shape = Shape [Int64] deriving Show
|
||||||
|
|
||||||
|
instance IsList Shape where
|
||||||
|
type Item Shape = Int64
|
||||||
|
fromList = Shape . fromList
|
||||||
|
toList (Shape ss) = toList ss
|
||||||
|
|
||||||
|
protoShape :: Lens' TensorShapeProto Shape
|
||||||
|
protoShape = iso protoToShape shapeToProto
|
||||||
|
where
|
||||||
|
protoToShape = Shape . fmap (view size) . view dim
|
||||||
|
shapeToProto (Shape ds) = def & dim .~ fmap (\d -> def & size .~ d) ds
|
||||||
|
|
||||||
|
|
||||||
|
class Attribute a where
|
||||||
|
attrLens :: Lens' AttrValue a
|
||||||
|
|
||||||
|
instance Attribute Float where
|
||||||
|
attrLens = f
|
||||||
|
|
||||||
|
instance Attribute ByteString where
|
||||||
|
attrLens = s
|
||||||
|
|
||||||
|
instance Attribute Int64 where
|
||||||
|
attrLens = i
|
||||||
|
|
||||||
|
instance Attribute DataType where
|
||||||
|
attrLens = type'
|
||||||
|
|
||||||
|
instance Attribute TensorProto where
|
||||||
|
attrLens = tensor
|
||||||
|
|
||||||
|
instance Attribute Bool where
|
||||||
|
attrLens = b
|
||||||
|
|
||||||
|
instance Attribute Shape where
|
||||||
|
attrLens = shape . protoShape
|
||||||
|
|
||||||
|
-- TODO(gnezdo): support generating list(Foo) from [Foo].
|
||||||
|
instance Attribute AttrValue'ListValue where
|
||||||
|
attrLens = list
|
||||||
|
|
||||||
|
instance Attribute [DataType] where
|
||||||
|
attrLens = list . type'
|
||||||
|
|
||||||
|
instance Attribute [Int64] where
|
||||||
|
attrLens = list . i
|
||||||
|
|
||||||
|
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
|
||||||
|
--
|
||||||
|
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
|
||||||
|
-- natural representation as a conjunction, i.e.,
|
||||||
|
--
|
||||||
|
-- @
|
||||||
|
-- a == Double || a == Float
|
||||||
|
-- @
|
||||||
|
--
|
||||||
|
-- into a disjunction like
|
||||||
|
--
|
||||||
|
-- @
|
||||||
|
-- a \/= Int32 && a \/= Int64 && a \/= ByteString && ...
|
||||||
|
-- @
|
||||||
|
--
|
||||||
|
-- using an enumeration of all the possible 'TensorType's.
|
||||||
|
type OneOf ts a
|
||||||
|
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
|
||||||
|
|
||||||
|
-- | A 'Constraint' checking that the input is a list of 'TensorType's.
|
||||||
|
-- Helps improve error messages when using 'OneOf'.
|
||||||
|
type family TensorTypes ts :: Constraint where
|
||||||
|
TensorTypes '[] = ()
|
||||||
|
TensorTypes (t ': ts) = (TensorType t, TensorTypes ts)
|
||||||
|
|
||||||
|
-- | A constraint checking that two types are different.
|
||||||
|
type family a /= b :: Constraint where
|
||||||
|
a /= a = TypeError a ~ ExcludedCase
|
||||||
|
a /= b = ()
|
||||||
|
|
||||||
|
-- | Helper types to produce a reasonable type error message when the Constraint
|
||||||
|
-- "a /= a" fails.
|
||||||
|
-- TODO(judahjacobson): Use ghc-8's CustomTypeErrors for this.
|
||||||
|
data TypeError a
|
||||||
|
data ExcludedCase
|
||||||
|
|
||||||
|
-- | An enumeration of all valid 'TensorType's.
|
||||||
|
type AllTensorTypes =
|
||||||
|
-- NOTE: This list should be kept in sync with
|
||||||
|
-- TensorFlow.OpGen.dtTypeToHaskell.
|
||||||
|
-- TODO: Add support for Complex Float/Double.
|
||||||
|
'[ Float
|
||||||
|
, Double
|
||||||
|
, Int8
|
||||||
|
, Int16
|
||||||
|
, Int32
|
||||||
|
, Int64
|
||||||
|
, Word8
|
||||||
|
, Word16
|
||||||
|
, ByteString
|
||||||
|
, Bool
|
||||||
|
]
|
||||||
|
|
||||||
|
-- | Removes a type from the given list of types.
|
||||||
|
type family Delete a as where
|
||||||
|
Delete a '[] = '[]
|
||||||
|
Delete a (a ': as) = Delete a as
|
||||||
|
Delete a (b ': as) = b ': Delete a as
|
||||||
|
|
||||||
|
-- | Takes the difference of two lists of types.
|
||||||
|
type family as \\ bs where
|
||||||
|
as \\ '[] = as
|
||||||
|
as \\ b ': bs = Delete b as \\ bs
|
||||||
|
|
||||||
|
-- | A constraint that the type @a@ doesn't appear in the type list @ts@.
|
||||||
|
-- Assumes that @a@ and each of the elements of @ts@ are 'TensorType's.
|
||||||
|
type family NoneOf ts a :: Constraint where
|
||||||
|
NoneOf '[] a = ()
|
||||||
|
NoneOf (t ': ts) a = (a /= t, NoneOf ts a)
|
84
tensorflow/tensorflow.cabal
Normal file
84
tensorflow/tensorflow.cabal
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
name: tensorflow
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: TensorFlow bindings.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.Build
|
||||||
|
, TensorFlow.BuildOp
|
||||||
|
, TensorFlow.ControlFlow
|
||||||
|
, TensorFlow.Internal.FFI
|
||||||
|
, TensorFlow.Nodes
|
||||||
|
, TensorFlow.Output
|
||||||
|
, TensorFlow.Session
|
||||||
|
, TensorFlow.Tensor
|
||||||
|
, TensorFlow.Types
|
||||||
|
, TensorFlow.Internal.VarInt
|
||||||
|
other-modules: TensorFlow.Internal.Raw
|
||||||
|
, TensorFlow.Orphans
|
||||||
|
build-tools: c2hs
|
||||||
|
build-depends: proto-lens == 0.1.*
|
||||||
|
-- Used by the custom Setup script (for the test-suite).
|
||||||
|
, proto-lens-protoc == 0.1.*
|
||||||
|
, tensorflow-proto == 0.1.*
|
||||||
|
, base >= 4.7 && < 5
|
||||||
|
, async
|
||||||
|
, attoparsec
|
||||||
|
, bytestring
|
||||||
|
, containers
|
||||||
|
, data-default
|
||||||
|
, fgl
|
||||||
|
, lens-family
|
||||||
|
, mainland-pretty
|
||||||
|
, mtl
|
||||||
|
, semigroups
|
||||||
|
, split
|
||||||
|
, text
|
||||||
|
, temporary
|
||||||
|
, transformers
|
||||||
|
, vector
|
||||||
|
extra-libraries: tensorflow_c
|
||||||
|
default-language: Haskell2010
|
||||||
|
include-dirs: .
|
||||||
|
|
||||||
|
Test-Suite FFITest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: FFITest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, base
|
||||||
|
, bytestring
|
||||||
|
, lens-family
|
||||||
|
, proto-lens
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-proto
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
|
||||||
|
|
||||||
|
Test-Suite VarIntTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: VarIntTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: base
|
||||||
|
, attoparsec
|
||||||
|
, bytestring
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, test-framework
|
||||||
|
, test-framework-quickcheck2
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
38
tensorflow/tests/FFITest.hs
Normal file
38
tensorflow/tests/FFITest.hs
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
-- | Tests for FFI.
|
||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Data.ProtoLens (decodeMessage)
|
||||||
|
import Lens.Family2 (view)
|
||||||
|
import TensorFlow.Internal.FFI (getAllOpList)
|
||||||
|
import Test.HUnit (assertBool, assertFailure)
|
||||||
|
import Test.Framework (defaultMain)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Proto.Tensorflow.Core.Framework.OpDef (OpList, op)
|
||||||
|
|
||||||
|
testParseAll :: IO ()
|
||||||
|
testParseAll = do
|
||||||
|
opList <- getAllOpList
|
||||||
|
either
|
||||||
|
assertFailure
|
||||||
|
(assertBool "Expected non-empty list of default Ops"
|
||||||
|
. not . null . view op)
|
||||||
|
(decodeMessage opList :: Either String OpList)
|
||||||
|
|
||||||
|
main = defaultMain
|
||||||
|
[ testCase "ParseAllOps" testParseAll
|
||||||
|
]
|
32
tensorflow/tests/VarIntTest.hs
Normal file
32
tensorflow/tests/VarIntTest.hs
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Data.ByteString.Builder (toLazyByteString)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
|
import qualified Data.Attoparsec.ByteString.Lazy as Atto
|
||||||
|
|
||||||
|
import TensorFlow.Internal.VarInt
|
||||||
|
|
||||||
|
testEncodeDecode = testProperty "testEncodeDecode" $ \x ->
|
||||||
|
let bytes = toLazyByteString (putVarInt x)
|
||||||
|
in case Atto.eitherResult $ Atto.parse getVarInt bytes of
|
||||||
|
Left _ -> False
|
||||||
|
Right y -> x == y
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testEncodeDecode
|
||||||
|
]
|
1
tensorflow/third_party
Symbolic link
1
tensorflow/third_party
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../third_party/tensorflow
|
1
third_party/tensorflow
vendored
Submodule
1
third_party/tensorflow
vendored
Submodule
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit bac7faa9a3eb5b60687a83336202cd3493de5385
|
Loading…
Reference in a new issue