mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-11 11:29:47 +01:00
Initial commit
This commit is contained in:
commit
67690d1499
67 changed files with 6400 additions and 0 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
**/.stack-work
|
||||
.stack/
|
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
[submodule "third_party/tensorflow"]
|
||||
path = third_party/tensorflow
|
||||
url = https://github.com/tensorflow/tensorflow.git
|
25
CONTRIBUTING.md
Normal file
25
CONTRIBUTING.md
Normal file
|
@ -0,0 +1,25 @@
|
|||
Want to contribute? Great! First, read this page (including the small print at the end).
|
||||
|
||||
### Before you contribute
|
||||
Before we can use your code, you must sign the
|
||||
[Google Individual Contributor License Agreement](https://cla.developers.google.com/about/google-individual)
|
||||
(CLA), which you can do online. The CLA is necessary mainly because you own the
|
||||
copyright to your changes, even after your contribution becomes part of our
|
||||
codebase, so we need your permission to use and distribute your code. We also
|
||||
need to be sure of various other things—for instance that you'll tell us if you
|
||||
know that your code infringes on other people's patents. You don't have to sign
|
||||
the CLA until after you've submitted your code for review and a member has
|
||||
approved it, but you must do it before we can put your code into our codebase.
|
||||
Before you start working on a larger contribution, you should get in touch with
|
||||
us first through the issue tracker with your idea so that we can help out and
|
||||
possibly guide you. Coordinating up front makes it much easier to avoid
|
||||
frustration later on.
|
||||
|
||||
### Code reviews
|
||||
All submissions, including submissions by project members, require review. We
|
||||
use Github pull requests for this purpose.
|
||||
|
||||
### The small print
|
||||
Contributions made by corporations are covered by a different agreement than
|
||||
the one above, the
|
||||
[Software Grant and Corporate Contributor License Agreement](https://cla.developers.google.com/about/google-corporate).
|
203
LICENSE
Normal file
203
LICENSE
Normal file
|
@ -0,0 +1,203 @@
|
|||
Copyright 2016 The TensorFlow Authors. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2016, The TensorFlow Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
24
README.md
Normal file
24
README.md
Normal file
|
@ -0,0 +1,24 @@
|
|||
The tensorflow-haskell package provides Haskell bindings to
|
||||
[TensorFlow](https://www.tensorflow.org/).
|
||||
|
||||
This is not an official Google product.
|
||||
|
||||
# Instructions
|
||||
|
||||
## Build
|
||||
|
||||
For now [docker](https://www.docker.com/) is required. Once you have docker
|
||||
working, the following commands will compile and run the tests.
|
||||
|
||||
git clone --recursive https://github.com/tensorflow/haskell.git tensorflow-haskell
|
||||
cd tensorflow-haskell
|
||||
IMAGE_NAME=tensorflow/haskell:v0
|
||||
docker build -t $IMAGE_NAME docker
|
||||
# TODO: move the setup step to the docker script.
|
||||
stack --docker --docker-image=$IMAGE_NAME setup
|
||||
stack --docker --docker-image=$IMAGE_NAME test
|
||||
|
||||
There is also a demo application:
|
||||
|
||||
cd tensorflow-mnist
|
||||
stack --docker --docker-image=$IMAGE_NAME build --exec Main
|
29
docker/Dockerfile
Normal file
29
docker/Dockerfile
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Prepare the image with:
|
||||
# docker build -t gnezdo/tfhs:2016-10-14-1 docker
|
||||
FROM gcr.io/tensorflow/tensorflow:latest-devel
|
||||
MAINTAINER Greg Steuck <gnezdo+tfhs@google.com>
|
||||
# Installs protoc and the libraries.
|
||||
RUN \
|
||||
cd /tensorflow && \
|
||||
bazel --batch build -c opt '@protobuf//:protoc' && \
|
||||
install -s bazel-bin/external/protobuf/protoc /usr/local/bin && \
|
||||
bazel --batch build -c opt '//tensorflow:libtensorflow_c.so' && \
|
||||
install bazel-bin/tensorflow/libtensorflow_c.so /usr/local/lib && \
|
||||
bazel --batch clean
|
||||
|
||||
RUN apt-get update
|
||||
|
||||
RUN apt-get install -y \
|
||||
# Avoids /usr/bin/ld: cannot find -ltinfo
|
||||
libncurses5-dev \
|
||||
# Makes stack viable in the container
|
||||
libgmp-dev \
|
||||
# Required for locales configuration.
|
||||
locales
|
||||
|
||||
# Our MNIST demo program outputs Unicode characters.
|
||||
RUN dpkg-reconfigure locales && \
|
||||
locale-gen en_US.UTF-8 && \
|
||||
update-locale LANG=en_US.UTF-8
|
||||
|
||||
ENV LANG en_US.UTF-8
|
3
google-shim/Setup.hs
Normal file
3
google-shim/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
|||
import Distribution.Simple
|
||||
|
||||
main = defaultMain
|
23
google-shim/google-shim.cabal
Normal file
23
google-shim/google-shim.cabal
Normal file
|
@ -0,0 +1,23 @@
|
|||
name: google-shim
|
||||
version: 0.1.0.0
|
||||
synopsis: Adapters to externalize TensorFlow code.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Simple
|
||||
cabal-version: >=1.22
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: Google.Test
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, test-framework
|
||||
default-language: Haskell2010
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
22
google-shim/src/Google/Test.hs
Normal file
22
google-shim/src/Google/Test.hs
Normal file
|
@ -0,0 +1,22 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- | Alternative implementations to make dependent code work without
|
||||
-- changes.
|
||||
module Google.Test where
|
||||
|
||||
import Test.Framework (Test, defaultMain)
|
||||
|
||||
googleTest :: [Test] -> IO ()
|
||||
googleTest = defaultMain
|
24
stack.yaml
Normal file
24
stack.yaml
Normal file
|
@ -0,0 +1,24 @@
|
|||
resolver: lts-6.2
|
||||
|
||||
packages:
|
||||
- google-shim
|
||||
- tensorflow
|
||||
- tensorflow-core-ops
|
||||
- tensorflow-opgen
|
||||
- tensorflow-ops
|
||||
- tensorflow-proto
|
||||
- tensorflow-mnist
|
||||
- tensorflow-mnist-input-data
|
||||
- tensorflow-queue
|
||||
|
||||
extra-deps:
|
||||
# proto-lens is not yet in Stackage.
|
||||
- proto-lens-0.1.0.4
|
||||
- proto-lens-protoc-0.1.0.4
|
||||
|
||||
# Allow our custom Setup.hs scripts to import Data.ProtoLens.Setup from the version of
|
||||
# `proto-lens-protoc` in stack's local DB. See:
|
||||
# https://github.com/google/proto-lens/blob/master/README.md#using-cabal
|
||||
explicit-setup-deps:
|
||||
"*": true
|
||||
|
100
tensorflow-core-ops/Setup.hs
Normal file
100
tensorflow-core-ops/Setup.hs
Normal file
|
@ -0,0 +1,100 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- | Generates the wrappers for Ops shipped with tensorflow_c.
|
||||
module Main where
|
||||
|
||||
import Distribution.Simple.BuildPaths (autogenModulesDir)
|
||||
import Distribution.Simple.LocalBuildInfo (LocalBuildInfo)
|
||||
import Distribution.Simple
|
||||
( defaultMainWithHooks
|
||||
, simpleUserHooks
|
||||
, UserHooks(..)
|
||||
)
|
||||
import Data.List (intercalate)
|
||||
import Data.ProtoLens (decodeMessage)
|
||||
import System.Directory (createDirectoryIfMissing)
|
||||
import System.Exit (exitFailure)
|
||||
import System.FilePath ((</>))
|
||||
import System.IO (hPutStrLn, stderr)
|
||||
import TensorFlow.Internal.FFI (getAllOpList)
|
||||
import TensorFlow.OpGen (docOpList, OpGenFlags(..))
|
||||
import Text.PrettyPrint.Mainland (prettyLazyText)
|
||||
import qualified Data.Text.Lazy.IO as Text
|
||||
|
||||
main = defaultMainWithHooks generatingOpsWrappers
|
||||
|
||||
-- TODO: Generalize for user libraries by replacing getAllOpList with
|
||||
-- a wrapper around TF_LoadLibrary. The complicated part is interplay
|
||||
-- between bazel and Haskell build system.
|
||||
generatingOpsWrappers :: UserHooks
|
||||
generatingOpsWrappers = hooks
|
||||
{ buildHook = \p l h f -> generateSources l >> buildHook hooks p l h f
|
||||
, haddockHook = \p l h f -> generateSources l >> haddockHook hooks p l h f
|
||||
, replHook = \p l h f args -> generateSources l
|
||||
>> replHook hooks p l h f args
|
||||
}
|
||||
where
|
||||
flagsBuilder dir = OpGenFlags
|
||||
{ outputFile = dir </> "Core.hs"
|
||||
, prefix = "TensorFlow.GenOps"
|
||||
, excludeList = intercalate "," blackList
|
||||
}
|
||||
hooks = simpleUserHooks
|
||||
generateSources :: LocalBuildInfo -> IO ()
|
||||
generateSources l = do
|
||||
let dir = autogenModulesDir l </> "TensorFlow/GenOps"
|
||||
createDirectoryIfMissing True dir
|
||||
let flags = flagsBuilder dir
|
||||
pb <- getAllOpList
|
||||
case decodeMessage pb of
|
||||
Left e -> hPutStrLn stderr e >> exitFailure
|
||||
Right x -> Text.writeFile (outputFile flags)
|
||||
(prettyLazyText 80 $ docOpList flags x)
|
||||
|
||||
blackList =
|
||||
-- A few data flow ops take a list of heterogeneous
|
||||
-- parameters which we don't support in general form.
|
||||
[ "HashTable"
|
||||
, "MutableDenseHashTable"
|
||||
, "MutableHashTable"
|
||||
, "MutableHashTableOfTensors"
|
||||
, "QueueDequeue"
|
||||
, "QueueDequeueMany"
|
||||
, "QueueDequeueUpTo"
|
||||
, "Stack"
|
||||
, "TensorArray"
|
||||
-- These should be possible to support by adding a bunch of
|
||||
-- overloads with a variable number of tuple arguments.
|
||||
, "Assert"
|
||||
, "BarrierTakeMany"
|
||||
, "Print"
|
||||
, "QueueEnqueue"
|
||||
, "QueueEnqueueMany"
|
||||
-- These have type ambiguities because one of the type arguments
|
||||
-- doesn't appear in the signature.
|
||||
, "ConditionalAccumulator"
|
||||
, "SparseConditionalAccumulator"
|
||||
-- Need list of types support.
|
||||
, "DecodeCSV"
|
||||
, "ParseExample"
|
||||
, "ParseSingleSequenceExample"
|
||||
, "Save"
|
||||
, "SaveSlices"
|
||||
, "SymbolicGradient"
|
||||
, "_ArrayToList"
|
||||
, "_ListToArray"
|
||||
-- Easy: support larger result tuples.
|
||||
, "Skipgram"
|
||||
]
|
30
tensorflow-core-ops/tensorflow-core-ops.cabal
Normal file
30
tensorflow-core-ops/tensorflow-core-ops.cabal
Normal file
|
@ -0,0 +1,30 @@
|
|||
name: tensorflow-core-ops
|
||||
version: 0.1.0.0
|
||||
synopsis: Haskell wrappers for Core Tensorflow Ops.
|
||||
description: Code generated signatures for the Ops in libtensorflow_c.
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Custom
|
||||
cabal-version: >=1.22
|
||||
|
||||
library
|
||||
exposed-modules: TensorFlow.GenOps.Core
|
||||
build-depends: Cabal >= 1.22 && < 1.25
|
||||
, bytestring
|
||||
, proto-lens == 0.1.*
|
||||
, tensorflow-opgen == 0.1.*
|
||||
, tensorflow == 0.1.*
|
||||
, base >= 4.7 && < 5
|
||||
, filepath
|
||||
, mainland-pretty
|
||||
, lens-family
|
||||
, text
|
||||
default-language: Haskell2010
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
113
tensorflow-mnist-input-data/Setup.hs
Normal file
113
tensorflow-mnist-input-data/Setup.hs
Normal file
|
@ -0,0 +1,113 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
-- | Downloads the MNIST data set and packages them as data files.
|
||||
module Main where
|
||||
|
||||
import Control.Monad (when)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Distribution.PackageDescription
|
||||
( GenericPackageDescription(packageDescription)
|
||||
, dataDir
|
||||
)
|
||||
import Distribution.Simple
|
||||
( UserHooks(..)
|
||||
, defaultMainWithHooks
|
||||
, simpleUserHooks
|
||||
)
|
||||
import System.IO (hPutStrLn, stderr)
|
||||
import System.FilePath ((</>))
|
||||
import System.Directory (doesFileExist)
|
||||
import qualified Crypto.Hash as Hash
|
||||
import qualified Data.ByteString.Lazy as B
|
||||
import qualified Network.HTTP as HTTP
|
||||
import qualified Network.URI as URI
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMainWithHooks downloadingDataFiles
|
||||
|
||||
downloadingDataFiles :: UserHooks
|
||||
downloadingDataFiles = hooks
|
||||
{ confHook = \gh@(g, _) c -> downloadFiles g >> confHook hooks gh c
|
||||
}
|
||||
where
|
||||
hooks = simpleUserHooks
|
||||
downloadFiles :: GenericPackageDescription -> IO ()
|
||||
downloadFiles g = do
|
||||
let dir = dataDir (packageDescription g)
|
||||
mapM_ (maybeDownload dir) fileInfos
|
||||
|
||||
maybeDownload :: FilePath -> (String, String) -> IO ()
|
||||
maybeDownload dataDir (basename, sha256) = do
|
||||
let filePath = dataDir </> basename
|
||||
exists <- doesFileExist filePath
|
||||
when (not exists) $ do
|
||||
let url = urlPrefix ++ basename
|
||||
hPutStrLn stderr ("Downloading " ++ url)
|
||||
httpDownload url filePath
|
||||
verify filePath sha256
|
||||
|
||||
httpDownload :: String -> FilePath -> IO ()
|
||||
httpDownload url outFile = do
|
||||
let uri = fromMaybe
|
||||
(error ("Can't be: invalid URI " ++ url))
|
||||
(URI.parseURI url)
|
||||
result <- HTTP.simpleHTTP (HTTP.defaultGETRequest_ uri)
|
||||
HTTP.getResponseCode result >>= \case
|
||||
(2, 0, 0) -> HTTP.getResponseBody result >>= B.writeFile outFile
|
||||
s -> error ( "Failed to download " ++ url ++ " error code " ++ show s
|
||||
++ helpfulMessage
|
||||
)
|
||||
|
||||
verify :: FilePath -> String -> IO ()
|
||||
verify filePath hash = do
|
||||
let sha256 = Hash.hashlazy :: B.ByteString -> Hash.Digest Hash.SHA256
|
||||
computed <- show . sha256 <$> B.readFile filePath
|
||||
when (hash /= computed) $
|
||||
error ( "Incorrect checksum for " ++ filePath
|
||||
++ "\nexpected " ++ hash
|
||||
++ "\ncomputed " ++ computed
|
||||
++ helpfulMessage
|
||||
)
|
||||
|
||||
urlPrefix = "http://yann.lecun.com/exdb/mnist/"
|
||||
|
||||
-- | File names relative to 'urlPrefix' and their sha256.
|
||||
fileInfos = [
|
||||
( "train-images-idx3-ubyte.gz"
|
||||
, "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
|
||||
)
|
||||
,
|
||||
( "train-labels-idx1-ubyte.gz"
|
||||
, "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
||||
)
|
||||
,
|
||||
( "t10k-images-idx3-ubyte.gz"
|
||||
, "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
||||
)
|
||||
,
|
||||
( "t10k-labels-idx1-ubyte.gz"
|
||||
, "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
||||
)
|
||||
]
|
||||
|
||||
helpfulMessage =
|
||||
unlines
|
||||
( ""
|
||||
: ""
|
||||
: "Please download the following URLs manually and put them in data/"
|
||||
: [ urlPrefix ++ h | (h, _) <- fileInfos ]
|
||||
)
|
|
@ -0,0 +1,31 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
module TensorFlow.Examples.MNIST.InputData
|
||||
( trainingImageData
|
||||
, trainingLabelData
|
||||
, testImageData
|
||||
, testLabelData
|
||||
) where
|
||||
|
||||
import Paths_tensorflow_mnist_input_data (getDataFileName)
|
||||
|
||||
-- | Download the files containing the canonical MNIST samples and labels.
|
||||
trainingImageData, trainingLabelData :: IO FilePath
|
||||
trainingImageData = getDataFileName "train-images-idx3-ubyte.gz"
|
||||
trainingLabelData = getDataFileName "train-labels-idx1-ubyte.gz"
|
||||
|
||||
testImageData, testLabelData :: IO FilePath
|
||||
testImageData = getDataFileName "t10k-images-idx3-ubyte.gz"
|
||||
testLabelData = getDataFileName "t10k-labels-idx1-ubyte.gz"
|
|
@ -0,0 +1,35 @@
|
|||
name: tensorflow-mnist-input-data
|
||||
version: 0.1.0.0
|
||||
synopsis: Downloader of input data for training MNIST.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Custom
|
||||
cabal-version: >=1.22
|
||||
-- These files are downloaded automatically by Setup.hs. If the
|
||||
-- automatic download fails, follow the instructions in error messages
|
||||
-- displayed by Setup.hs.
|
||||
data-dir: data
|
||||
data-files: *.gz
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.Examples.MNIST.InputData
|
||||
other-modules: Paths_tensorflow_mnist_input_data
|
||||
build-depends: Cabal
|
||||
, HTTP
|
||||
, base >= 4.7 && < 5
|
||||
, bytestring
|
||||
, cryptonite
|
||||
, directory
|
||||
, filepath
|
||||
, network-uri
|
||||
default-language: Haskell2010
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
3
tensorflow-mnist/Setup.hs
Normal file
3
tensorflow-mnist/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
|||
import Distribution.Simple
|
||||
|
||||
main = defaultMain
|
161
tensorflow-mnist/app/Main.hs
Normal file
161
tensorflow-mnist/app/Main.hs
Normal file
|
@ -0,0 +1,161 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
|
||||
import Control.Monad (zipWithM, when, forM, forM_)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import qualified Data.Text.IO as T
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
|
||||
import TensorFlow.Examples.MNIST.InputData
|
||||
import TensorFlow.Examples.MNIST.Parse
|
||||
|
||||
numPixels = 28^2 :: Int64
|
||||
numLabels = 10 :: Int64
|
||||
|
||||
-- | Create tensor with random values where the stddev depends on the width.
|
||||
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float)
|
||||
randomParam width (TF.Shape shape) =
|
||||
(* stddev) <$> TF.truncatedNormal (TF.vector shape)
|
||||
where
|
||||
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||
|
||||
-- Types must match due to model structure (sparseToDense requires
|
||||
-- index types to match)
|
||||
type LabelType = Int32
|
||||
type BatchSize = Int32
|
||||
|
||||
-- | Convert scalar labels to one-hot vectors.
|
||||
labelClasses :: TF.Tensor TF.Value LabelType
|
||||
-> LabelType
|
||||
-> BatchSize
|
||||
-> TF.Tensor TF.Value Float
|
||||
labelClasses labels numClasses batchSize =
|
||||
let indices = TF.range 0 (TF.scalar batchSize) 1
|
||||
concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1]
|
||||
in TF.sparseToDense concated
|
||||
(TF.constant [2] [batchSize, numClasses])
|
||||
1 {- ON value -}
|
||||
0 {- default (OFF) value -}
|
||||
|
||||
-- | Fraction of elements that differ between two vectors.
|
||||
errorRate :: Eq a => V.Vector a -> V.Vector a -> Double
|
||||
errorRate xs ys = fromIntegral (len - numCorrect) / fromIntegral len
|
||||
where
|
||||
numCorrect = V.length $ V.filter id $ V.zipWith (==) xs ys
|
||||
len = V.length xs
|
||||
|
||||
data Model = Model {
|
||||
train :: TF.TensorData Float -- ^ images
|
||||
-> TF.TensorData LabelType
|
||||
-> TF.Session ()
|
||||
, infer :: TF.TensorData Float -- ^ images
|
||||
-> TF.Session (V.Vector LabelType) -- ^ predictions
|
||||
}
|
||||
|
||||
createModel :: Int64 -> TF.Build Model
|
||||
createModel batchSize = do
|
||||
-- Inputs.
|
||||
images <- TF.placeholder [batchSize, numPixels]
|
||||
-- Hidden layer.
|
||||
let numUnits = 500
|
||||
hiddenWeights <-
|
||||
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
|
||||
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
|
||||
let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases
|
||||
let hidden = TF.relu hiddenZ
|
||||
-- Logits.
|
||||
logitWeights <-
|
||||
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
|
||||
logitBiases <- TF.zeroInitializedVariable [numLabels]
|
||||
let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases
|
||||
predict <- TF.render $ TF.cast $
|
||||
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
|
||||
|
||||
-- Create training action.
|
||||
labels <- TF.placeholder [batchSize]
|
||||
let labelVecs = labelClasses labels 10 (fromIntegral batchSize)
|
||||
loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
||||
grads <- TF.gradients loss params
|
||||
|
||||
let lr = TF.scalar $ 0.001 / fromIntegral batchSize
|
||||
applyGrad param grad
|
||||
= TF.assign param $ param `TF.sub` (lr * grad)
|
||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
||||
|
||||
return Model {
|
||||
train = \imFeed lFeed -> TF.runWithFeeds_ [
|
||||
TF.feed images imFeed
|
||||
, TF.feed labels lFeed
|
||||
] trainStep
|
||||
, infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict
|
||||
}
|
||||
|
||||
main = TF.runSession $ do
|
||||
-- Read training and test data.
|
||||
trainingImages <- liftIO (readMNISTSamples =<< trainingImageData)
|
||||
trainingLabels <- liftIO (readMNISTLabels =<< trainingLabelData)
|
||||
testImages <- liftIO (readMNISTSamples =<< testImageData)
|
||||
testLabels <- liftIO (readMNISTLabels =<< testLabelData)
|
||||
|
||||
let batchSize = 100 :: Int64
|
||||
|
||||
-- Create the model.
|
||||
model <- TF.build $ createModel batchSize
|
||||
|
||||
-- Helpers for generate batches.
|
||||
let selectBatch i xs = take size $ drop (i * size) $ cycle xs
|
||||
where size = fromIntegral batchSize
|
||||
let getImageBatch i xs = TF.encodeTensorData
|
||||
[batchSize, numPixels]
|
||||
$ fromIntegral <$> mconcat (selectBatch i xs)
|
||||
let getExpectedLabelBatch i xs =
|
||||
fromIntegral <$> V.fromList (selectBatch i xs)
|
||||
|
||||
-- Train.
|
||||
forM_ ([0..1000] :: [Int]) $ \i -> do
|
||||
let images = getImageBatch i trainingImages
|
||||
labels = getExpectedLabelBatch i trainingLabels
|
||||
train model images (TF.encodeTensorData [batchSize] labels)
|
||||
when (i `mod` 100 == 0) $ do
|
||||
preds <- infer model images
|
||||
liftIO $ putStrLn $
|
||||
"training error " ++ show (errorRate preds labels * 100)
|
||||
liftIO $ putStrLn ""
|
||||
|
||||
-- Test.
|
||||
let numTestBatches = length testImages `div` fromIntegral batchSize
|
||||
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do
|
||||
infer model (getImageBatch i testImages)
|
||||
let testExpected = fromIntegral <$> V.fromList testLabels
|
||||
liftIO $ putStrLn $
|
||||
"test error " ++ show (errorRate testPreds testExpected * 100)
|
||||
|
||||
-- Show some predictions.
|
||||
liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do
|
||||
putStrLn ""
|
||||
T.putStrLn $ drawMNIST $ testImages !! i
|
||||
putStrLn $ "expected " ++ show (testLabels !! i)
|
||||
putStrLn $ " got " ++ show (testPreds V.! i)
|
BIN
tensorflow-mnist/data/MNIST.pb
Normal file
BIN
tensorflow-mnist/data/MNIST.pb
Normal file
Binary file not shown.
BIN
tensorflow-mnist/data/MNISTBias.ckpt
Normal file
BIN
tensorflow-mnist/data/MNISTBias.ckpt
Normal file
Binary file not shown.
BIN
tensorflow-mnist/data/MNISTWts.ckpt
Normal file
BIN
tensorflow-mnist/data/MNISTWts.ckpt
Normal file
Binary file not shown.
|
@ -0,0 +1,30 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
-- | Paths to test helper files.
|
||||
module TensorFlow.Examples.MNIST.TrainedGraph where
|
||||
|
||||
import Paths_tensorflow_mnist (getDataFileName)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.ByteString.Char8 (pack)
|
||||
|
||||
-- | File containing a Tensorflow serialized proto of MNIST.
|
||||
mnistPb :: IO FilePath
|
||||
mnistPb = getDataFileName "data/MNIST.pb"
|
||||
|
||||
-- | Files containing pre-trained weights for MNIST.
|
||||
wtsCkpt, biasCkpt :: IO ByteString
|
||||
wtsCkpt = pack <$> getDataFileName "data/MNISTWts.ckpt"
|
||||
biasCkpt = pack <$> getDataFileName "data/MNISTBias.ckpt"
|
96
tensorflow-mnist/src/TensorFlow/Examples/MNIST/Parse.hs
Normal file
96
tensorflow-mnist/src/TensorFlow/Examples/MNIST/Parse.hs
Normal file
|
@ -0,0 +1,96 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE TypeSynonymInstances #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE ViewPatterns #-}
|
||||
|
||||
module TensorFlow.Examples.MNIST.Parse where
|
||||
|
||||
import Control.Monad (when, liftM)
|
||||
import Data.Binary.Get (Get, runGet, getWord32be, getLazyByteString)
|
||||
import Data.ByteString.Lazy (toStrict, readFile)
|
||||
import Data.List.Split (chunksOf)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.ProtoLens (Message, decodeMessageOrDie)
|
||||
import Data.Text (Text)
|
||||
import Data.Word (Word8, Word32)
|
||||
import Prelude hiding (readFile)
|
||||
import qualified Codec.Compression.GZip as GZip
|
||||
import qualified Data.ByteString.Lazy as L
|
||||
import qualified Data.Text as Text
|
||||
import qualified Data.Vector as V
|
||||
|
||||
-- | Utilities specific to MNIST.
|
||||
type MNIST = V.Vector Word8
|
||||
|
||||
-- | Produces a unicode rendering of the MNIST digit sample.
|
||||
drawMNIST :: MNIST -> Text
|
||||
drawMNIST = chunk . block
|
||||
where
|
||||
block :: V.Vector Word8 -> Text
|
||||
block (V.splitAt 1 -> ([0], xs)) = " " <> block xs
|
||||
block (V.splitAt 1 -> ([n], xs)) = c `Text.cons` block xs
|
||||
where c = "\9617\9618\9619\9608" !! fromIntegral (n `div` 64)
|
||||
block (V.splitAt 1 -> _) = ""
|
||||
chunk :: Text -> Text
|
||||
chunk "" = "\n"
|
||||
chunk xs = Text.take 28 xs <> "\n" <> chunk (Text.drop 28 xs)
|
||||
|
||||
-- | Check's the file's endianess, throwing an error if it's not as expected.
|
||||
checkEndian :: Get ()
|
||||
checkEndian = do
|
||||
magic <- getWord32be
|
||||
when (magic `notElem` ([2049, 2051] :: [Word32])) $
|
||||
fail "Expected big endian, but image file is little endian."
|
||||
|
||||
-- | Reads an MNIST file and returns a list of samples.
|
||||
readMNISTSamples :: FilePath -> IO [MNIST]
|
||||
readMNISTSamples path = do
|
||||
raw <- GZip.decompress <$> readFile path
|
||||
return $ runGet getMNIST raw
|
||||
where
|
||||
getMNIST :: Get [MNIST]
|
||||
getMNIST = do
|
||||
checkEndian
|
||||
-- Parse header data.
|
||||
cnt <- liftM fromIntegral getWord32be
|
||||
rows <- liftM fromIntegral getWord32be
|
||||
cols <- liftM fromIntegral getWord32be
|
||||
-- Read all of the data, then split into samples.
|
||||
pixels <- getLazyByteString $ fromIntegral $ cnt * rows * cols
|
||||
return $ V.fromList <$> chunksOf (rows * cols) (L.unpack pixels)
|
||||
|
||||
-- | Reads a list of MNIST labels from a file and returns them.
|
||||
readMNISTLabels :: FilePath -> IO [Word8]
|
||||
readMNISTLabels path = do
|
||||
raw <- GZip.decompress <$> readFile path
|
||||
return $ runGet getLabels raw
|
||||
where getLabels :: Get [Word8]
|
||||
getLabels = do
|
||||
checkEndian
|
||||
-- Parse header data.
|
||||
cnt <- liftM fromIntegral getWord32be
|
||||
-- Read all of the labels.
|
||||
L.unpack <$> getLazyByteString cnt
|
||||
|
||||
readMessageFromFileOrDie :: Message m => FilePath -> IO m
|
||||
readMessageFromFileOrDie path = do
|
||||
pb <- readFile path
|
||||
return $ decodeMessageOrDie $ toStrict pb
|
||||
|
||||
-- TODO: Write a writeMessageFromFileOrDie and read/write non-lethal
|
||||
-- versions.
|
80
tensorflow-mnist/tensorflow-mnist.cabal
Normal file
80
tensorflow-mnist/tensorflow-mnist.cabal
Normal file
|
@ -0,0 +1,80 @@
|
|||
name: tensorflow-mnist
|
||||
version: 0.1.0.0
|
||||
synopsis: TensorFlow demo application for learning MNIST model.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Simple
|
||||
cabal-version: >=1.22
|
||||
data-files: data/*.ckpt
|
||||
, data/*.pb
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
, src-data
|
||||
exposed-modules: TensorFlow.Examples.MNIST.Parse
|
||||
, TensorFlow.Examples.MNIST.TrainedGraph
|
||||
other-modules: Paths_tensorflow_mnist
|
||||
build-depends: proto-lens == 0.1.*
|
||||
, base >= 4.7 && < 5
|
||||
, binary
|
||||
, bytestring
|
||||
, filepath
|
||||
, lens-family
|
||||
, containers
|
||||
, split
|
||||
, tensorflow-proto == 0.1.*
|
||||
, tensorflow-core-ops == 0.1.*
|
||||
, tensorflow
|
||||
, text
|
||||
, vector
|
||||
, zlib
|
||||
default-language: Haskell2010
|
||||
|
||||
executable Main
|
||||
default-language: Haskell2010
|
||||
main-is: Main.hs
|
||||
hs-source-dirs: app
|
||||
build-depends: base
|
||||
, bytestring
|
||||
, filepath
|
||||
, lens-family
|
||||
, proto-lens
|
||||
, tensorflow
|
||||
, tensorflow-mnist
|
||||
, tensorflow-mnist-input-data
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, text
|
||||
, transformers
|
||||
, vector
|
||||
|
||||
Test-Suite ParseTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: ParseTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, bytestring
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-mnist
|
||||
, tensorflow-mnist-input-data
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, text
|
||||
, transformers
|
||||
, vector
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
170
tensorflow-mnist/tests/ParseTest.hs
Normal file
170
tensorflow-mnist/tests/ParseTest.hs
Normal file
|
@ -0,0 +1,170 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int64)
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text.IO as Text
|
||||
import Lens.Family2 ((&), (.~), (^.))
|
||||
import Prelude hiding (abs)
|
||||
import Proto.Tensorflow.Core.Framework.Graph
|
||||
( GraphDef(..)
|
||||
, version
|
||||
, node )
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
( NodeDef(..)
|
||||
, op )
|
||||
import System.IO as IO
|
||||
import TensorFlow.Examples.MNIST.InputData
|
||||
import TensorFlow.Examples.MNIST.Parse
|
||||
import TensorFlow.Examples.MNIST.TrainedGraph
|
||||
import TensorFlow.Build
|
||||
( asGraphDef
|
||||
, addGraphDef
|
||||
, render
|
||||
)
|
||||
import TensorFlow.Tensor
|
||||
( Tensor(..)
|
||||
, Ref
|
||||
, Value
|
||||
, feed
|
||||
, TensorKind(..)
|
||||
, tensorFromName
|
||||
)
|
||||
import TensorFlow.Ops
|
||||
import TensorFlow.Nodes (unScalar)
|
||||
import TensorFlow.Session
|
||||
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
||||
import TensorFlow.Types (TensorType(..), Shape(..))
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?), Assertion)
|
||||
import Google.Test
|
||||
import qualified Data.Vector as V
|
||||
|
||||
-- | Test that a file can be read and the GraphDef proto correctly parsed.
|
||||
testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
|
||||
-- Check the function on a known well-formatted file.
|
||||
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
|
||||
-- Simple field read.
|
||||
1 @=? mnist^.version
|
||||
-- Count the number of nodes.
|
||||
let nodes :: [NodeDef]
|
||||
nodes = mnist^.node
|
||||
100 @=? length nodes
|
||||
-- Check that the expected op is found at an arbitrary index.
|
||||
"Variable" @=? nodes!!6^.op
|
||||
|
||||
-- | Parse the test set for label and image data. Will only fail if the file is
|
||||
-- missing or incredibly corrupt.
|
||||
testReadMNIST = testCase "testReadMNIST" $ do
|
||||
imageData <- readMNISTSamples =<< testImageData
|
||||
10000 @=? length imageData
|
||||
labelData <- readMNISTLabels =<< testLabelData
|
||||
10000 @=? length labelData
|
||||
|
||||
testNodeName :: Text -> Tensor v a -> Assertion
|
||||
testNodeName n g = n @=? opName
|
||||
where
|
||||
opName = head (gDef^.node)^.op
|
||||
gDef = asGraphDef $ render g
|
||||
|
||||
testGraphDefGen = testCase "testGraphDefGen" $ do
|
||||
-- Test the inferred operation type.
|
||||
let f0 :: Tensor Value Float
|
||||
f0 = 0
|
||||
testNodeName "Const" f0
|
||||
testNodeName "Add" $ 1 + f0
|
||||
testNodeName "Mul" $ 1 * f0
|
||||
testNodeName "Sub" $ 1 - f0
|
||||
testNodeName "Abs" $ abs f0
|
||||
testNodeName "Sign" $ signum f0
|
||||
testNodeName "Neg" $ -f0
|
||||
-- Test the grouping.
|
||||
testNodeName "Add" $ 1 + f0 * 2
|
||||
testNodeName "Add" $ 1 + (f0 * 2)
|
||||
testNodeName "Mul" $ (1 + f0) * 2
|
||||
|
||||
-- | Convert a simple graph to GraphDef, load it, run it, and check the output.
|
||||
testGraphDefExec = testCase "testGraphDefExec" $ do
|
||||
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
||||
runSession $ do
|
||||
build $ addGraphDef graphDef
|
||||
x <- run $ tensorFromName ValueKind "Mul_2"
|
||||
liftIO $ (50 :: Float) @=? unScalar x
|
||||
|
||||
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
|
||||
-- sample data.
|
||||
testMNISTExec = testCase "testMNISTExec" $ do
|
||||
-- Switch to unicode to enable pretty printing of MNIST digits.
|
||||
IO.hSetEncoding IO.stdout IO.utf8
|
||||
|
||||
-- Parse the Graph definition, samples, & labels from files.
|
||||
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
|
||||
mnistSamples <- readMNISTSamples =<< testImageData
|
||||
mnistLabels <- readMNISTLabels =<< testLabelData
|
||||
|
||||
-- Select a sample to run on and convert it into a TensorData of Floats.
|
||||
let idx = 12
|
||||
sample :: MNIST
|
||||
sample = mnistSamples !! idx
|
||||
label = mnistLabels !! idx
|
||||
tensorSample = encodeTensorData (Shape [1,784]) floatSample
|
||||
where
|
||||
floatSample :: V.Vector Float
|
||||
floatSample = V.map fromIntegral sample
|
||||
Text.putStrLn $ drawMNIST sample
|
||||
|
||||
-- Execute the graph on the sample data.
|
||||
runSession $ do
|
||||
-- The version of this session is 0, but the version of the graph is 1.
|
||||
-- Change the graph version to 0 so they're compatible.
|
||||
build $ addGraphDef $ mnist & version .~ 0
|
||||
-- Define nodes that restore saved weights and biases.
|
||||
let bias, wts :: Tensor Ref Float
|
||||
bias = tensorFromName RefKind "Variable"
|
||||
wts = tensorFromName RefKind "weights"
|
||||
wtsCkptPath <- liftIO wtsCkpt
|
||||
biasCkptPath <- liftIO biasCkpt
|
||||
-- Run those restoring nodes on the graph in the current session.
|
||||
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
|
||||
[ restore wtsCkptPath wts
|
||||
, restoreFromName biasCkptPath "bias" bias
|
||||
]
|
||||
-- Encode the expected sample data as one-hot data.
|
||||
let ty = encodeTensorData [10] oneHotLabels
|
||||
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
|
||||
updates = [(fromIntegral label, 1)]
|
||||
let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample
|
||||
, feed (tensorFromName ValueKind "y-input") ty
|
||||
]
|
||||
-- Run the graph with the input feeds and read the ArgMax'd result from
|
||||
-- the test (not training) side of the evaluation.
|
||||
x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax"
|
||||
-- Print the trained model's predicted outcome.
|
||||
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
|
||||
++ "Prediction: " ++ show (unScalar x :: Int64)
|
||||
-- Check whether the prediction matches the expectation.
|
||||
liftIO $ (fromInteger . toInteger $ label :: Int64) @=? unScalar x
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testReadMessageFromFileOrDie
|
||||
, testReadMNIST
|
||||
, testGraphDefGen
|
||||
, testGraphDefExec
|
||||
, testMNISTExec]
|
3
tensorflow-opgen/Setup.hs
Normal file
3
tensorflow-opgen/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
|||
import Distribution.Simple
|
||||
|
||||
main = defaultMain
|
457
tensorflow-opgen/src/TensorFlow/OpGen.hs
Normal file
457
tensorflow-opgen/src/TensorFlow/OpGen.hs
Normal file
|
@ -0,0 +1,457 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
-- | Rendering of TensorFlow operations as Haskell functions.
|
||||
|
||||
module TensorFlow.OpGen
|
||||
( OpGenFlags(..)
|
||||
, docOpList
|
||||
, flagParser)
|
||||
where
|
||||
|
||||
import Prelude hiding (head, tail)
|
||||
|
||||
import Control.Monad (guard)
|
||||
import Data.Char (toLower, toUpper)
|
||||
import Data.Foldable (toList)
|
||||
import Data.Maybe (fromMaybe, maybeToList)
|
||||
import Data.ProtoLens (def, showMessage)
|
||||
import Data.List.NonEmpty (NonEmpty((:|)), head)
|
||||
import qualified Data.List.NonEmpty as NE
|
||||
import Lens.Family2 ((^.), (.~), (&), view)
|
||||
import Options.Applicative (Parser, help, long, strOption, value)
|
||||
import Proto.Tensorflow.Core.Framework.OpDef
|
||||
( OpList
|
||||
, OpDef
|
||||
, OpDef'ArgDef
|
||||
, attr
|
||||
, description
|
||||
, inputArg
|
||||
, name
|
||||
, numberAttr
|
||||
, op
|
||||
, outputArg
|
||||
, summary
|
||||
, type'
|
||||
, typeAttr
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||
import System.FilePath (takeBaseName)
|
||||
import TensorFlow.OpGen.AttrVal
|
||||
(AttrDef
|
||||
, AttrCase(..)
|
||||
, AttrTemplate(..)
|
||||
, Template
|
||||
, attrDef
|
||||
, attrOriginal
|
||||
, attrTemplate
|
||||
, templateDefault
|
||||
, templateRestrictions
|
||||
)
|
||||
import Text.PrettyPrint.Mainland
|
||||
( Doc
|
||||
, (<>)
|
||||
, (<+>)
|
||||
, (</>)
|
||||
, (<+/>)
|
||||
, brackets
|
||||
, comma
|
||||
, commasep
|
||||
, dquotes
|
||||
, empty
|
||||
, enclose
|
||||
, flatten
|
||||
, folddoc
|
||||
, hang
|
||||
, indent
|
||||
, int
|
||||
, parens
|
||||
, sep
|
||||
, stack
|
||||
, strictText
|
||||
, tuple
|
||||
)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import qualified Data.Text as Text
|
||||
import qualified Data.Semigroup as Semigroup
|
||||
import Data.Text (Text)
|
||||
|
||||
data OpGenFlags = OpGenFlags
|
||||
{ outputFile :: String
|
||||
, prefix :: String
|
||||
, excludeList :: String
|
||||
}
|
||||
|
||||
flagParser :: Parser OpGenFlags
|
||||
flagParser = OpGenFlags
|
||||
<$> strOption (mconcat [ long "output"
|
||||
, help "File to write."
|
||||
])
|
||||
<*> strOption (mconcat [ long "prefix"
|
||||
, help "Haskell package prefix to use"
|
||||
])
|
||||
<*> strOption (mconcat [ long "exclude_list"
|
||||
, value ""
|
||||
, help "Comma separated Ops names to ignore"
|
||||
])
|
||||
|
||||
|
||||
docOpList :: OpGenFlags -> OpList -> Doc
|
||||
docOpList flags opList =
|
||||
stack [ "{-# LANGUAGE ConstraintKinds #-}"
|
||||
, "{-# LANGUAGE DataKinds #-}"
|
||||
, "{-# LANGUAGE FlexibleInstances #-}"
|
||||
, "{-# LANGUAGE OverloadedStrings #-}"
|
||||
, "{-# LANGUAGE RankNTypes #-}"
|
||||
, "{-# LANGUAGE ScopedTypeVariables #-}"
|
||||
, "module" <+> strictText moduleName <+> "where"
|
||||
, empty
|
||||
, imports
|
||||
, empty
|
||||
, folddoc (\x y -> x </> empty </> y)
|
||||
(map renderDef $
|
||||
filter (not . flip elem exclusions . view name) $
|
||||
toList $ opList ^. op)
|
||||
]
|
||||
where moduleName =
|
||||
Text.pack (prefix flags) <> "." <> camelCase
|
||||
-- Discards the optional trailing _op_lib
|
||||
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
|
||||
shortName = Text.pack (takeBaseName $ outputFile flags)
|
||||
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
|
||||
|
||||
camelCase s = Text.concat $ map upCase
|
||||
$ filter (/= "ops")
|
||||
$ Text.splitOn "_" s
|
||||
|
||||
-- | Upper-case the given text.
|
||||
upCase :: Text -> Text
|
||||
upCase = forceCase toUpper
|
||||
|
||||
-- | Lower-case the given name, and prevent it from overlapping with a reserved
|
||||
-- Haskell name.
|
||||
lowCase :: Text -> Text
|
||||
lowCase = replaceReservedName . forceCase toLower
|
||||
|
||||
forceCase :: (Char -> Char) -> Text -> Text
|
||||
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
||||
(Text.uncons s)
|
||||
|
||||
imports = stack [
|
||||
"import Data.ByteString (ByteString)"
|
||||
, "import Data.Complex (Complex)"
|
||||
, "import Data.Int (Int8, Int16, Int32, Int64)"
|
||||
, "import Data.Word (Word8, Word16)"
|
||||
, "import Lens.Family2 ((.~), (&))"
|
||||
, "import TensorFlow.Build"
|
||||
, "import TensorFlow.BuildOp"
|
||||
, "import TensorFlow.Tensor"
|
||||
, "import TensorFlow.Types"
|
||||
]
|
||||
|
||||
renderDef :: OpDef -> Doc
|
||||
renderDef d =
|
||||
stack [
|
||||
haddocks
|
||||
, n <+> "::" <+> hang 0 (typeSig d)
|
||||
, n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" </> -- args are indented
|
||||
-- the body needs to be indented wrt the name
|
||||
indent indentation functionBody
|
||||
, extras -- just for debug
|
||||
]
|
||||
where
|
||||
n = strictText $ fixOpName (d ^. name)
|
||||
args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs
|
||||
tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg]
|
||||
fixOpName = lowCase
|
||||
funcGuard = "eqLengthGuard" <+> brackets (commasep entries)
|
||||
where
|
||||
entries =
|
||||
[ parens $ quotedText nAttr <> comma <+>
|
||||
brackets (commasep $ toList $
|
||||
NE.map renderTensorName tensorNames)
|
||||
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
|
||||
]
|
||||
renderTensorName x = parens $ quotedText x <> comma <+>
|
||||
"length" <+> strictText x
|
||||
-- Uses hang 0 to align the argument vertically on multiple lines.
|
||||
functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
</> indent indentation (sep tensorArgs)
|
||||
buildFunction
|
||||
| null outputListsSizes = "buildOp"
|
||||
| otherwise = "buildListOp" <+> brackets (commasep outputListsSizes)
|
||||
outputListsSizes = [ strictText numberAttrName
|
||||
| o <- d ^. outputArg
|
||||
, let numberAttrName = o ^. numberAttr
|
||||
, not (Text.null numberAttrName) &&
|
||||
numberAttrName `Map.member` mandatoryAttrMap d
|
||||
]
|
||||
buildOpParts =
|
||||
"opDef" <+> quotedText (d ^. name) :
|
||||
-- Renders tensor arguments.
|
||||
[ "& opAttr" <+> quotedText tfName <+>
|
||||
".~ tensorType (undefined ::" <+> strictText hsName <> ")"
|
||||
| (tfName, (hsName, _)) <- Map.toList typeMap
|
||||
] ++
|
||||
-- Renders mandatory attributes as function parameters.
|
||||
[ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName
|
||||
| (tfName, hsName) <- mandatoryAttrs
|
||||
] ++
|
||||
-- Renders sizes of tensor list types having number_attr.
|
||||
[ "& opAttr" <+> quotedText nAttr <+> ".~" <+>
|
||||
"(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)"
|
||||
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
|
||||
]
|
||||
mandatoryAttrs = [(strictText tf, strictText hs)
|
||||
| (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d)
|
||||
]
|
||||
haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description)
|
||||
extras = enclose "{-\n" "\n-}" $
|
||||
strictText $ Text.pack $
|
||||
showMessage ((def :: OpDef)
|
||||
& inputArg .~ (d ^. inputArg)
|
||||
& outputArg .~ (d ^. outputArg)
|
||||
& attr .~ (d ^. attr))
|
||||
typeMap = opDefTypeMap d
|
||||
|
||||
-- | Makes a quoted string doc out of the given text value.
|
||||
quotedText :: Text.Text -> Doc
|
||||
quotedText = dquotes . strictText
|
||||
|
||||
-- | typeSig renders the type signature of the given OpDef.
|
||||
typeSig :: OpDef -> Doc
|
||||
typeSig d =
|
||||
foralls <+> constraints <+/>
|
||||
signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs])
|
||||
where
|
||||
foralls | Map.null typeMap = empty
|
||||
| otherwise =
|
||||
"forall"
|
||||
<+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap))
|
||||
<+> "."
|
||||
constraints | Map.null typeMap = empty
|
||||
| otherwise =
|
||||
tuple (concatMap
|
||||
(\(t, aDef) ->
|
||||
"TensorType" <+> strictText t
|
||||
: maybeToList (oneOfRestrictions aDef t))
|
||||
(Map.elems typeMap)) <+> "=>"
|
||||
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
|
||||
refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)]
|
||||
tensorArg refType arg = wrapArg refType arg <+>
|
||||
hang 0 ("-- ^" <+> argComment arg)
|
||||
-- Argument type is a list of tensors if number_attr is set;
|
||||
-- otherwise it's a single Tensor.
|
||||
wrapArg refType arg =
|
||||
if Text.null (arg ^. numberAttr) then typ else brackets typ
|
||||
where typ = tensorType refType arg
|
||||
tensorType refType arg =
|
||||
"Tensor" <+> refType <+> maybe directType strictText indirectType
|
||||
where
|
||||
indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap)
|
||||
directType = dtTypeToDoc (arg ^. type')
|
||||
outputs =
|
||||
case d ^. outputArg of
|
||||
[] -> "ControlNode"
|
||||
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
|
||||
os -> renderTupleResult os
|
||||
wrappedOutput = wrapArg "Value"
|
||||
-- Tuple result case is rendered differently to give
|
||||
-- individual elements their own comments.
|
||||
renderTupleResult os =
|
||||
stack $ [ tuple (map wrappedOutput os)
|
||||
, flatten commentSummary
|
||||
] ++ map commentDetails os
|
||||
where
|
||||
commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os]
|
||||
commentDetails o =
|
||||
stack [ "--"
|
||||
, "-- *" <+> argComment o
|
||||
]
|
||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||
mandatoryAttrInputs = [
|
||||
dtTypeToDoc dtType <+>
|
||||
hang 0 ("-- ^" <+> argComment' tfName descr)
|
||||
| (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d]
|
||||
typeMap = opDefTypeMap d
|
||||
|
||||
-- | Returns the type restriction for the given tensor type if the
|
||||
-- set of allowed types is not empty (i.e. restricted).
|
||||
oneOfRestrictions :: AttrDef -> Text -> Maybe Doc
|
||||
oneOfRestrictions aDef tName = do
|
||||
typs <- onAttrType (^. templateRestrictions) aDef
|
||||
guard $ not $ null typs
|
||||
let typeList = commasep $ map strictText $
|
||||
Set.toList $ Set.fromList $
|
||||
map dtTypeToHaskell typs
|
||||
return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName
|
||||
|
||||
-- | Identifies the attributes used as tensor cardinalities. In such
|
||||
-- cases a list of tensors is supplied as an input_arg. The number of
|
||||
-- such inputs is communicated as a separate opAttr.
|
||||
-- The result key is TensorFlow attribute name and the value is the
|
||||
-- tensor names which have number_attr set to the result key.
|
||||
numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text)
|
||||
numberAttrMap d = Map.fromListWith (Semigroup.<>) [
|
||||
(nAttr, replaceReservedName (inp ^. name) :| [])
|
||||
| inp <- d ^. inputArg
|
||||
, nAttr <- [inp ^. numberAttr]
|
||||
, not (Text.null nAttr)
|
||||
]
|
||||
|
||||
argComment :: OpDef'ArgDef -> Doc
|
||||
argComment arg = argComment' (arg ^. name) (arg ^. description)
|
||||
|
||||
argComment' :: Text.Text -> Text.Text -> Doc
|
||||
argComment' argName argDesc =
|
||||
bold argName <> splitMultilineText (":" <+>) argDesc
|
||||
|
||||
bold :: Text.Text -> Doc
|
||||
bold n = strictText ("__" <> n <> "__")
|
||||
|
||||
opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef)
|
||||
opDefTypeMap d =
|
||||
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
|
||||
|
||||
attrList :: OpDef -> [(Text.Text, AttrDef)]
|
||||
attrList d = [(a ^. name, attrDef a) | a <- d ^. attr]
|
||||
|
||||
isType :: AttrDef -> Bool
|
||||
isType = fromMaybe False . onAttrType (const True)
|
||||
|
||||
-- | Applies the given function to the data type. Is this a Prism?
|
||||
onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a
|
||||
onAttrType f x = case x ^. attrTemplate of
|
||||
AttrSingle (AttrType a) -> Just (f a)
|
||||
_ -> Nothing
|
||||
|
||||
-- | mandatoryAttrMap contains the attributes chosen by
|
||||
-- isMandatoryAttr, excluding those which are derived from list of
|
||||
-- tensor arguments. The key is the TF name of the attribute. The
|
||||
-- value tuple is (haskell name, TF type, attribute comment).
|
||||
mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text)
|
||||
mandatoryAttrMap d =
|
||||
Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description))
|
||||
| (n, a) <- attrList d
|
||||
, Just dtType <- [isMandatoryAttr a]
|
||||
-- Excludes the attributes rendered as list lengths.
|
||||
, n `Map.notMember` numberAttrMap d
|
||||
]
|
||||
|
||||
-- | Inspects the attribute and if it is one of the implemented
|
||||
-- non-tensor values lacking default, then returns Just the TF type.
|
||||
isMandatoryAttr :: AttrDef -> Maybe DataType
|
||||
isMandatoryAttr x =
|
||||
case x ^. attrTemplate of
|
||||
AttrSingle (AttrBool y) -> noDefault DT_BOOL y
|
||||
AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y
|
||||
AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y
|
||||
_ -> Nothing
|
||||
where
|
||||
noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault)
|
||||
|
||||
dtTypeToDoc :: DataType -> Doc
|
||||
dtTypeToDoc = strictText . dtTypeToHaskell
|
||||
|
||||
-- NOTE: The cases of this function should be kept in sync with
|
||||
-- TensorFlow.Types.AllTensorTypes.
|
||||
dtTypeToHaskell :: DataType -> Text.Text
|
||||
dtTypeToHaskell DT_BOOL = "Bool"
|
||||
dtTypeToHaskell DT_BFLOAT16 = "Data.Word.Word16"
|
||||
dtTypeToHaskell DT_COMPLEX128 = "(Data.Complex.Complex Double)"
|
||||
dtTypeToHaskell DT_COMPLEX64 = "(Data.Complex.Complex Float)"
|
||||
dtTypeToHaskell DT_DOUBLE = "Double"
|
||||
dtTypeToHaskell DT_FLOAT = "Float"
|
||||
dtTypeToHaskell DT_INT16 = "Data.Int.Int16"
|
||||
dtTypeToHaskell DT_INT32 = "Data.Int.Int32"
|
||||
dtTypeToHaskell DT_INT64 = "Data.Int.Int64"
|
||||
dtTypeToHaskell DT_INT8 = "Data.Int.Int8"
|
||||
dtTypeToHaskell DT_QINT32 = "Data.Int.Int32" -- TODO(gnezdo): make unique
|
||||
dtTypeToHaskell DT_QINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique
|
||||
dtTypeToHaskell DT_QINT16 = "Data.Int.Int16" -- TODO(gnezdo): make unique
|
||||
dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique
|
||||
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
|
||||
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
||||
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
||||
dtTypeToHaskell x =
|
||||
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
||||
|
||||
-- | haddockComment escapes TensorFlow doc strings into haddock.
|
||||
-- TODO(gnezdo): deal with the markup.
|
||||
haddockComment :: Text.Text -> Doc
|
||||
haddockComment = strictText
|
||||
|
||||
multilineComment :: Text.Text -> Text.Text -> Doc
|
||||
multilineComment summary' detail =
|
||||
haddockComment summary' </>
|
||||
splitMultilineText insertParagraphAndComment detail
|
||||
where insertParagraphAndComment x = "--" </> "--" <+> x
|
||||
|
||||
-- | Converts the given multi-line detail string into
|
||||
-- a multi-line haddock. Applies the given lead to the
|
||||
-- first line. Returns an empty document for empty detail.
|
||||
splitMultilineText :: (Doc -> Doc) -> Text.Text -> Doc
|
||||
splitMultilineText lead detail =
|
||||
case Text.lines detail of
|
||||
[] -> empty
|
||||
(l : ls) -> stack $ lead (haddockComment l)
|
||||
: map (("--" <+>) . haddockComment) ls
|
||||
|
||||
replaceReservedName :: Text -> Text
|
||||
replaceReservedName n
|
||||
| n `Set.member` reservedKeywords = n <> "'"
|
||||
| otherwise = n
|
||||
|
||||
indentation = 4
|
||||
|
||||
reservedKeywords :: Set.Set Text
|
||||
reservedKeywords = Set.fromList $
|
||||
-- Haskell2010 keywords:
|
||||
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
|
||||
-- We don't include keywords that are allowed to be variable names,
|
||||
-- in particular: "as", "forall", and "hiding".
|
||||
[ "case"
|
||||
, "class"
|
||||
, "data"
|
||||
, "default"
|
||||
, "deriving"
|
||||
, "do"
|
||||
, "else"
|
||||
, "foreign"
|
||||
, "if"
|
||||
, "import"
|
||||
, "in"
|
||||
, "infix"
|
||||
, "infixl"
|
||||
, "infixr"
|
||||
, "instance"
|
||||
, "let"
|
||||
, "module"
|
||||
, "newtype"
|
||||
, "of"
|
||||
, "then"
|
||||
, "type"
|
||||
, "where"
|
||||
]
|
||||
++ -- Nonstandard extensions
|
||||
[ "mdo" -- RecursiveDo
|
||||
, "rec" -- Arrows, RecursiveDo
|
||||
, "proc" -- Arrows
|
||||
]
|
120
tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs
Normal file
120
tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs
Normal file
|
@ -0,0 +1,120 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
-- | Wrapping of TensorFlow attributes into Haskell entities.
|
||||
module TensorFlow.OpGen.AttrVal
|
||||
(AttrDef
|
||||
, AttrCase(..)
|
||||
, AttrTemplate(..)
|
||||
, Template
|
||||
, attrDef
|
||||
, attrOriginal
|
||||
, attrTemplate
|
||||
, templateDefault
|
||||
, templateRestrictions
|
||||
) where
|
||||
|
||||
import Data.Int (Int64)
|
||||
import Data.Monoid ((<>))
|
||||
import Lens.Family2 (Lens', (^.))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue
|
||||
import Proto.Tensorflow.Core.Framework.OpDef as OpDef
|
||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.Text as Text
|
||||
|
||||
-- | Specifies the optional default value and a set of allowed values
|
||||
-- for the given type.
|
||||
data Template a = Template {
|
||||
_templateDefault :: Maybe a
|
||||
-- ^ The default value (mandatory if unspecified)
|
||||
, _templateRestrictions :: [a]
|
||||
-- ^ The allowed set of values, empty if no restrictions
|
||||
}
|
||||
|
||||
templateDefault :: Lens' (Template a) (Maybe a)
|
||||
templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x })
|
||||
|
||||
templateRestrictions :: Lens' (Template a) [a]
|
||||
templateRestrictions = lens _templateRestrictions
|
||||
(\g x -> g { _templateRestrictions = x })
|
||||
|
||||
data UnusedTensor
|
||||
|
||||
data AttrCase f
|
||||
= AttrBytes (f B.ByteString) -- bytes s = 2; // "string"
|
||||
| AttrInt64 (f Int64) -- int64 i = 3; // "int"
|
||||
| AttrFloat (f Float) -- float f = 4; // "float"
|
||||
| AttrBool (f Bool) -- bool b = 5; // "bool"
|
||||
| AttrType (f DataType) -- type = 6; // "type"
|
||||
-- To be translated into TensorFlow.Types.Shape before use.
|
||||
-- Leaving as a proto to reduce dependencies.
|
||||
| AttrShape (f TensorShapeProto) -- shape = 7; // "shape"
|
||||
|
||||
-- | Type-reified representation of TensorFlow AttrDef.
|
||||
-- Initially limited to just the types in Op descriptors.
|
||||
data AttrTemplate
|
||||
= AttrSingle (AttrCase Template)
|
||||
| AttrList (AttrCase [])
|
||||
| AttrTensor UnusedTensor -- tensor = 8; // "tensor"
|
||||
|
||||
data AttrDef = AttrDef {
|
||||
_attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from
|
||||
, _attrTemplate :: AttrTemplate -- ^ the type of the attribute
|
||||
}
|
||||
|
||||
attrTemplate :: Lens' AttrDef AttrTemplate
|
||||
attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x })
|
||||
|
||||
attrOriginal :: Lens' AttrDef OpDef'AttrDef
|
||||
attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x })
|
||||
|
||||
attrDef :: OpDef'AttrDef -> AttrDef
|
||||
attrDef a = AttrDef a
|
||||
$ translate (a^.OpDef.type')
|
||||
(a^.OpDef.defaultValue)
|
||||
(a^.allowedValues)
|
||||
|
||||
-- | Converts the given AttrValue with the type given by the string
|
||||
-- into the AttrVal if the type is known.
|
||||
translate :: Text.Text -- ^ one of the TensorFlow type strings
|
||||
-> AttrValue -- ^ default value
|
||||
-> AttrValue -- ^ allowed values
|
||||
-> AttrTemplate
|
||||
translate t defaults allowed
|
||||
| t == "string" = makeVal AttrBytes maybe's s
|
||||
| t == "int" = makeVal AttrInt64 maybe'i i
|
||||
| t == "float" = makeVal AttrFloat maybe'f f
|
||||
| t == "bool" = makeVal AttrBool maybe'b b
|
||||
| t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type'
|
||||
| t == "shape" = makeVal AttrShape maybe'shape shape
|
||||
| t == "tensor" = AttrTensor $ error "tensor is unimplemented"
|
||||
| t == "list(string)" = makeList AttrBytes $ list.s
|
||||
| t == "list(int)" = makeList AttrInt64 $ list.i
|
||||
| t == "list(float)" = makeList AttrFloat $ list.f
|
||||
| t == "list(bool)" = makeList AttrBool $ list.b
|
||||
| t == "list(type)" = makeList AttrType $ list.AttrValue.type'
|
||||
| t == "list(shape)" = makeList AttrShape $ list.shape
|
||||
| t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented"
|
||||
| t == "func" = AttrTensor $ error "func is unimplemented"
|
||||
| otherwise = error $ show ("Unknown attribute type " <> t) ++
|
||||
"," ++ show defaults ++
|
||||
"," ++ show allowed
|
||||
where makeVal c x y = AttrSingle $ c $
|
||||
Template (defaults^.x) (allowed^.list.y)
|
||||
makeList c y = AttrList $ c $ defaults^.y
|
33
tensorflow-opgen/tensorflow-opgen.cabal
Normal file
33
tensorflow-opgen/tensorflow-opgen.cabal
Normal file
|
@ -0,0 +1,33 @@
|
|||
name: tensorflow-opgen
|
||||
version: 0.1.0.0
|
||||
synopsis: Code generation for TensorFlow operations.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Simple
|
||||
cabal-version: >=1.22
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.OpGen.AttrVal
|
||||
, TensorFlow.OpGen
|
||||
build-depends: proto-lens == 0.1.*
|
||||
, tensorflow-proto == 0.1.*
|
||||
, base >= 4.7 && < 5
|
||||
, bytestring
|
||||
, containers
|
||||
, filepath
|
||||
, lens-family
|
||||
, mainland-pretty
|
||||
, optparse-applicative
|
||||
, semigroups
|
||||
, text
|
||||
default-language: Haskell2010
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
1
tensorflow-opgen/third_party
Symbolic link
1
tensorflow-opgen/third_party
Symbolic link
|
@ -0,0 +1 @@
|
|||
../third_party/tensorflow
|
3
tensorflow-ops/Setup.hs
Normal file
3
tensorflow-ops/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
|||
import Distribution.Simple
|
||||
|
||||
main = defaultMain
|
76
tensorflow-ops/src/TensorFlow/EmbeddingOps.hs
Normal file
76
tensorflow-ops/src/TensorFlow/EmbeddingOps.hs
Normal file
|
@ -0,0 +1,76 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
||||
-- | Parallel lookups on the list of tensors.
|
||||
module TensorFlow.EmbeddingOps where
|
||||
|
||||
import Control.Monad (zipWithM)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
import TensorFlow.Build (Build, colocateWith, render)
|
||||
import TensorFlow.Ops () -- Num instance for Tensor
|
||||
import TensorFlow.Tensor (Tensor, Value)
|
||||
import TensorFlow.Types (OneOf, TensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
-- | Looks up `ids` in a list of embedding tensors.
|
||||
--
|
||||
-- This function is used to perform parallel lookups on the list of
|
||||
-- tensors in `params`. It is a generalization of `TF.gather`, where
|
||||
-- `params` is interpreted as a partition of a larger embedding
|
||||
-- tensor.
|
||||
--
|
||||
-- The partition_strategy is "mod", we assign each id to partition
|
||||
-- `p = id % len(params)`. For instance,
|
||||
-- 13 ids are split across 5 partitions as:
|
||||
-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
|
||||
--
|
||||
-- The results of the lookup are concatenated into a dense
|
||||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup :: forall a b v .
|
||||
( TensorType a
|
||||
, OneOf '[Int64, Int32] b
|
||||
, Num b
|
||||
)
|
||||
=> [Tensor v a]
|
||||
-- ^ A list of tensors which can be concatenated along
|
||||
-- dimension 0. Each `Tensor` must be appropriately
|
||||
-- sized for `mod` partition strategy.
|
||||
-> Tensor Value b
|
||||
-- ^ A `Tensor` with type `int32` or `int64`
|
||||
-- containing the ids to be looked up in `params`.
|
||||
-- The ids are required to be flat on entry and have
|
||||
-- fewer than 2^31 entries.
|
||||
-> Build (Tensor Value a)
|
||||
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup params ids =
|
||||
CoreOps.dynamicStitch pindices <$> partitionedResult
|
||||
where np = genericLength params
|
||||
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
|
||||
newIds = ids `CoreOps.div` np
|
||||
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
|
||||
-- Partition list of ids based on assignments into np separate lists
|
||||
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
|
||||
-- Similarly, partition the original indices.
|
||||
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
|
||||
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
|
||||
partitionedResult = zipWithM
|
||||
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
|
||||
params gatherIds
|
697
tensorflow-ops/src/TensorFlow/Gradient.hs
Normal file
697
tensorflow-ops/src/TensorFlow/Gradient.hs
Normal file
|
@ -0,0 +1,697 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE ViewPatterns #-}
|
||||
|
||||
module TensorFlow.Gradient
|
||||
( gradients
|
||||
) where
|
||||
|
||||
import Control.Monad (forM, zipWithM)
|
||||
import Control.Monad.State.Strict (State, evalState, gets, modify)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (foldl', sortBy)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
||||
import Data.Ord (comparing)
|
||||
import Data.ProtoLens.TextFormat (showMessage)
|
||||
import Data.Set (Set)
|
||||
import Data.Text (Text)
|
||||
import Data.Tuple (swap)
|
||||
import Lens.Family2 (Lens', (&), (^.), (.~), (%~))
|
||||
import Lens.Family2.State.Strict (uses)
|
||||
import Lens.Family2.Stock (at, intAt)
|
||||
import Lens.Family2.Unchecked (lens, iso)
|
||||
import Prelude hiding (sum)
|
||||
import Text.Printf (printf)
|
||||
import qualified Data.Graph.Inductive.Basic as FGL
|
||||
import qualified Data.Graph.Inductive.Graph as FGL
|
||||
import qualified Data.Graph.Inductive.PatriciaTree as FGL
|
||||
import qualified Data.Graph.Inductive.Query.DFS as FGL
|
||||
import qualified Data.IntMap.Strict as IntMap
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import qualified Data.Text as Text
|
||||
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Build
|
||||
( Build
|
||||
, render
|
||||
, renderNodeName
|
||||
, renderedNodeDefs
|
||||
, opDef
|
||||
, opAttr
|
||||
)
|
||||
import TensorFlow.BuildOp
|
||||
import TensorFlow.Ops
|
||||
( addN
|
||||
, broadcastGradientArgs
|
||||
, expandDims
|
||||
, fill
|
||||
, matMul
|
||||
, reducedShape
|
||||
, reluGrad
|
||||
, reshape
|
||||
, scalar
|
||||
, shape
|
||||
, softmaxCrossEntropyWithLogits
|
||||
, sum
|
||||
, vector
|
||||
, zerosLike
|
||||
)
|
||||
import TensorFlow.Output
|
||||
( NodeName(..)
|
||||
, Op (Rendered)
|
||||
, Output(..)
|
||||
, OutputIx(..)
|
||||
, outputIndex
|
||||
)
|
||||
import TensorFlow.Tensor
|
||||
( Tensor(..)
|
||||
, TensorKind (ValueKind)
|
||||
, Value
|
||||
, tensorOutput
|
||||
, tensorAttr
|
||||
)
|
||||
import TensorFlow.Types (OneOf, TensorType, attrLens)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
(NodeDef, attr, input, op, name)
|
||||
|
||||
type GradientCompatible a =
|
||||
-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
|
||||
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
|
||||
|
||||
-- TODO(fmayle): Support control flow.
|
||||
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.
|
||||
-- TODO(fmayle): Do we need to consider control inputs? See _PendingCount in
|
||||
-- tensorflow/python/ops/gradients.py.
|
||||
-- TODO(fmayle): Maybe store the gradient functions and numOutputs on the OpDef.
|
||||
|
||||
|
||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
|
||||
-- TODO(gnezdo): remove indirect constraint.
|
||||
-- It's a wart inherited from Num instance.
|
||||
, v1 ~ Value
|
||||
, GradientCompatible a
|
||||
)
|
||||
=> Tensor v1 a -- ^ The output of the graph.
|
||||
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
||||
-> Build [Tensor Value a]
|
||||
gradients y xs = do
|
||||
-- The gradients are computed using "reverse accumulation", similarly to
|
||||
-- what is described here:
|
||||
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
|
||||
--
|
||||
-- The code is summarised as follows:
|
||||
--
|
||||
-- 1. Create an fgl graph of the relevant nodes (ops) and edges (tensors).
|
||||
-- 2. Initialize the gradient of y to 1 (∂y/∂y = 1) and the rest of tensor's
|
||||
-- gradients to nothing.
|
||||
-- 3. Process the nodes in reverse topological order (i.e. each node comes
|
||||
-- after all of its outputs so that the output gradients for a node have
|
||||
-- been completely calculated before it is processed):
|
||||
-- a. Record the gradient for each of the node's output tensors (∂y/∂w
|
||||
-- for each output tensor w).
|
||||
-- b. Calculate the gradient of y w.r.t. each of the node's input
|
||||
-- tensors using the gradients of the node's output tensors.
|
||||
--
|
||||
-- Written differently, for each output tensor w and input tensor v:
|
||||
-- ∂y/∂w = ... (calculated in previous steps)
|
||||
-- ∂w/∂v = ... (op specific)
|
||||
-- ∂y/∂v = ∂y/∂w * ∂w/∂v (technically, if tensor v is an input
|
||||
-- to multiple nodes, then this is only
|
||||
-- part of ∂y/∂v)
|
||||
--
|
||||
-- 4. Lookup the recorded gradient for each x in xs.
|
||||
|
||||
yName <- renderNodeName y
|
||||
-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
|
||||
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
|
||||
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
||||
. flip Map.lookup
|
||||
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
||||
-- Set gradient of y to one.
|
||||
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
||||
initPending = Map.empty & at (nodeMap Map.! yName)
|
||||
. nonEmpty
|
||||
. outputIxAt (y ^. tensorOutput . outputIndex)
|
||||
. nonEmpty
|
||||
.~ [fill (shape y) (scalar 1)]
|
||||
-- Calculate the gradients of y w.r.t. each node in the graph.
|
||||
gradientMap <- graphGrads gr initPending
|
||||
-- Lookup the gradients for each x.
|
||||
forM xs $ \x -> do
|
||||
xName <- renderNodeName x
|
||||
render $ fromMaybe (zerosLike x) $ do
|
||||
n <- nodeMap ^. at xName
|
||||
let i = x ^. tensorOutput . outputIndex
|
||||
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||
|
||||
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
||||
outputIxAt = intAt . unOutputIx
|
||||
|
||||
-- | Incomplete gradients of a node's outputs.
|
||||
--
|
||||
-- The lists represent partial sums. The key is an OutputIx sans newtype.
|
||||
type PendingGradients a = IntMap.IntMap [Tensor Value a]
|
||||
|
||||
-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
|
||||
type Gradients a = IntMap.IntMap (Tensor Value a)
|
||||
|
||||
-- | Graph of TensorFlow operations.
|
||||
type Graph = FGL.Gr NodeDef EdgeLabel
|
||||
|
||||
-- | Data associated with an edge.
|
||||
--
|
||||
-- Pair of
|
||||
-- 1. Output index of a tensor from the source node.
|
||||
-- 2. Input index that the tensor connects to on the destination node.
|
||||
type EdgeLabel = (OutputIx, OutputIx)
|
||||
|
||||
|
||||
-- | State used for calculating gradients.
|
||||
data GradientsState a = GradientsState
|
||||
{ _gradientsPending :: !(Map FGL.Node (PendingGradients a))
|
||||
, _gradientsResult :: !(Map FGL.Node (Gradients a))
|
||||
}
|
||||
|
||||
gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a))
|
||||
gradientsPending = lens _gradientsPending (\x y -> x { _gradientsPending = y })
|
||||
|
||||
gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a))
|
||||
gradientsResult = lens _gradientsResult (\x y -> x { _gradientsResult = y })
|
||||
|
||||
|
||||
-- TODO(fmayle): Use something like Data.List.Safe.
|
||||
-- | Safe version of (!!).
|
||||
safeIndex :: [a] -> Int -> Maybe a
|
||||
_ `safeIndex` n | n < 0 = Nothing
|
||||
[] `safeIndex` _ = Nothing
|
||||
(x:_) `safeIndex` 0 = Just x
|
||||
(_:xs) `safeIndex` n = xs `safeIndex` (n-1)
|
||||
|
||||
-- Copy of http://hackage.haskell.org/package/lens-3.9.0.2/docs/Control-Lens-Iso.html#v%3anon
|
||||
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
|
||||
anon a p = iso (fromMaybe a) go where
|
||||
go b | p b = Nothing
|
||||
| otherwise = Just b
|
||||
|
||||
non :: Eq a => a -> Lens' (Maybe a) a
|
||||
non a = anon a (a==)
|
||||
|
||||
-- | Lens that defaults Nothing to mempty.
|
||||
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
|
||||
nonEmpty = anon mempty null
|
||||
|
||||
-- | Calculate the gradients for every node in a graph.
|
||||
graphGrads :: forall a. GradientCompatible a
|
||||
=> Graph
|
||||
-> Map FGL.Node (PendingGradients a)
|
||||
-- ^ Initial gradients (usually just 1 for the node of interest).
|
||||
-> Build (Map FGL.Node (Gradients a))
|
||||
graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult)
|
||||
where
|
||||
initState = GradientsState initPending Map.empty
|
||||
-- Reverse topological sort.
|
||||
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
|
||||
-- avoid calculating gradients that won't be used.
|
||||
nodeOrder = FGL.topsort $ FGL.grev gr
|
||||
go state node =
|
||||
-- Aggregate the accumulated gradients for this node.
|
||||
let outputGrads =
|
||||
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
|
||||
in if null outputGrads
|
||||
then state
|
||||
else
|
||||
-- Calculate the gradients for each of the node's inputs.
|
||||
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
||||
ctx = FGL.context gr node
|
||||
in updatePendingGradients
|
||||
ctx
|
||||
(calculateInputGrads ctx outputGrads gr)
|
||||
nextState
|
||||
|
||||
-- | Reduce accumulated gradients for each output to one Tensor.
|
||||
sumPendingGradient :: GradientCompatible a
|
||||
=> PendingGradients a -> Gradients a
|
||||
sumPendingGradient = IntMap.mapMaybe f
|
||||
where
|
||||
f [] = Nothing
|
||||
f [x] = Just x
|
||||
f xs = Just (addN xs)
|
||||
|
||||
|
||||
-- | Calculate the gradients of a node's input tensors.
|
||||
--
|
||||
-- This is mostly just a wrapper around opGrad.
|
||||
calculateInputGrads :: forall a. GradientCompatible a
|
||||
=> FGL.Context NodeDef EdgeLabel
|
||||
-> Gradients a -- ^ Output gradients of the node.
|
||||
-> Graph
|
||||
-> [Maybe (Tensor Value a)]
|
||||
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
|
||||
opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
|
||||
where
|
||||
fullOutGrads =
|
||||
fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads
|
||||
-- Create a tensor from an edge (technically an Output, but it seems less
|
||||
-- confusing to refer to it as a tensor here).
|
||||
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
|
||||
edgeToTensor ((i, _), n) =
|
||||
case FGL.lab gr n of
|
||||
Just edgeNodeDef -> Output i (Rendered edgeNodeDef)
|
||||
Nothing -> error $ "calculateInputGrads: missing input node for "
|
||||
++ Text.unpack (nodeDef ^. name)
|
||||
-- Input tensors, sorted by input index.
|
||||
inputTensors = map edgeToTensor $ sortBy (comparing (snd . fst)) inputEdges
|
||||
|
||||
-- | Convert a Map of gradients to a list, with zeros for missing outputs.
|
||||
fullOutputGrads :: (TensorType a, Num a)
|
||||
=> OutputIx -- ^ Number of outputs.
|
||||
-> Op
|
||||
-> Gradients a
|
||||
-> [Tensor Value a]
|
||||
fullOutputGrads n o gs =
|
||||
map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1]
|
||||
where
|
||||
-- A tensor of zeros with the same shape as the i'th output.
|
||||
zero i = zerosLike $ toT (Output i o)
|
||||
|
||||
|
||||
-- | Update the pending gradients of a node's inputs.
|
||||
updatePendingGradients :: forall a. (TensorType a, Num a)
|
||||
=> FGL.Context NodeDef EdgeLabel
|
||||
-> [Maybe (Tensor Value a)]
|
||||
-- ^ Gradient of each input tensor.
|
||||
-> GradientsState a
|
||||
-> GradientsState a
|
||||
updatePendingGradients (inputEdges, _, nodeDef, _) inputGrads initState =
|
||||
foldl' go initState inputEdges
|
||||
where
|
||||
go :: GradientsState a -> (EdgeLabel, FGL.Node) -> GradientsState a
|
||||
go state ((outIndex, OutputIx inIndex), node) =
|
||||
case maybeGradient of
|
||||
Nothing -> state
|
||||
Just g ->
|
||||
-- Add to the list of pending gradients for this tensor.
|
||||
state & gradientsPending
|
||||
. at node
|
||||
. nonEmpty
|
||||
. outputIxAt outIndex
|
||||
. nonEmpty
|
||||
%~ (g:)
|
||||
where
|
||||
badSizeErr = error $ printf "updatePendingGradients: bad input index \
|
||||
\%d for inputGrads of length %d in %s"
|
||||
inIndex (length inputGrads)
|
||||
(show (nodeDef ^. name))
|
||||
maybeGradient = fromMaybe badSizeErr (safeIndex inputGrads inIndex)
|
||||
|
||||
|
||||
-- | Create a graph that includes a node and its transitive dependencies.
|
||||
createGraph :: NodeName -> (NodeName -> NodeDef)
|
||||
-> (Graph, Map NodeName FGL.Node)
|
||||
createGraph nodeName nodeDefLookup = (FGL.nmap nodeDefLookup graph, nodeMap)
|
||||
where
|
||||
-- Parse a tensor name.
|
||||
parseTensorName :: Text -> Maybe (NodeName, OutputIx)
|
||||
parseTensorName n
|
||||
| Text.null n = error "parseTensorName: empty name"
|
||||
| Text.head n == '^' = Nothing -- Control edge
|
||||
| otherwise =
|
||||
let (nm, indexStr) = Text.breakOn ":" n
|
||||
index | Text.null indexStr = 0
|
||||
| otherwise = read $ Text.unpack $ Text.tail indexStr
|
||||
in Just (NodeName nm, OutputIx index)
|
||||
|
||||
-- Build a map from node name to outward edges.
|
||||
--
|
||||
-- The state is the set of visited nodes.
|
||||
collect :: Maybe (NodeName, OutputIx, OutputIx)
|
||||
-> NodeName
|
||||
-> State (Set NodeName)
|
||||
(Map NodeName [(NodeName, OutputIx, OutputIx)])
|
||||
collect outgoingEdge nm = do
|
||||
let nextLookup = Map.singleton nm (maybeToList outgoingEdge)
|
||||
seen <- gets (Set.member nm)
|
||||
modify (Set.insert nm)
|
||||
if seen
|
||||
then pure nextLookup
|
||||
else do
|
||||
let inputs = nodeDefLookup nm ^. input
|
||||
recurse inIndex (parentName, outIndex) =
|
||||
collect (Just (nm, outIndex, inIndex)) parentName
|
||||
subEdgeLookups <-
|
||||
zipWithM recurse [0..] $ mapMaybe parseTensorName inputs
|
||||
pure $ Map.unionsWith (++) (nextLookup:subEdgeLookups)
|
||||
|
||||
edgeLookup = evalState (collect Nothing nodeName) Set.empty
|
||||
-- Associate an ID with each node name.
|
||||
nodeMap = Map.fromList $ zip (Map.keys edgeLookup) [0..]
|
||||
-- Create the graph.
|
||||
graph = FGL.mkGraph (swap <$> Map.toList nodeMap)
|
||||
[ (nodeMap Map.! n, nodeMap Map.! m, (i, j))
|
||||
| (n, edges) <- Map.toList edgeLookup
|
||||
, (m, i, j) <- edges
|
||||
]
|
||||
|
||||
-- | Function to compute the gradient of y w.r.t. each input.
|
||||
--
|
||||
-- Let y be an arbitrary tensor
|
||||
-- and [w_0, ..., w_n] be the output tensors of a node
|
||||
-- and [v_0, ..., v_n] be the input tensors of the same node.
|
||||
--
|
||||
-- Given [∂y/∂w_0, ..., ∂y/∂w_n] and [v_0, ..., v_n], a GradientFunc computes
|
||||
-- [∂y/∂v_0, ..., ∂y/∂v_n] for a particular op type.
|
||||
--
|
||||
-- A Nothing gradient is equivalent to zero (but allows for short circuiting
|
||||
-- computation when all the gradients for something are Nothing).
|
||||
type GradientFunc a = NodeDef
|
||||
-> [Output]
|
||||
-- ^ Input tensors.
|
||||
-> [Tensor Value a]
|
||||
-- ^ Gradient of y w.r.t. each output tensor.
|
||||
-> [Maybe (Tensor Value a)]
|
||||
-- ^ Gradient of y w.r.t. each input tensor.
|
||||
|
||||
|
||||
-- TODO(fmayle): Assert the type is correct.
|
||||
-- | Create a Tensor from an Output.
|
||||
toT :: Output -> Tensor Value a
|
||||
toT = Tensor ValueKind
|
||||
|
||||
-- | The gradient function for an op type.
|
||||
--
|
||||
-- These implementations should match their python counterparts in:
|
||||
-- third_party/tensorflow/python/ops/*_grad.py
|
||||
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
|
||||
|
||||
opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x]
|
||||
opGrad "Neg" _ [_] [dz] = [Just $ -dz]
|
||||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||
|
||||
opGrad "Square" _ [toT -> x] [dz] =
|
||||
-- TODO(fmayle): Handle complex numbers.
|
||||
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
|
||||
-- (for performance reasons?). Will need to put these functions in the Build
|
||||
-- monad to replicate that.
|
||||
[Just $ dz * (2 * x)]
|
||||
|
||||
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
||||
-- TODO(fmayle): The python version uses a better performance implementation
|
||||
-- when the shape is known without having to run the graph.
|
||||
-- TODO(fmayle): We shouldn't convert the result to a dense tensor. Sparse
|
||||
-- tensor support will require some thinking.
|
||||
[ Just $ CoreOps.unsortedSegmentSum values indices' numRows
|
||||
, Nothing
|
||||
]
|
||||
where
|
||||
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
||||
denseShape = shape (x :: Tensor Value a)
|
||||
numRows = CoreOps.slice denseShape 0 (1 :: Tensor Value Int32)
|
||||
valuesShape = CoreOps.concat 0 [
|
||||
allDimensions
|
||||
, CoreOps.slice denseShape 1 (-1 :: Tensor Value Int32)
|
||||
]
|
||||
values = reshape dz valuesShape
|
||||
-- TODO(fmayle): This could be either Int32 or Int64.
|
||||
indices' = reshape indices allDimensions :: Tensor Value Int32
|
||||
|
||||
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
|
||||
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
|
||||
where
|
||||
sx = shape (x :: Tensor Value a)
|
||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
||||
x' = reshape x outputShapeKeptDims
|
||||
dz' = reshape dz outputShapeKeptDims
|
||||
indicators = CoreOps.cast $ CoreOps.equal x' x
|
||||
numSelected = reshape (sum indicators indices) outputShapeKeptDims
|
||||
|
||||
-- Min and Max have identical gradient implementations.
|
||||
opGrad "Min" u v w = opGrad "Max" u v w
|
||||
|
||||
opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
|
||||
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
|
||||
where
|
||||
-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
|
||||
sx = shape (x :: Tensor Value a)
|
||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
||||
tileScaling = safeShapeDiv sx outputShapeKeptDims
|
||||
grad = reshape dz outputShapeKeptDims
|
||||
|
||||
opGrad "Mean" u v@[toT -> x, _] w =
|
||||
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
|
||||
where
|
||||
[Just dz, Nothing] = opGrad "Sum" u v w
|
||||
inputShape = shape (x :: Tensor Value a)
|
||||
outputShape = shape (dz :: Tensor Value a)
|
||||
-- TODO(fmayle): Add fast path when shape is known.
|
||||
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
|
||||
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
|
||||
factor = safeShapeDiv inputSize outputSize
|
||||
|
||||
opGrad "Add" _ [toT -> x, toT -> y] [dz] =
|
||||
[ Just $ reshape (sum dz rx) sx
|
||||
, Just $ reshape (sum dz ry) sy ]
|
||||
where
|
||||
sx = shape (x :: Tensor Value a)
|
||||
sy = shape (y :: Tensor Value a)
|
||||
(rx, ry) = broadcastGradientArgs sx sy
|
||||
|
||||
opGrad "Sub" u v w =
|
||||
[Just x, Just (-y)]
|
||||
where
|
||||
[Just x, Just y] = opGrad "Add" u v w
|
||||
|
||||
opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
|
||||
[ Just $ expandDims dz (-1) * snd (softmaxCrossEntropyWithLogits x y)
|
||||
, Nothing ]
|
||||
|
||||
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
|
||||
-- TODO(fmayle): Handle complex numbers.
|
||||
[ Just $ reshape (sum (dz * y) rx) sx
|
||||
, Just $ reshape (sum (x * dz) ry) sy ]
|
||||
where
|
||||
sx = shape (x :: Tensor Value a)
|
||||
sy = shape (y :: Tensor Value a)
|
||||
(rx, ry) = broadcastGradientArgs sx sy
|
||||
|
||||
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
|
||||
-- TODO(fmayle): Handle complex numbers.
|
||||
-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
|
||||
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
|
||||
, Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy
|
||||
]
|
||||
where
|
||||
sx = shape (x :: Tensor Value a)
|
||||
sy = shape (y :: Tensor Value a)
|
||||
(rx, ry) = broadcastGradientArgs sx sy
|
||||
|
||||
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||
let transposeA = lookupAttr nodeDef "transpose_a"
|
||||
transposeB = lookupAttr nodeDef "transpose_b"
|
||||
transAttrs a b =
|
||||
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
|
||||
in case (transposeA, transposeB) of
|
||||
(False, False) ->
|
||||
[ Just $ (dz `matMul` y) & transAttrs False True
|
||||
, Just $ (x `matMul` dz) & transAttrs True False ]
|
||||
(False, True) ->
|
||||
[ Just $ dz `matMul` y
|
||||
, Just $ (x `matMul` dz) & transAttrs True False ]
|
||||
(True, False) ->
|
||||
[ Just $ (dz `matMul` y) & transAttrs False True
|
||||
, Just $ x `matMul` dz ]
|
||||
(True, True) ->
|
||||
[ Just $ (dz `matMul` y) & transAttrs True True
|
||||
, Just $ (x `matMul` dz) & transAttrs True True ]
|
||||
|
||||
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||
[ Just $ CoreOps.transpose dz
|
||||
(CoreOps.invertPermutation p :: Tensor Value Int32)
|
||||
, Nothing
|
||||
]
|
||||
|
||||
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
||||
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
|
||||
& tensorAttr "strides" .~ strides
|
||||
& tensorAttr "padding" .~ padding
|
||||
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
||||
& tensorAttr "data_format" .~ dataFormat
|
||||
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
|
||||
& tensorAttr "strides" .~ strides
|
||||
& tensorAttr "padding" .~ padding
|
||||
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
||||
& tensorAttr "data_format" .~ dataFormat
|
||||
]
|
||||
where
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
|
||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||
|
||||
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||
[ Just $ CoreOps.maxPoolGrad x output dz
|
||||
& tensorAttr "ksize" .~ ksize
|
||||
& tensorAttr "strides" .~ strides
|
||||
& tensorAttr "padding" .~ padding
|
||||
& tensorAttr "data_format" .~ dataFormat
|
||||
]
|
||||
where
|
||||
output :: Tensor Value a
|
||||
output = toT $ Output 0 (Rendered nodeDef)
|
||||
ksize = lookupAttr nodeDef "ksize" :: [Int64]
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||
|
||||
opGrad "Reshape" _ [toT -> x, _] [dz] =
|
||||
[Just $ reshape dz $ shape (x :: Tensor Value a), Nothing]
|
||||
|
||||
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
||||
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
||||
|
||||
opGrad "RefIdentity" _ _ [dz] = [Just dz]
|
||||
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
|
||||
where
|
||||
-- TODO(gnezdo): too permissive, python only allows float types as src_type.
|
||||
reverseCast =
|
||||
buildOp (opDef "Cast"
|
||||
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
|
||||
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString))
|
||||
dz
|
||||
|
||||
opGrad "DynamicStitch" nodeDef inputs [dz] =
|
||||
replicate halfLen Nothing ++ valuesGrads
|
||||
where
|
||||
halfLen =
|
||||
let len = length inputs
|
||||
half = len `div` 2
|
||||
in if 2 * half == len
|
||||
then half
|
||||
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
|
||||
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32)
|
||||
| idx <- take halfLen inputs
|
||||
]
|
||||
|
||||
opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
|
||||
[ Just reconstructed, Nothing ]
|
||||
where
|
||||
reconstructed = CoreOps.reshape stitched
|
||||
(CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32)
|
||||
stitched = CoreOps.dynamicStitch partitionedIndices dz
|
||||
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
|
||||
np = lookupAttr nodeDef "num_partitions" :: Int64
|
||||
originalIndices =
|
||||
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
|
||||
prefixShape = shapeInt32 indices
|
||||
shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32
|
||||
|
||||
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
|
||||
[ Nothing
|
||||
, Just $ CoreOps.select c dz zeros
|
||||
, Just $ CoreOps.select c zeros dz
|
||||
]
|
||||
where zeros = CoreOps.zerosLike x
|
||||
|
||||
-- TODO(gnezdo): Unlike Python, no control dependency on dz.
|
||||
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ]
|
||||
-- TODO(gnezdo): Reuse the output instead of doing another exp,
|
||||
-- though, it is probably CSE'd away anyway.
|
||||
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ]
|
||||
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
||||
[ Just $ CoreOps.unsortedSegmentSum
|
||||
(CoreOps.gather dz (t :: Tensor Value Int32))
|
||||
(y :: Tensor Value Int32) inputRows
|
||||
, Nothing
|
||||
, Nothing
|
||||
]
|
||||
where inputRows = CoreOps.slice (shape (x :: Tensor Value a)) (scalar (0 :: Int32)) (scalar 1)
|
||||
|
||||
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||
opGrad "Size" _ _ _ = [Nothing]
|
||||
opGrad "ZerosLike" _ _ _ = [Nothing]
|
||||
|
||||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "Placeholder" _ _ _ = []
|
||||
opGrad "Variable" _ _ _ = []
|
||||
|
||||
opGrad n nodeDef ins grads =
|
||||
error $ "no gradient implemented for " ++
|
||||
show (n, length ins, length grads, showMessage nodeDef, ins)
|
||||
|
||||
-- | The number of outputs for an op type.
|
||||
numOutputs :: NodeDef -> OutputIx
|
||||
numOutputs o =
|
||||
case o ^. op of
|
||||
"Abs" -> 1
|
||||
"Add" -> 1
|
||||
"Cast" -> 1
|
||||
"Const" -> 1
|
||||
"Conv2D" -> 1
|
||||
"Div" -> 1
|
||||
"DynamicStitch" -> 1
|
||||
"DynamicPartition" ->
|
||||
fromIntegral (lookupAttr o "num_partitions" :: Int64)
|
||||
"Exp" -> 1
|
||||
"Gather" -> 1
|
||||
"LabelClasses" -> 1
|
||||
"LabelWeights" -> 1
|
||||
"Log" -> 1
|
||||
"MatMul" -> 1
|
||||
"Max" -> 1
|
||||
"MaxPool" -> 1
|
||||
"Mean" -> 1
|
||||
"Min" -> 1
|
||||
"Mul" -> 1
|
||||
"Neg" -> 1
|
||||
"Placeholder" -> 1
|
||||
"OneHot" -> 1
|
||||
"RefIdentity" -> 1
|
||||
"Relu" -> 1
|
||||
"Reshape" -> 1
|
||||
"Select" -> 1
|
||||
"Size" -> 1
|
||||
"SoftmaxCrossEntropyWithLogits" -> 2
|
||||
"Square" -> 1
|
||||
"SparseSegmentSum" -> 1
|
||||
"Sub" -> 1
|
||||
"Sum" -> 1
|
||||
"Transpose" -> 1
|
||||
"TruncatedNormal" -> 1
|
||||
"Variable" -> 1
|
||||
"ZerosLike" -> 1
|
||||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||
|
||||
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
||||
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
||||
|
||||
allDimensions = vector [-1 :: Int32]
|
||||
|
||||
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
||||
|
||||
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens
|
296
tensorflow-ops/src/TensorFlow/Ops.hs
Normal file
296
tensorflow-ops/src/TensorFlow/Ops.hs
Normal file
|
@ -0,0 +1,296 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- | This module contains definitions for some built-in TensorFlow operations.
|
||||
--
|
||||
-- Note that certain, "stateful" ops like 'variable' and 'assign' return a
|
||||
-- 'Build' action (e.g., @Build (Tensor Ref a)@ instead of a pure value; the
|
||||
-- returned 'Tensor's are always rendered in the current 'Build' context. This
|
||||
-- approach helps us avoid problems with inlining or common subexpression
|
||||
-- elimination, by writing
|
||||
--
|
||||
-- > do
|
||||
-- > v <- variable []
|
||||
-- > w <- assign v 3
|
||||
-- > render $ w * w
|
||||
--
|
||||
-- instead of
|
||||
--
|
||||
-- > let
|
||||
-- > v = variable []
|
||||
-- > w = assign v 3
|
||||
-- > in w * w
|
||||
--
|
||||
-- since the latter could be reasonably transformed by the compiler into (or
|
||||
-- vice versa)
|
||||
--
|
||||
-- > let
|
||||
-- > v = variable []
|
||||
-- > w = assign v 3
|
||||
-- > w' = assign v 3
|
||||
-- > in w * w'
|
||||
--
|
||||
-- Ops should return a 'Build' action if their original 'OpDef' marks them as
|
||||
-- stateful, or if they take any Refs as input. (This mirrors the rules that
|
||||
-- TensorFlow uses to avoid common subexpression elimination.)
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module TensorFlow.Ops
|
||||
( CoreOps.add
|
||||
, CoreOps.abs
|
||||
, CoreOps.addN
|
||||
, CoreOps.argMax
|
||||
, assign
|
||||
, CoreOps.broadcastGradientArgs
|
||||
, CoreOps.cast
|
||||
, CoreOps.concat
|
||||
, constant
|
||||
, expandDims
|
||||
, initializedVariable
|
||||
, zeroInitializedVariable
|
||||
, CoreOps.fill
|
||||
, CoreOps.matMul
|
||||
, matTranspose
|
||||
, CoreOps.mul
|
||||
, CoreOps.neg
|
||||
, CoreOps.pack
|
||||
, placeholder
|
||||
, CoreOps.range
|
||||
, reducedShape
|
||||
, CoreOps.relu
|
||||
, CoreOps.reluGrad
|
||||
, CoreOps.reshape
|
||||
, restore
|
||||
, restoreFromName
|
||||
, save
|
||||
, scalar
|
||||
, shape
|
||||
, CoreOps.sign
|
||||
, CoreOps.size
|
||||
, CoreOps.softmax
|
||||
, CoreOps.softmaxCrossEntropyWithLogits
|
||||
, CoreOps.sparseToDense
|
||||
, CoreOps.sub
|
||||
, CoreOps.sum
|
||||
, CoreOps.topK
|
||||
, CoreOps.transpose
|
||||
, truncatedNormal
|
||||
, variable
|
||||
, vector
|
||||
, zeros
|
||||
, CoreOps.zerosLike
|
||||
) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Complex (Complex)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Prelude hiding (abs, sum, concat)
|
||||
import Data.ProtoLens (def)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import Text.Printf (printf)
|
||||
import Proto.Tensorflow.Core.Framework.Tensor
|
||||
( TensorProto
|
||||
, dtype
|
||||
, tensorShape
|
||||
)
|
||||
import qualified Proto.Tensorflow.Core.Framework.TensorShape
|
||||
as TensorShape
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.BuildOp
|
||||
import TensorFlow.ControlFlow (group)
|
||||
import TensorFlow.Output (unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
import qualified Prelude (abs)
|
||||
|
||||
-- TODO: Look into hs-boot refactoring to allow mutually recursive imports.
|
||||
-- | Must be defined as an orphan because of the dependency order between Ops
|
||||
-- and Tensor.
|
||||
--
|
||||
-- The indirect constraint "v ~ Value" helps disambiguate types, for example in
|
||||
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
|
||||
-- "1".
|
||||
instance ( TensorType a
|
||||
, Num a
|
||||
, v ~ Value
|
||||
, OneOf '[ Double, Float, Int32, Int64
|
||||
, Complex Float, Complex Double] a) => Num (Tensor v a) where
|
||||
(+) = CoreOps.add
|
||||
(*) = CoreOps.mul
|
||||
(-) = CoreOps.sub
|
||||
abs = CoreOps.abs
|
||||
fromInteger = scalar . fromInteger
|
||||
signum = CoreOps.sign
|
||||
negate = CoreOps.neg
|
||||
|
||||
matTranspose :: forall a v . TensorType a
|
||||
=> Tensor v a -> Tensor Value a
|
||||
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
||||
|
||||
-- | Create a new, uninitialized stateful Tensor of the given shape.
|
||||
variable :: forall a . TensorType a => Shape -> Build (Tensor Ref a)
|
||||
variable shape' = buildOp $ opDef "Variable"
|
||||
& opAttr "shape" .~ shape'
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
|
||||
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
|
||||
placeholder shape' =
|
||||
buildOp $ opDef "Placeholder"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "shape" .~ shape'
|
||||
|
||||
-- Assign returns the input ref.
|
||||
assign :: forall a v . TensorType a
|
||||
=> Tensor Ref a -> Tensor v a -> Build (Tensor Ref a)
|
||||
assign = buildOp $ opDef "Assign"
|
||||
& opAttr "T" .~ tensorType (undefined :: a)
|
||||
& opAttr "use_locking" .~ True
|
||||
|
||||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
initializedVariable :: forall a . TensorType a
|
||||
=> Tensor Value a -> Build (Tensor Ref a)
|
||||
initializedVariable initializer = do
|
||||
v <- variable [] -- The shape is not known initially.
|
||||
(i :: Tensor Ref a) <-
|
||||
buildOp (opDef "Assign"
|
||||
& opAttr "T" .~ tensorType (undefined :: a)
|
||||
& opAttr "use_locking" .~ True
|
||||
& opAttr "validate_shape" .~ False
|
||||
)
|
||||
v initializer
|
||||
addInitializer =<< group i
|
||||
return v
|
||||
|
||||
-- | Creates a zero-initialized variable with the given shape.
|
||||
zeroInitializedVariable
|
||||
:: (TensorType a, Num a) =>
|
||||
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
||||
zeroInitializedVariable = initializedVariable . zeros
|
||||
|
||||
-- TODO: Support heterogeneous list of tensors.
|
||||
save :: forall a v . TensorType a
|
||||
=> ByteString -- ^ File path.
|
||||
-> [Tensor v a] -- ^ Tensors to save.
|
||||
-> Build ControlNode
|
||||
save path xs = do
|
||||
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
||||
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
|
||||
let types = replicate (length xs) (tensorType (undefined :: a))
|
||||
let saveOp = buildOp $ opDef "Save"
|
||||
& opAttr "T" .~ types
|
||||
saveOp (scalar path) (CoreOps.pack names) xs
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
--
|
||||
-- This version allows restoring from a checkpoint file that uses a different
|
||||
-- tensor name than the variable.
|
||||
restoreFromName :: forall a . TensorType a
|
||||
=> ByteString -- ^ File path.
|
||||
-> ByteString -- ^ Tensor name override.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> Build ControlNode
|
||||
restoreFromName path name x = do
|
||||
let restoreOp = buildOp $ opDef "Restore"
|
||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
restore :: forall a . TensorType a
|
||||
=> ByteString -- ^ File path.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> Build ControlNode
|
||||
restore path x = do
|
||||
name <- encodeUtf8 . unNodeName <$> renderNodeName x
|
||||
restoreFromName path name x
|
||||
|
||||
-- | Create a constant tensor.
|
||||
--
|
||||
-- The values should be in row major order, e.g.,
|
||||
--
|
||||
-- element 0: index (0, ..., 0)
|
||||
-- element 1: index (0, ..., 1)
|
||||
-- ...
|
||||
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
|
||||
constant (Shape shape') values
|
||||
| invalidLength = error invalidLengthMsg
|
||||
| otherwise = buildOp $ opDef "Const"
|
||||
& opAttr "value" .~ typedNode
|
||||
& opAttr "dtype" .~ nodeType
|
||||
where
|
||||
invalidLength = product shape' /= fromIntegral (length values)
|
||||
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
|
||||
(product shape')
|
||||
(length values)
|
||||
nodeType = tensorType (undefined :: a)
|
||||
typedNode :: TensorProto
|
||||
typedNode = def
|
||||
& dtype .~ nodeType
|
||||
& tensorShape.TensorShape.dim .~
|
||||
[def & TensorShape.size .~ x | x <- shape']
|
||||
& tensorVal .~ values
|
||||
|
||||
-- | Create a constant vector.
|
||||
vector :: TensorType a => [a] -> Tensor Value a
|
||||
vector xs = constant [fromIntegral $ length xs] xs
|
||||
|
||||
-- | Create a constant scalar.
|
||||
scalar :: forall a . TensorType a => a -> Tensor Value a
|
||||
scalar x = constant [] [x]
|
||||
|
||||
-- Random tensor from the unit normal distribution with bounded values.
|
||||
truncatedNormal :: forall a v . TensorType a
|
||||
=> Tensor v Int64 -- ^ Shape.
|
||||
-> Build (Tensor Value a)
|
||||
truncatedNormal = buildOp $ opDef "TruncatedNormal"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "T" .~ tensorType (undefined :: Int64)
|
||||
|
||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
||||
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
||||
|
||||
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
|
||||
shape = CoreOps.shape
|
||||
|
||||
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||
expandDims = CoreOps.expandDims
|
||||
|
||||
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
|
||||
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
|
||||
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
|
||||
reducedShape inputShape axes =
|
||||
let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7]
|
||||
axes32 = toInt32 axes -- [1, 2]
|
||||
toInt32 x = CoreOps.cast x :: Tensor Value Int32
|
||||
inputRank = CoreOps.size inputShape32 -- 4
|
||||
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
|
||||
axesShape = shape axesMod -- [2]
|
||||
in CoreOps.dynamicStitch -- [2, 1, 1, 7]
|
||||
[CoreOps.range 0 inputRank 1, -- [0, 1, 2, 3]
|
||||
axesMod] -- [1, 2]
|
||||
[inputShape32, -- [2, 3, 5, 7]
|
||||
CoreOps.fill axesShape 1] -- [1, 1]
|
191
tensorflow-ops/tensorflow-ops.cabal
Normal file
191
tensorflow-ops/tensorflow-ops.cabal
Normal file
|
@ -0,0 +1,191 @@
|
|||
name: tensorflow-ops
|
||||
version: 0.1.0.0
|
||||
synopsis: Friendly layer around TensorFlow bindings.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Simple
|
||||
cabal-version: >=1.22
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.Gradient
|
||||
, TensorFlow.Ops
|
||||
, TensorFlow.EmbeddingOps
|
||||
build-depends: proto-lens == 0.1.*
|
||||
, base >= 4.7 && < 5
|
||||
, bytestring
|
||||
, fgl
|
||||
, mtl
|
||||
, data-default
|
||||
, lens-family
|
||||
, containers
|
||||
, tensorflow == 0.1.*
|
||||
, tensorflow-proto == 0.1.*
|
||||
, tensorflow-core-ops == 0.1.*
|
||||
, text
|
||||
default-language: Haskell2010
|
||||
|
||||
Test-Suite BuildTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: BuildTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, transformers
|
||||
, vector
|
||||
|
||||
Test-Suite EmbeddingOpsTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: EmbeddingOpsTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, QuickCheck
|
||||
, base
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, test-framework-quickcheck2
|
||||
, vector
|
||||
|
||||
Test-Suite ArrayOpsTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: ArrayOpsTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, QuickCheck
|
||||
, base
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, test-framework-quickcheck2
|
||||
, transformers
|
||||
, vector
|
||||
|
||||
Test-Suite OpsTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: OpsTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, QuickCheck
|
||||
, base
|
||||
, bytestring
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, google-shim
|
||||
, temporary
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, test-framework-quickcheck2
|
||||
, transformers
|
||||
, vector
|
||||
|
||||
Test-Suite DataFlowOpsTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: DataFlowOpsTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, QuickCheck
|
||||
, base
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, test-framework-quickcheck2
|
||||
, vector
|
||||
|
||||
Test-Suite GradientTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: GradientTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
|
||||
Test-Suite MiscTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: MiscTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, bytestring
|
||||
, vector
|
||||
, google-shim
|
||||
, transformers
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
|
||||
Test-Suite TypesTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: TypesTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, QuickCheck
|
||||
, base
|
||||
, bytestring
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, transformers
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, test-framework-quickcheck2
|
||||
, vector
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
42
tensorflow-ops/tests/ArrayOpsTest.hs
Normal file
42
tensorflow-ops/tests/ArrayOpsTest.hs
Normal file
|
@ -0,0 +1,42 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Google.Test (googleTest)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
-- | Test split and concat are inverses.
|
||||
testSplit = testCase "testSplit" $ TF.runSession $ do
|
||||
let original = TF.constant [2, 3] [0..5 :: Float]
|
||||
splitList = CoreOps.split 3 dim original
|
||||
restored = CoreOps.concat dim splitList
|
||||
dim = 1 -- dimension to split along (with size of 3 in original)
|
||||
liftIO $ 3 @=? length splitList
|
||||
(x, y, z) <-
|
||||
TF.buildAnd TF.run $ return (original, restored, splitList !! 1)
|
||||
liftIO $ x @=? (y :: V.Vector Float)
|
||||
liftIO $ V.fromList [1, 4] @=? z
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testSplit
|
||||
]
|
181
tensorflow-ops/tests/BuildTest.hs
Normal file
181
tensorflow-ops/tests/BuildTest.hs
Normal file
|
@ -0,0 +1,181 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Lens.Family2 ((^.))
|
||||
import Data.List (sort)
|
||||
import Proto.Tensorflow.Core.Framework.Graph
|
||||
( node )
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
( NodeDef
|
||||
, device
|
||||
, name
|
||||
, op )
|
||||
import TensorFlow.Build
|
||||
( Build
|
||||
, BuildT
|
||||
, asGraphDef
|
||||
, evalBuildT
|
||||
, flushNodeBuffer
|
||||
, hoistBuildT
|
||||
, render
|
||||
, withDevice
|
||||
, colocateWith
|
||||
, withNameScope
|
||||
)
|
||||
import TensorFlow.ControlFlow (named)
|
||||
import TensorFlow.Nodes (unScalar)
|
||||
import TensorFlow.Ops
|
||||
( add
|
||||
, assign
|
||||
, constant
|
||||
, initializedVariable
|
||||
, variable
|
||||
)
|
||||
import TensorFlow.Output (Device(..))
|
||||
import TensorFlow.Tensor (Tensor, Value, Ref)
|
||||
import TensorFlow.Session
|
||||
( build
|
||||
, buildAnd
|
||||
, run
|
||||
, runSession
|
||||
, run_
|
||||
)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import Google.Test (googleTest)
|
||||
import qualified Data.Vector as V
|
||||
|
||||
-- | Test named behavior.
|
||||
testNamed = testCase "testNamed" $ do
|
||||
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
|
||||
nodeDef :: NodeDef
|
||||
nodeDef = head $ asGraphDef graph ^. node
|
||||
"RefIdentity" @=? (nodeDef ^. op)
|
||||
"foo" @=? (nodeDef ^. name)
|
||||
|
||||
-- | Test named deRef behavior.
|
||||
testNamedDeRef = testCase "testNamedDeRef" $ do
|
||||
let graph = named "foo" <$> do
|
||||
v :: Tensor Ref Float <- variable []
|
||||
assign v 5
|
||||
-- TODO: Implement TensorFlow get_variable and test it.
|
||||
runSession $ do
|
||||
out <- buildAnd run graph
|
||||
liftIO $ 5 @=? (unScalar out :: Float)
|
||||
|
||||
-- | Test that "run" will render and extend any pure ops that haven't already
|
||||
-- been rendered.
|
||||
testPureRender = testCase "testPureRender" $ runSession $ do
|
||||
result <- run $ 2 `add` 2
|
||||
liftIO $ 4 @=? (unScalar result :: Float)
|
||||
|
||||
-- | Test that "run" assigns any previously accumulated initializers.
|
||||
testInitializedVariable =
|
||||
testCase "testInitializedVariable" $ runSession $ do
|
||||
(formula, reset) <- build $ do
|
||||
v <- initializedVariable 42
|
||||
r <- assign v 24
|
||||
return (1 `add` v, r)
|
||||
result <- run formula
|
||||
liftIO $ 43 @=? (unScalar result :: Float)
|
||||
run_ reset -- Updates v to a different value
|
||||
rerunResult <- run formula
|
||||
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
||||
|
||||
testInitializedVariableShape =
|
||||
testCase "testInitializedVariableShape" $ runSession $ do
|
||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
||||
result <- run vector
|
||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
|
||||
-- | Test nameScoped behavior.
|
||||
testNameScoped = testCase "testNameScoped" $ do
|
||||
let graph = withNameScope "foo" $ variable [] :: Build (Tensor Ref Float)
|
||||
nodeDef :: NodeDef
|
||||
[nodeDef] = asGraphDef graph ^. node
|
||||
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
|
||||
"Variable" @=? (nodeDef ^. op)
|
||||
|
||||
-- | Test combined named and nameScoped behavior.
|
||||
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||
let graph :: Build (Tensor Ref Float)
|
||||
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
|
||||
nodeDef :: NodeDef
|
||||
nodeDef = head $ asGraphDef graph ^. node
|
||||
"RefIdentity" @=? (nodeDef ^. op)
|
||||
"foo1/bar1" @=? (nodeDef ^. name)
|
||||
|
||||
-- | Lift a Build action into a context for HUnit to run.
|
||||
liftBuild :: Build a -> BuildT IO a
|
||||
liftBuild = hoistBuildT (return . runIdentity)
|
||||
|
||||
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
|
||||
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
|
||||
flushed field = sort . map field <$> liftBuild flushNodeBuffer
|
||||
|
||||
-- | Test the interaction of rendering, CSE and scoping.
|
||||
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
||||
liftBuild renderNodes
|
||||
names <- flushed (^. name)
|
||||
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
|
||||
-- Render the nodes in a different scope, which should cause them
|
||||
-- to be distinct from the previous ones.
|
||||
liftBuild $ withNameScope "foo" renderNodes
|
||||
scopedNames <- flushed (^. name)
|
||||
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
|
||||
where
|
||||
renderNodes = do
|
||||
-- A stateful op and a pure op.
|
||||
_ :: Tensor Ref Float <- variable []
|
||||
_ :: Tensor Value Float <- render 3
|
||||
-- Another stateful op, and a pure op which should be
|
||||
-- deduped with the previous one.
|
||||
_ :: Tensor Ref Float <- variable []
|
||||
_ :: Tensor Value Float <- render 3
|
||||
return ()
|
||||
|
||||
-- | Test the interaction of rendering, CSE and scoping.
|
||||
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
|
||||
liftBuild renderNodes
|
||||
devices <- flushed (\x -> (x ^. name, x ^. device))
|
||||
liftIO $ [ ("Add_2","dev0")
|
||||
, ("Const_1","dev0")
|
||||
, ("Variable_0","dev0")] @=? devices
|
||||
where
|
||||
renderNodes = do
|
||||
-- A stateful op and a pure op.
|
||||
var :: Tensor Ref Float <- withDevice (Just $ Device "dev0") $ variable []
|
||||
-- Uses render to cause the expression be added to the graph.
|
||||
_ <- colocateWith var $ render $ 3 `add` var
|
||||
return ()
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testInitializedVariable
|
||||
, testInitializedVariableShape
|
||||
, testDeviceColocation
|
||||
, testNamed
|
||||
, testNamedDeRef
|
||||
, testNameScoped
|
||||
, testNamedAndScoped
|
||||
, testPureRender
|
||||
, testRenderDedup
|
||||
]
|
66
tensorflow-ops/tests/DataFlowOpsTest.hs
Normal file
66
tensorflow-ops/tests/DataFlowOpsTest.hs
Normal file
|
@ -0,0 +1,66 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
import Google.Test (googleTest)
|
||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||
import Test.HUnit ((@=?))
|
||||
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
-- DynamicSplit is undone with DynamicStitch to get the original input
|
||||
-- back.
|
||||
testDynamicPartitionStitchInverse :: forall a.
|
||||
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property
|
||||
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
||||
let splitParts :: [TF.Tensor TF.Value a] =
|
||||
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
||||
partTensor = TF.vector partitions
|
||||
restitchIndices = CoreOps.dynamicPartition numParts
|
||||
(TF.vector [0..genericLength values-1])
|
||||
partTensor
|
||||
-- drop (numParts - 2) from both args to expose b/27343984
|
||||
restitch = CoreOps.dynamicStitch restitchIndices splitParts
|
||||
in monadicIO $ run $ do
|
||||
fromIntegral numParts @=? length splitParts
|
||||
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch
|
||||
V.fromList values @=? valuesOut
|
||||
|
||||
data StitchExample a = StitchExample Int64 [a] [Int32]
|
||||
deriving Show
|
||||
|
||||
instance Arbitrary a => Arbitrary (StitchExample a) where
|
||||
arbitrary = do
|
||||
-- Limits the size of the vector.
|
||||
size <- choose (1, 100)
|
||||
values <- vectorOf size arbitrary
|
||||
numParts <- choose (2, 15)
|
||||
partitions <- vectorOf size (choose (0, fromIntegral numParts - 1))
|
||||
return $ StitchExample numParts values partitions
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest
|
||||
[ testProperty "DynamicPartitionStitchInverse"
|
||||
(testDynamicPartitionStitchInverse :: StitchExample Int64 -> Property)
|
||||
]
|
88
tensorflow-ops/tests/EmbeddingOpsTest.hs
Normal file
88
tensorflow-ops/tests/EmbeddingOpsTest.hs
Normal file
|
@ -0,0 +1,88 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
-- | Tests for EmbeddingOps.
|
||||
module Main where
|
||||
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.EmbeddingOps (embeddingLookup)
|
||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||
import Test.HUnit ((@=?))
|
||||
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
-- Verifies that direct gather is the same as dynamic split into
|
||||
-- partitions, followed by embedding lookup.
|
||||
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
|
||||
=> LookupExample a -> Property
|
||||
testEmbeddingLookupUndoesSplit
|
||||
(LookupExample numParts
|
||||
shape@(TF.Shape (firstDim : restDims))
|
||||
values
|
||||
indices) =
|
||||
let modShardedValues :: [TF.Tensor TF.Value a] =
|
||||
CoreOps.dynamicPartition numParts shapedValues cyclicCounter
|
||||
cyclicCounter :: TF.Tensor TF.Value Int32 =
|
||||
TF.vector [0..fromIntegral firstDim-1]
|
||||
`CoreOps.mod` fromIntegral numParts
|
||||
indicesVector = TF.vector indices
|
||||
directs = CoreOps.gather shapedValues indicesVector
|
||||
shapedValues = TF.constant shape values
|
||||
in monadicIO $ run $ do
|
||||
(shapeOut, got, want :: V.Vector a) <-
|
||||
TF.runSession $ TF.buildAnd TF.run $ do
|
||||
embeddings <- embeddingLookup modShardedValues indicesVector
|
||||
return (TF.cast (TF.shape embeddings), embeddings, directs)
|
||||
-- Checks the explicitly documented invariant of embeddingLookup.
|
||||
shapeOut @=? V.fromList (genericLength indices : restDims)
|
||||
got @=? want
|
||||
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
|
||||
|
||||
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit.
|
||||
data LookupExample a = LookupExample
|
||||
Int64 -- ^ number of ways to split.
|
||||
TF.Shape -- ^ shape of the generated tensor
|
||||
[a] -- ^ data for the tensor
|
||||
[Int32] -- ^ indices to split the tensor by
|
||||
deriving Show
|
||||
|
||||
instance Arbitrary a => Arbitrary (LookupExample a) where
|
||||
arbitrary = do
|
||||
rank <- choose (1, 4)
|
||||
-- Takes rank-th root of 100 to cap the tensor size.
|
||||
let maxDim = fromIntegral $ ceiling $ 100 ** (1 / fromIntegral rank)
|
||||
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
|
||||
values <- vectorOf (fromIntegral $ product shape) arbitrary
|
||||
numParts <- choose (2, 15)
|
||||
indSize <- choose (0, fromIntegral $ firstDim - 1)
|
||||
indices <- vectorOf indSize (choose (0, fromIntegral firstDim - 1))
|
||||
return $ LookupExample numParts (TF.Shape shape) values indices
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest
|
||||
[ testProperty "EmbeddingLookupUndoesSplit"
|
||||
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
|
||||
]
|
158
tensorflow-ops/tests/GradientTest.hs
Normal file
158
tensorflow-ops/tests/GradientTest.hs
Normal file
|
@ -0,0 +1,158 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
import Data.List (sort)
|
||||
import Data.ProtoLens.TextFormat (showMessage)
|
||||
import Google.Test (googleTest)
|
||||
import Lens.Family2 ((^..))
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||
|
||||
testGradientSimple :: Test
|
||||
testGradientSimple = testCase "testGradientSimple" $ do
|
||||
let x = TF.scalar (3 :: Float)
|
||||
b = TF.scalar (4 :: Float)
|
||||
y = x*x + b
|
||||
grads = TF.gradients y [x, b]
|
||||
-- Assert that the gradients are right.
|
||||
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
||||
6 @=? TF.unScalar dx
|
||||
1 @=? TF.unScalar db
|
||||
-- Assert that the graph has the expected ops.
|
||||
let graphDef = TF.asGraphDef grads
|
||||
putStrLn $ showMessage graphDef
|
||||
let ops = graphDef ^.. node . traverse . op
|
||||
expected = [ "Const"
|
||||
, "Mul"
|
||||
, "Const"
|
||||
, "Add"
|
||||
-- Default output gradient of y.
|
||||
, "Shape"
|
||||
, "Const"
|
||||
, "Fill"
|
||||
-- Add gradient.
|
||||
, "Shape"
|
||||
, "Shape"
|
||||
, "BroadcastGradientArgs"
|
||||
, "Sum"
|
||||
, "Sum"
|
||||
, "Reshape"
|
||||
, "Reshape"
|
||||
-- Mul gradient.
|
||||
, "Shape"
|
||||
-- This Op gets dedup'd because the inputs are the same.
|
||||
-- TODO(fmayle): The same would happen to the Mul and Sum ops
|
||||
-- below if the gradient function didn't multiply one as
|
||||
-- 'dz * y' and the other as 'x * dz'. We could change the
|
||||
-- order, but I'm going to keep it the same as the python
|
||||
-- version for now.
|
||||
--
|
||||
-- , "Shape"
|
||||
, "BroadcastGradientArgs"
|
||||
, "Mul"
|
||||
, "Mul"
|
||||
, "Sum"
|
||||
, "Sum"
|
||||
, "Reshape"
|
||||
, "Reshape"
|
||||
-- AddN to combine x's output gradients.
|
||||
, "AddN"
|
||||
]
|
||||
sort expected @=? sort ops
|
||||
|
||||
testGradientDisconnected :: Test
|
||||
testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
||||
let x = TF.scalar (3 :: Float)
|
||||
b = TF.scalar (4 :: Float)
|
||||
grads = TF.gradients x [x, b]
|
||||
-- Assert that the gradients are right.
|
||||
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
||||
1 @=? TF.unScalar dx
|
||||
0 @=? TF.unScalar db
|
||||
-- Assert that the graph has the expected ops.
|
||||
let graphDef = TF.asGraphDef grads
|
||||
putStrLn $ showMessage graphDef
|
||||
let ops = graphDef ^.. node . traverse . op
|
||||
expected = [ "Const"
|
||||
, "Const"
|
||||
-- Default output gradient of x.
|
||||
, "Shape"
|
||||
, "Const"
|
||||
, "Fill"
|
||||
-- Default output gradient of b.
|
||||
, "ZerosLike"
|
||||
]
|
||||
sort expected @=? sort ops
|
||||
|
||||
|
||||
-- Test that identical "stateful" ops work with createGraph.
|
||||
testCreateGraphStateful :: Test
|
||||
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
||||
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
let shape = TF.constant (TF.Shape [1]) [1]
|
||||
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||
TF.gradients (x + y*3) [x, y]
|
||||
-- If this test fails, it will likely be caused by an exception within
|
||||
-- `TF.gradients`. These asserts are extra.
|
||||
1 @=? TF.unScalar dx
|
||||
3 @=? TF.unScalar dy
|
||||
|
||||
|
||||
-- Test that name scopes work with createGraph.
|
||||
testCreateGraphNameScopes :: Test
|
||||
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
let shape = TF.constant (TF.Shape [1]) [1]
|
||||
x :: TF.Tensor TF.Value Float <-
|
||||
TF.withNameScope "foo" (TF.truncatedNormal shape)
|
||||
TF.gradients x [x]
|
||||
-- If this test fails, it will likely be caused by an exception within
|
||||
-- `TF.gradients`. This assert is extra.
|
||||
1 @=? TF.unScalar dx
|
||||
|
||||
|
||||
-- Test that createGraph can handle graphs with diamond shapes.
|
||||
testDiamond :: Test
|
||||
testDiamond = testCase "testDiamond" $ do
|
||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
let x = TF.vector [1]
|
||||
y = x*x
|
||||
z = y*y
|
||||
TF.gradients z [x]
|
||||
(4 :: Float) @=? TF.unScalar dx
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testGradientSimple
|
||||
, testGradientDisconnected
|
||||
, testCreateGraphStateful
|
||||
, testCreateGraphNameScopes
|
||||
, testDiamond
|
||||
]
|
46
tensorflow-ops/tests/MiscTest.hs
Normal file
46
tensorflow-ops/tests/MiscTest.hs
Normal file
|
@ -0,0 +1,46 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import Google.Test
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import TensorFlow.Ops
|
||||
import TensorFlow.Session
|
||||
|
||||
-- | Test fetching multiple outputs from an op.
|
||||
testMultipleOutputs = testCase "testMultipleOutputs" $
|
||||
runSession $ do
|
||||
(values, indices) <- run $ topK 2 $ constant [1, 4] [10, 40, 20, 30]
|
||||
liftIO $ [40, 30] @=? V.toList (values :: V.Vector Float)
|
||||
liftIO $ [1, 3] @=? V.toList (indices :: V.Vector Int32)
|
||||
|
||||
-- | Test op with variable number of inputs.
|
||||
testVarargs = testCase "testVarargs" $
|
||||
runSession $ do
|
||||
xs <- run $ pack $ map scalar [1..8]
|
||||
liftIO $ [1..8] @=? V.toList (xs :: V.Vector Float)
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testMultipleOutputs
|
||||
, testVarargs
|
||||
]
|
70
tensorflow-ops/tests/OpsTest.hs
Normal file
70
tensorflow-ops/tests/OpsTest.hs
Normal file
|
@ -0,0 +1,70 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Google.Test (googleTest)
|
||||
import System.IO.Temp (withSystemTempDirectory)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import qualified Data.ByteString.Char8 as B8
|
||||
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
|
||||
-- | Test that one can easily determine number of elements in the tensor.
|
||||
testSize = testCase "testSize" $ do
|
||||
x <- eval $ TF.size (TF.constant [2, 3] [0..5 :: Float])
|
||||
TF.Scalar (2 * 3 :: Int32) @=? x
|
||||
|
||||
eval = TF.runSession . TF.buildAnd TF.run . return
|
||||
|
||||
-- | Confirms that the original example from Python code works.
|
||||
testReducedShape = testCase "testReducedShape" $ do
|
||||
x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64])
|
||||
(TF.vector [1, 2 :: Int32])
|
||||
V.fromList [2, 1, 1, 7 :: Int32] @=? x
|
||||
|
||||
testSaveRestore = testCase "testSaveRestore" $
|
||||
withSystemTempDirectory "" $ \dirPath -> do
|
||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||
var :: TF.Build (TF.Tensor TF.Ref Float)
|
||||
var = TF.render =<< TF.named "a" <$> TF.zeroInitializedVariable []
|
||||
TF.runSession $ do
|
||||
v <- TF.build var
|
||||
TF.buildAnd TF.run_ $ TF.assign v 134
|
||||
TF.buildAnd TF.run_ $ TF.save path [v]
|
||||
result <- TF.runSession $ do
|
||||
v <- TF.build var
|
||||
TF.buildAnd TF.run_ $ TF.restore path v
|
||||
TF.run v
|
||||
liftIO $ TF.Scalar 134 @=? result
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testSaveRestore
|
||||
, testSize
|
||||
, testReducedShape
|
||||
]
|
119
tensorflow-ops/tests/TypesTest.hs
Normal file
119
tensorflow-ops/tests/TypesTest.hs
Normal file
|
@ -0,0 +1,119 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
import Control.Monad (replicateM)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int64)
|
||||
import Google.Test (googleTest)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||
import Test.HUnit ((@=?))
|
||||
import Test.QuickCheck (Arbitrary(..), listOf, suchThat)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Char8 as B8
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
instance Arbitrary B.ByteString where
|
||||
arbitrary = B.pack <$> arbitrary
|
||||
|
||||
-- Test encoding tensors, feeding them through tensorflow, and decoding the
|
||||
-- results.
|
||||
testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
||||
TF.runSession $ do
|
||||
let floatData = V.fromList [1..6 :: Float]
|
||||
stringData = V.fromList [B8.pack (show x) | x <- [1..6]]
|
||||
f <- TF.build $ TF.placeholder [2,3]
|
||||
s <- TF.build $ TF.placeholder [2,3]
|
||||
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
||||
, TF.feed s (TF.encodeTensorData [2,3] stringData)
|
||||
]
|
||||
-- It is an error to fetch a tensor that is being fed, so the tensors
|
||||
-- are passed through identity.
|
||||
(f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s)
|
||||
liftIO $ do
|
||||
floatData @=? f'
|
||||
stringData @=? s'
|
||||
|
||||
|
||||
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
|
||||
deriving Show
|
||||
|
||||
instance Arbitrary a => Arbitrary (TensorDataInputs a) where
|
||||
arbitrary = do
|
||||
-- Limit the size of the final vector, and also guard against overflow
|
||||
-- (i.e., p<0) when there are too many dimensions
|
||||
let validProduct p = p > 0 && p < 100
|
||||
sizes <- listOf (arbitrary `suchThat` (>0))
|
||||
`suchThat` (validProduct . product)
|
||||
elems <- replicateM (fromIntegral $ product sizes) arbitrary
|
||||
return $ TensorDataInputs sizes (V.fromList elems)
|
||||
|
||||
-- Test that a vector is unchanged after being encoded and decoded.
|
||||
encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
|
||||
encodeDecodeProp (TensorDataInputs shape vec) =
|
||||
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
|
||||
|
||||
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
|
||||
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
|
||||
|
||||
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
|
||||
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
|
||||
|
||||
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
|
||||
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
|
||||
|
||||
doubleOrInt64Func :: TF.OneOf '[Double, Int64] a => a -> a
|
||||
doubleOrInt64Func = id
|
||||
|
||||
doubleOrFloatFunc :: TF.OneOf '[Double, Float] a => a -> a
|
||||
doubleOrFloatFunc = id
|
||||
|
||||
doubleFunc :: TF.OneOf '[Double] a => a -> a
|
||||
doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
|
||||
|
||||
-- No explicit type signature; make sure it can be inferred automatically.
|
||||
-- (Note: this would fail if we didn't have NoMonomorphismRestriction, since it
|
||||
-- can't simplify the type all the way to `Double -> Double`.
|
||||
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
|
||||
|
||||
typeConstraintTests = testCase "type constraints" $ do
|
||||
42 @=? doubleOrInt64Func (42 :: Double)
|
||||
42 @=? doubleOrInt64Func (42 :: Int64)
|
||||
42 @=? doubleOrFloatFunc (42 :: Double)
|
||||
42 @=? doubleOrFloatFunc (42 :: Float)
|
||||
42 @=? doubleFunc (42 :: Double)
|
||||
42 @=? doubleFuncNoSig (42 :: Double)
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testFFIRoundTrip
|
||||
, testEncodeDecodeQcFloat
|
||||
, testEncodeDecodeQcInt64
|
||||
, testEncodeDecodeQcString
|
||||
, typeConstraintTests
|
||||
]
|
17
tensorflow-proto/Setup.hs
Normal file
17
tensorflow-proto/Setup.hs
Normal file
|
@ -0,0 +1,17 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
import Data.ProtoLens.Setup
|
||||
|
||||
main = defaultMainGeneratingProtos "../third_party/tensorflow"
|
40
tensorflow-proto/tensorflow-proto.cabal
Normal file
40
tensorflow-proto/tensorflow-proto.cabal
Normal file
|
@ -0,0 +1,40 @@
|
|||
name: tensorflow-proto
|
||||
version: 0.1.0.0
|
||||
synopsis: TensorFlow protocol buffers.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Custom
|
||||
cabal-version: >=1.22
|
||||
extra-source-files: ../third_party/tensorflow/tensorflow/core/framework/*.proto
|
||||
, ../third_party/tensorflow/tensorflow/core/protobuf/config.proto
|
||||
|
||||
library
|
||||
exposed-modules: Proto.Tensorflow.Core.Framework.AttrValue
|
||||
, Proto.Tensorflow.Core.Framework.Graph
|
||||
, Proto.Tensorflow.Core.Framework.NodeDef
|
||||
, Proto.Tensorflow.Core.Framework.OpDef
|
||||
, Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||
, Proto.Tensorflow.Core.Framework.Tensor
|
||||
, Proto.Tensorflow.Core.Framework.TensorShape
|
||||
, Proto.Tensorflow.Core.Framework.Types
|
||||
, Proto.Tensorflow.Core.Protobuf.Config
|
||||
other-modules: Proto.Tensorflow.Core.Framework.AllocationDescription
|
||||
, Proto.Tensorflow.Core.Framework.CostGraph
|
||||
, Proto.Tensorflow.Core.Framework.Function
|
||||
, Proto.Tensorflow.Core.Framework.StepStats
|
||||
, Proto.Tensorflow.Core.Framework.TensorDescription
|
||||
, Proto.Tensorflow.Core.Framework.Versions
|
||||
build-depends: proto-lens == 0.1.*
|
||||
, proto-lens-protoc == 0.1.*
|
||||
, base >= 4.7 && < 5
|
||||
default-language: Haskell2010
|
||||
include-dirs: .
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
3
tensorflow-queue/Setup.hs
Normal file
3
tensorflow-queue/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
|||
import Distribution.Simple
|
||||
|
||||
main = defaultMain
|
78
tensorflow-queue/src/TensorFlow/Queue.hs
Normal file
78
tensorflow-queue/src/TensorFlow/Queue.hs
Normal file
|
@ -0,0 +1,78 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
-- | Queues in TensorFlow graph. Very limited support for now.
|
||||
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
|
||||
import TensorFlow.BuildOp (buildOp)
|
||||
import TensorFlow.ControlFlow (group)
|
||||
import TensorFlow.Tensor (Ref, Tensor)
|
||||
import TensorFlow.Types (TensorType, tensorType)
|
||||
|
||||
-- | A queue carrying tuples. The underlying structure is more
|
||||
-- versatile and can be made to support arbitrary tuples.
|
||||
data Queue2 a b = Queue2 { handle :: Handle }
|
||||
|
||||
type Handle = Tensor Ref ByteString
|
||||
|
||||
-- | Adds the given values to the queue.
|
||||
enqueue :: forall a b v1 v2. (TensorType a, TensorType b)
|
||||
=> Queue2 a b
|
||||
-> Tensor v1 a
|
||||
-> Tensor v2 b
|
||||
-> Build ControlNode
|
||||
enqueue q =
|
||||
buildOp (opDef "QueueEnqueue"
|
||||
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a)
|
||||
, tensorType (undefined :: b)])
|
||||
(handle q)
|
||||
|
||||
-- | Retrieves the values from the queue.
|
||||
dequeue :: forall a b . (TensorType a, TensorType b)
|
||||
=> Queue2 a b
|
||||
-> Build (Tensor Ref a, Tensor Ref b)
|
||||
-- ^ Dequeued tensors. They are paired in a sense
|
||||
-- that values appear together, even if they are
|
||||
-- not consumed together.
|
||||
dequeue q =
|
||||
buildOp (opDef "QueueDequeue"
|
||||
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
||||
, tensorType (undefined :: b)])
|
||||
(handle q)
|
||||
|
||||
-- | Creates a new queue with the given capacity and shared name.
|
||||
makeQueue2 :: forall a b . (TensorType a, TensorType b)
|
||||
=> Int64 -- ^ The upper bound on the number of elements in
|
||||
-- this queue. Negative numbers mean no limit.
|
||||
-> ByteString -- ^ If non-empty, this queue will be shared
|
||||
-- under the given name across multiple sessions.
|
||||
-> Build (Queue2 a b)
|
||||
makeQueue2 capacity sharedName = do
|
||||
q <- buildOp (opDef "FIFOQueue"
|
||||
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
||||
, tensorType (undefined :: b)]
|
||||
& opAttr "shared_name" .~ sharedName
|
||||
& opAttr "capacity" .~ capacity
|
||||
)
|
||||
group q >>= addInitializer
|
||||
return (Queue2 q)
|
||||
|
||||
-- TODO(gnezdo): Figure out the closing story for queues.
|
51
tensorflow-queue/tensorflow-queue.cabal
Normal file
51
tensorflow-queue/tensorflow-queue.cabal
Normal file
|
@ -0,0 +1,51 @@
|
|||
name: tensorflow-queue
|
||||
version: 0.1.0.0
|
||||
synopsis: Basic access to TensorFlow queues.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Simple
|
||||
cabal-version: >=1.22
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.Queue
|
||||
build-depends: proto-lens == 0.1.*
|
||||
, base >= 4.7 && < 5
|
||||
, bytestring
|
||||
, lens-family
|
||||
, containers
|
||||
, tensorflow-proto == 0.1.*
|
||||
, tensorflow-core-ops == 0.1.*
|
||||
, tensorflow
|
||||
, text
|
||||
default-language: Haskell2010
|
||||
|
||||
Test-Suite QueueTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: QueueTest.hs
|
||||
hs-source-dirs: tests
|
||||
-- Uses multiple threads and blocks without this option.
|
||||
ghc-options: -threaded
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, bytestring
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, tensorflow-queue
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, transformers
|
||||
, vector
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
79
tensorflow-queue/tests/QueueTest.hs
Normal file
79
tensorflow-queue/tests/QueueTest.hs
Normal file
|
@ -0,0 +1,79 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int64)
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Nodes (Scalar(..))
|
||||
import TensorFlow.Ops (scalar)
|
||||
import TensorFlow.Queue
|
||||
import TensorFlow.Session
|
||||
( asyncProdNodes
|
||||
, build
|
||||
, buildAnd
|
||||
, run
|
||||
, runSession
|
||||
, run_
|
||||
)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import qualified Data.ByteString as BS
|
||||
|
||||
-- | Test basic queue behaviors.
|
||||
testBasic = testCase "testBasic" $ runSession $ do
|
||||
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
|
||||
buildAnd run_ (enqueue q 42 (scalar "Hi"))
|
||||
x <- buildAnd run (dequeue q)
|
||||
liftIO $ (Scalar 42, Scalar "Hi") @=? x
|
||||
|
||||
buildAnd run_ (enqueue q 56 (scalar "Bar"))
|
||||
y <- buildAnd run (dequeue q)
|
||||
liftIO $ (Scalar 56, Scalar "Bar") @=? y
|
||||
|
||||
-- | Test queue pumping.
|
||||
testPump = testCase "testPump" $ runSession $ do
|
||||
(deq, pump) <- build $ do
|
||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
|
||||
(,) <$> dequeue q
|
||||
<*> enqueue q 31 (scalar "Baz")
|
||||
-- This is a realistic use. The pump inputs are pre-bound to some
|
||||
-- nodes that produce values when pumped (e.g. read from a
|
||||
-- file).
|
||||
run_ (pump, pump)
|
||||
|
||||
(x, y) <- run (deq, deq)
|
||||
liftIO $ (Scalar 31, Scalar "Baz") @=? x
|
||||
liftIO $ (Scalar 31, Scalar "Baz") @=? y
|
||||
|
||||
testAsync = testCase "testAsync" $ runSession $ do
|
||||
(deq, pump) <- build $ do
|
||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
|
||||
(,) <$> dequeue q
|
||||
<*> enqueue q 10 (scalar "Async")
|
||||
-- Pumps the queue until canceled by runSession exiting.
|
||||
asyncProdNodes pump
|
||||
-- Picks up a couple values and verifies they are as expected.
|
||||
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
||||
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testBasic
|
||||
, testPump
|
||||
, testAsync
|
||||
]
|
3
tensorflow/Setup.hs
Normal file
3
tensorflow/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
|||
import Distribution.Simple
|
||||
|
||||
main = defaultMain
|
376
tensorflow/src/TensorFlow/Build.hs
Normal file
376
tensorflow/src/TensorFlow/Build.hs
Normal file
|
@ -0,0 +1,376 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
module TensorFlow.Build
|
||||
( -- * Graph node types
|
||||
ControlNode(..)
|
||||
, Unique
|
||||
-- * Ops
|
||||
, explicitName
|
||||
, implicitName
|
||||
, opDef
|
||||
, opDefWithName
|
||||
, opName
|
||||
, opType
|
||||
, opAttr
|
||||
, opInputs
|
||||
, opControlInputs
|
||||
-- * The Build monad
|
||||
, GraphState
|
||||
, render
|
||||
, renderNodeName
|
||||
, renderedNodeDefs
|
||||
, BuildT
|
||||
, Build
|
||||
, addInitializer
|
||||
, hoistBuildT
|
||||
, evalBuildT
|
||||
, runBuildT
|
||||
, asGraphDef
|
||||
, addGraphDef
|
||||
, flushInitializers
|
||||
, flushNodeBuffer
|
||||
-- * Creating and looking up Ops
|
||||
, getOrAddOp
|
||||
, addNewOp
|
||||
, renderOutput
|
||||
-- * Modifying all nodes in a Build action
|
||||
, colocateWith
|
||||
, withStateLens
|
||||
, withDevice
|
||||
, withNameScope
|
||||
, withNodeDependencies
|
||||
-- * Internal Summary related bits.
|
||||
, addSummary
|
||||
, SummaryTensor
|
||||
, collectAllSummaries
|
||||
) where
|
||||
|
||||
import Control.Monad.IO.Class (MonadIO(..))
|
||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Default (def)
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.Monoid ((<>))
|
||||
import qualified Data.Set as Set
|
||||
import Data.Set (Set)
|
||||
import Data.String (IsString(..))
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens', (.~), (^.), (&))
|
||||
import Lens.Family2.State.Strict (MonadState, use, uses, (.=), (<>=), (%=))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.Graph
|
||||
( GraphDef
|
||||
, node
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
( NodeDef
|
||||
, attr
|
||||
, input
|
||||
, device
|
||||
, name
|
||||
, op
|
||||
)
|
||||
|
||||
import TensorFlow.Orphans ()
|
||||
import TensorFlow.Output
|
||||
import TensorFlow.Tensor
|
||||
|
||||
newtype Unique = Unique Int
|
||||
deriving (Eq, Ord, Enum)
|
||||
|
||||
--------------
|
||||
|
||||
implicitName :: PendingNodeName
|
||||
implicitName = ImplicitName
|
||||
|
||||
explicitName :: Text -> PendingNodeName
|
||||
explicitName = ExplicitName
|
||||
|
||||
newtype Scope = Scope {unScope :: Text}
|
||||
deriving (Eq, Ord, IsString)
|
||||
|
||||
instance Show Scope where
|
||||
show = show . unScope
|
||||
|
||||
opDef :: OpType -> OpDef
|
||||
opDef = opDefWithName ImplicitName
|
||||
|
||||
opDefWithName :: PendingNodeName -> OpType -> OpDef
|
||||
opDefWithName n t = OpDef
|
||||
{ _opName = n
|
||||
, _opType = t
|
||||
, _opAttrs = Map.empty
|
||||
, _opInputs = []
|
||||
, _opControlInputs = []
|
||||
}
|
||||
|
||||
-- | Synonym for the tensors that return serialized Summary proto.
|
||||
type SummaryTensor = Tensor Value ByteString
|
||||
|
||||
data GraphState = GraphState
|
||||
{ _renderedNodes :: !(Map.Map PendingNode NodeDef)
|
||||
-- ^ Nodes which have been rendered. Keeps track of the unique ID we
|
||||
-- assign each implicitly-named node. Also prevents us from adding the
|
||||
-- same node (implicit or explicit) more than once to the nodeBuffer.
|
||||
, _renderedNodeDefs :: !(Map.Map NodeName NodeDef)
|
||||
-- ^ The NodeDefs of nodes which have been rendered. Used by the
|
||||
-- Gradient module to inspect the node graph.
|
||||
, _nodeBuffer :: [NodeDef]
|
||||
-- ^ A list of nodes that should be passed to TensorFlow during
|
||||
-- the next call to Session.extend (TF_ExtendGraph).
|
||||
, _nextUnique :: !Unique
|
||||
-- ^ Unique ID for the next node
|
||||
-- TODO(judahjacobson): watch for clashes between auto and user names.
|
||||
, _defaultDevice :: !(Maybe Device)
|
||||
, _currentScope :: [Scope]
|
||||
, _defaultControlInputs :: !(Set NodeName)
|
||||
, _initializationNodes :: [NodeName]
|
||||
-- ^ The nodes to run next time a TF.run is issued, typically
|
||||
-- variable initializers.
|
||||
, _summaries :: [SummaryTensor]
|
||||
-- ^ The tensors for summary
|
||||
}
|
||||
|
||||
-- | A node definition without its final name. Used as a key in the
|
||||
-- "renderedNodes" map.
|
||||
-- The NodeDef contained inside has an empty "name" field.
|
||||
data PendingNode = PendingNode [Scope] !PendingNodeName !NodeDef
|
||||
deriving (Eq, Ord)
|
||||
|
||||
-- Returns an _incomplete_ NodeDef. The name is fixed by addNewOpFromPending.
|
||||
pendingNodeDef :: PendingNode -> NodeDef
|
||||
pendingNodeDef (PendingNode _ _ n) = n
|
||||
|
||||
initGraphState :: GraphState
|
||||
initGraphState =
|
||||
GraphState Map.empty Map.empty [] (Unique 0) Nothing [] Set.empty [] []
|
||||
|
||||
renderedNodes :: Lens' GraphState (Map.Map PendingNode NodeDef)
|
||||
renderedNodes = lens _renderedNodes (\g x -> g { _renderedNodes = x })
|
||||
|
||||
renderedNodeDefs :: Lens' GraphState (Map.Map NodeName NodeDef)
|
||||
renderedNodeDefs = lens _renderedNodeDefs (\g x -> g { _renderedNodeDefs = x })
|
||||
|
||||
nodeBuffer :: Lens' GraphState [NodeDef]
|
||||
nodeBuffer = lens _nodeBuffer (\g x -> g { _nodeBuffer = x })
|
||||
|
||||
nextUnique :: Lens' GraphState Unique
|
||||
nextUnique = lens _nextUnique (\g x -> g { _nextUnique = x })
|
||||
|
||||
defaultDevice :: Lens' GraphState (Maybe Device)
|
||||
defaultDevice = lens _defaultDevice (\g x -> g { _defaultDevice = x })
|
||||
|
||||
currentScope :: Lens' GraphState [Scope]
|
||||
currentScope = lens _currentScope (\g x -> g { _currentScope = x })
|
||||
|
||||
defaultControlInputs :: Lens' GraphState (Set NodeName)
|
||||
defaultControlInputs = lens _defaultControlInputs
|
||||
(\g x -> g { _defaultControlInputs = x })
|
||||
|
||||
initializationNodes :: Lens' GraphState [NodeName]
|
||||
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
|
||||
|
||||
summaries :: Lens' GraphState [SummaryTensor]
|
||||
summaries = lens _summaries (\g x -> g { _summaries = x })
|
||||
|
||||
-- | An action for building nodes in a TensorFlow graph.
|
||||
-- Used to manage build state internally as part of the @Session@ monad.
|
||||
newtype BuildT m a = BuildT (StateT GraphState m a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
||||
MonadState GraphState)
|
||||
|
||||
-- | An action for building nodes in a TensorFlow graph.
|
||||
type Build = BuildT Identity
|
||||
|
||||
-- | This is Control.Monad.Morph.hoist sans the dependency.
|
||||
hoistBuildT :: (forall a . m a -> n a) -> BuildT m b -> BuildT n b
|
||||
hoistBuildT f (BuildT m) = BuildT $ mapStateT f m
|
||||
|
||||
runBuildT :: BuildT m a -> m (a, GraphState)
|
||||
runBuildT (BuildT f) = runStateT f initGraphState
|
||||
|
||||
evalBuildT :: Monad m => BuildT m a -> m a
|
||||
evalBuildT (BuildT f) = evalStateT f initGraphState
|
||||
|
||||
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
|
||||
flushNodeBuffer :: Monad m => BuildT m [NodeDef]
|
||||
flushNodeBuffer = do
|
||||
ns <- use nodeBuffer
|
||||
nodeBuffer .= []
|
||||
return ns
|
||||
|
||||
-- | Get all the initializers that have accumulated so far, and clear
|
||||
-- that buffer.
|
||||
flushInitializers :: Monad m => BuildT m [NodeName]
|
||||
flushInitializers = do
|
||||
ns <- use initializationNodes
|
||||
initializationNodes .= []
|
||||
return ns
|
||||
|
||||
-- | Registers the given node to be executed before the next
|
||||
-- 'TensorFlow.Session.run'.
|
||||
addInitializer :: ControlNode -> Build ()
|
||||
addInitializer (ControlNode o) = do
|
||||
i <- getOrAddOp o
|
||||
initializationNodes %= (i:)
|
||||
|
||||
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
||||
-- the given 'Build' action.
|
||||
asGraphDef :: Build a -> GraphDef
|
||||
asGraphDef b = def & node .~ gs ^. nodeBuffer
|
||||
where
|
||||
gs = snd $ runIdentity $ runBuildT b
|
||||
|
||||
-- TODO: check against existing nodes for conflicts?
|
||||
addGraphDef :: GraphDef -> Build ()
|
||||
addGraphDef g = nodeBuffer <>= g ^. node
|
||||
|
||||
-- | Render the given op if it hasn't been rendered already, and return its
|
||||
-- name.
|
||||
getOrAddOp :: Op -> Build NodeName
|
||||
getOrAddOp o = NodeName . (^. name) <$> resolveOp o
|
||||
|
||||
resolveOp :: Op -> Build NodeDef
|
||||
resolveOp (Rendered n) = return n
|
||||
resolveOp (Unrendered o) = do
|
||||
pending <- getPendingNode o
|
||||
uses renderedNodes (Map.lookup pending) >>= \case
|
||||
Just n -> return n
|
||||
Nothing -> addNewOpFromPending pending
|
||||
|
||||
-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops
|
||||
-- which are not safe to dedup (e.g, "variable" and "assign").
|
||||
addNewOp :: OpDef -> Build NodeDef
|
||||
addNewOp o = getPendingNode o >>= addNewOpFromPending
|
||||
|
||||
addNewOpFromPending :: PendingNode -> Build NodeDef
|
||||
addNewOpFromPending pending = do
|
||||
nodeName <- renderPendingNode pending
|
||||
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
|
||||
nodeBuffer %= (nodeDef :)
|
||||
renderedNodes %= Map.insert pending nodeDef
|
||||
renderedNodeDefs %= Map.insert nodeName nodeDef
|
||||
return nodeDef
|
||||
|
||||
-- | Get the pending node corresponding to an OpDef, which may or may not have
|
||||
-- been rendered before. Implicitly renders all of this node's inputs.
|
||||
getPendingNode :: OpDef -> Build PendingNode
|
||||
getPendingNode o = do
|
||||
-- An empty string in the proto field means that no specific
|
||||
-- device is specified.
|
||||
dev <- maybe "" deviceName <$> use defaultDevice
|
||||
inputs <- mapM getInput (o ^. opInputs)
|
||||
scope <- use currentScope
|
||||
controls <- use defaultControlInputs
|
||||
let controlInputs
|
||||
= map getDep (o ^. opControlInputs ++ Set.toList controls)
|
||||
return $ PendingNode scope (o ^. opName)
|
||||
$ def & op .~ (unOpType (o ^. opType) :: Text)
|
||||
& attr .~ _opAttrs o
|
||||
& input .~ (inputs ++ controlInputs)
|
||||
& device .~ dev
|
||||
where
|
||||
getInput (Output (OutputIx k) subOp)
|
||||
= (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp
|
||||
getDep = ("^" <>) . unNodeName
|
||||
|
||||
-- | Pick a name for a pending node. If it has an explicit name, just use that;
|
||||
-- if the name is implicit, assign a new unique name based on the op type.
|
||||
renderPendingNode :: PendingNode -> Build NodeName
|
||||
renderPendingNode (PendingNode scope pendingName nodeDef)
|
||||
= NodeName . (scopePrefix <>) <$> getName
|
||||
where
|
||||
scopePrefix = Text.concat $ fmap ((<> "/") . unScope) scope
|
||||
getName = case pendingName of
|
||||
ExplicitName n -> return n
|
||||
ImplicitName -> do
|
||||
u@(Unique k) <- use nextUnique
|
||||
nextUnique .= succ u
|
||||
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
|
||||
|
||||
|
||||
-- | Render an 'Output' and return a string representation for the TensorFlow
|
||||
-- foreign APIs.
|
||||
renderOutput :: Output -> Build Text
|
||||
renderOutput (Output (OutputIx i) o) = do
|
||||
n <- getOrAddOp o
|
||||
return $ unNodeName n <> Text.pack (":" ++ show i)
|
||||
|
||||
-- | Modify some part of the state, run an action, and restore the state
|
||||
-- after that action is done.
|
||||
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b
|
||||
withStateLens accessor f act = do
|
||||
old <- use accessor
|
||||
accessor %= f
|
||||
result <- act
|
||||
accessor .= old
|
||||
return result
|
||||
|
||||
-- | Set a device for all nodes rendered in the given 'Build' action
|
||||
-- (unless further overridden by another use of withDevice).
|
||||
withDevice :: Maybe Device -> Build a -> Build a
|
||||
withDevice d = withStateLens defaultDevice (const d)
|
||||
|
||||
-- | Places all nodes rendered in the given 'Build' action on the same
|
||||
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||
-- the action has side effects of rendering the desired tensors. A pure
|
||||
-- return would not have the desired effect.
|
||||
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a
|
||||
colocateWith t x = do
|
||||
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
||||
withDevice (Just d) x
|
||||
|
||||
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
||||
withNameScope :: Text -> Build a -> Build a
|
||||
withNameScope s = withStateLens currentScope (Scope s :)
|
||||
|
||||
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
||||
withNodeDependencies :: Set NodeName -> Build a -> Build a
|
||||
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
||||
|
||||
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
||||
-- the 'Build' context. Also renders any dependencies of the 'Tensor' that
|
||||
-- weren't already rendered.
|
||||
--
|
||||
-- This operation is idempotent; @render >=> render === render@. However,
|
||||
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
|
||||
-- may result in two different 'Tensor's.
|
||||
render :: Tensor v a -> Build (Tensor v a)
|
||||
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
|
||||
|
||||
-- | Render a 'Tensor' and get its node's name.
|
||||
renderNodeName :: Tensor v a -> Build NodeName
|
||||
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
|
||||
|
||||
-- | Records the given summary action in Build for retrieval with
|
||||
-- 'collectAllSummaries'. The summary op is required to produce a
|
||||
-- Summary protocol buffer in string form. For safety, use the
|
||||
-- pre-composed functions: Logging.scalarSummary and
|
||||
-- Logging.histogramSummary.
|
||||
addSummary :: SummaryTensor -> Build ()
|
||||
addSummary t = summaries %= (t :)
|
||||
|
||||
-- | Retrieves the summary ops collected thus far. Typically this only
|
||||
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
|
||||
-- repeatedly, the values accumulate.
|
||||
collectAllSummaries :: Monad m => BuildT m [SummaryTensor]
|
||||
collectAllSummaries = use summaries
|
199
tensorflow/src/TensorFlow/BuildOp.hs
Normal file
199
tensorflow/src/TensorFlow/BuildOp.hs
Normal file
|
@ -0,0 +1,199 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module TensorFlow.BuildOp
|
||||
( OpResult
|
||||
, BuildOp
|
||||
, buildOp
|
||||
, buildListOp
|
||||
, eqLengthGuard
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Monad (replicateM)
|
||||
import Control.Monad.Reader (ReaderT, runReaderT, ask)
|
||||
import Control.Monad.State.Strict (State, runState, get, put)
|
||||
import Data.Int (Int64)
|
||||
import Lens.Family2 ((&), (<>~), (^.))
|
||||
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Output
|
||||
import TensorFlow.Tensor
|
||||
|
||||
data ResultState = ResultState !OutputIx [Int64] deriving Show
|
||||
|
||||
type Result = ReaderT Op (State ResultState)
|
||||
|
||||
-- | Class of types that can be used as op outputs.
|
||||
class OpResult a where
|
||||
toResult :: Result a
|
||||
|
||||
instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where
|
||||
toResult = (,) <$> toResult <*> toResult
|
||||
|
||||
instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where
|
||||
toResult = (,,) <$> toResult <*> toResult <*> toResult
|
||||
|
||||
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4)
|
||||
=> OpResult (a1, a2, a3, a4) where
|
||||
toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult
|
||||
|
||||
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5)
|
||||
=> OpResult (a1, a2, a3, a4, a5) where
|
||||
toResult = (,,,,) <$> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
|
||||
instance ( OpResult a1
|
||||
, OpResult a2
|
||||
, OpResult a3
|
||||
, OpResult a4
|
||||
, OpResult a5
|
||||
, OpResult a6
|
||||
)
|
||||
=> OpResult (a1, a2, a3, a4, a5, a6) where
|
||||
toResult = (,,,,,)
|
||||
<$> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
|
||||
tensorResult :: TensorKind v -> Result (Tensor v a)
|
||||
tensorResult v = do
|
||||
o <- ask
|
||||
ResultState i ns <- get
|
||||
put $! ResultState (i+1) ns
|
||||
return $! Tensor v $ output i o
|
||||
|
||||
instance OpResult (Tensor Value a) where
|
||||
toResult = tensorResult ValueKind
|
||||
|
||||
instance OpResult (Tensor Ref a) where
|
||||
toResult = tensorResult RefKind
|
||||
|
||||
instance OpResult ControlNode where
|
||||
toResult = ControlNode <$> ask
|
||||
|
||||
instance OpResult a => OpResult [a] where
|
||||
toResult = do
|
||||
ResultState i ns <- get
|
||||
case ns of
|
||||
[] -> error $ "Ran out of counts in toResult. " ++
|
||||
"Likely misuse of buildListOp."
|
||||
(n : rest) -> do
|
||||
put $! ResultState i rest
|
||||
replicateM (fromIntegral n) toResult
|
||||
|
||||
runResult :: OpResult a => [Int64] -> Op -> a
|
||||
runResult ns o =
|
||||
case runState (runReaderT toResult o) (ResultState 0 ns) of
|
||||
(x, ResultState _ []) -> x
|
||||
(_, ns') -> error $ "Ununsed length in runResult attributes: " ++
|
||||
show (ns, ns')
|
||||
|
||||
-- | Make a new "pure" op, which may be deduped with identical ops within
|
||||
-- the same scope.
|
||||
pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a
|
||||
pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts
|
||||
|
||||
-- | Make a new "stateful" op, which will not be deduped with otherwise
|
||||
-- identical ops.
|
||||
buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a
|
||||
buildResult ns o ts
|
||||
= runResult ns . Rendered <$> addNewOp (addReversedInputs o ts)
|
||||
|
||||
addReversedInputs :: OpDef -> [Output] -> OpDef
|
||||
addReversedInputs o ts = o & opInputs <>~ reverse ts
|
||||
|
||||
-- | Class of types that can be used as op functions.
|
||||
class BuildOp f where
|
||||
buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr)
|
||||
-> OpDef
|
||||
-> [Output] -- ^ Accumulator for inputs to the op.
|
||||
-> f
|
||||
|
||||
-- | Starts an operation that returns a structured set of tensors
|
||||
-- (singletons or tuples).
|
||||
buildOp :: BuildOp f => OpDef -> f
|
||||
buildOp o = buildOp' [] o []
|
||||
|
||||
-- | Starts an operation that returns a list of tensors.
|
||||
buildListOp :: BuildOp f => [Int64]
|
||||
-- ^ Cardinality of the corresponding list of tensors output.
|
||||
-> OpDef -> f
|
||||
buildListOp counts o = buildOp' counts o []
|
||||
|
||||
instance BuildOp ControlNode where
|
||||
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
||||
|
||||
instance BuildOp (Tensor Value a) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance BuildOp (Tensor Ref a) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance BuildOp [Tensor Value a] where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4)
|
||||
=> BuildOp (t1, t2, t3, t4) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5)
|
||||
=> BuildOp (t1, t2, t3, t4, t5) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance ( OpResult t1
|
||||
, OpResult t2
|
||||
, OpResult t3
|
||||
, OpResult t4
|
||||
, OpResult t5
|
||||
, OpResult t6
|
||||
)
|
||||
=> BuildOp (t1, t2, t3, t4, t5, t6) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance OpResult a => BuildOp (Build a) where
|
||||
buildOp' = buildResult
|
||||
|
||||
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
||||
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
|
||||
|
||||
instance BuildOp f => BuildOp ([Tensor v a] -> f) where
|
||||
buildOp' rf o accum ts
|
||||
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
|
||||
|
||||
-- | Returns true if all the integers in each tuple are identical.
|
||||
-- Throws an error with a descriptive message if not.
|
||||
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
|
||||
eqLengthGuard = all eachOk
|
||||
where
|
||||
eachOk (_, []) = True
|
||||
-- The next line has (== 1) . length . nub in disguise
|
||||
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
|
||||
error ("number_attr " ++ numberAttrName ++
|
||||
" contains tensors with different length " ++ show pairs)
|
87
tensorflow/src/TensorFlow/ControlFlow.hs
Normal file
87
tensorflow/src/TensorFlow/ControlFlow.hs
Normal file
|
@ -0,0 +1,87 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module TensorFlow.ControlFlow
|
||||
( -- * Dependencies
|
||||
withControlDependencies
|
||||
, group
|
||||
-- * Operations
|
||||
, identity
|
||||
, noOp
|
||||
, named
|
||||
) where
|
||||
|
||||
import qualified Data.Set as Set
|
||||
import Data.Text (Text)
|
||||
import Lens.Family2 ((&), (^.), (.~))
|
||||
|
||||
import TensorFlow.BuildOp
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Nodes
|
||||
import TensorFlow.Output
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
|
||||
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
||||
-- on the nodes in the first argument.
|
||||
withControlDependencies :: Nodes t => t -> Build a -> Build a
|
||||
withControlDependencies deps act = do
|
||||
nodes <- getNodes deps
|
||||
withNodeDependencies nodes act
|
||||
|
||||
-- TODO(judahjacobson): Reimplement withDependencies.
|
||||
|
||||
-- | Create an op that groups multiple operations.
|
||||
--
|
||||
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
||||
-- no output.
|
||||
group :: Nodes t => t -> Build ControlNode
|
||||
group deps = do
|
||||
nodes <- Set.toList <$> getNodes deps
|
||||
-- TODO: slicker way
|
||||
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
||||
|
||||
|
||||
-- | Returns a 'Tensor' with the same shape and contents as the input.
|
||||
identity :: TensorType a => Tensor v a -> Tensor v a
|
||||
identity = namedIdentity implicitName
|
||||
|
||||
-- | Returns a 'Tensor' with a given name and the same shape and contents as
|
||||
-- the input.
|
||||
--
|
||||
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
|
||||
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
|
||||
-- whether we can change that op.
|
||||
named :: TensorType a => Text -> Tensor v a -> Tensor v a
|
||||
named = namedIdentity . explicitName
|
||||
|
||||
-- | An internal version of "identity" that allows setting the name
|
||||
-- of the output Tensor.
|
||||
namedIdentity :: forall a v . TensorType a
|
||||
=> PendingNodeName -> Tensor v a -> Tensor v a
|
||||
namedIdentity n t = case t ^. tensorKind of
|
||||
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
|
||||
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
|
||||
where
|
||||
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
|
||||
|
||||
|
||||
-- | Does nothing. Only useful as a placeholder for control edges.
|
||||
noOp :: ControlNode
|
||||
noOp = buildOp $ opDef "NoOp"
|
243
tensorflow/src/TensorFlow/Internal/FFI.hs
Normal file
243
tensorflow/src/TensorFlow/Internal/FFI.hs
Normal file
|
@ -0,0 +1,243 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DeriveDataTypeable #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module TensorFlow.Internal.FFI
|
||||
( TensorFlowException(..)
|
||||
, Raw.Session
|
||||
, withSession
|
||||
, extendGraph
|
||||
, run
|
||||
, TensorData(..)
|
||||
, setSessionConfig
|
||||
, setSessionTarget
|
||||
, getAllOpList
|
||||
-- * Internal helper.
|
||||
, useProtoAsVoidPtrLen
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
||||
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
||||
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
||||
import Control.Monad (when)
|
||||
import Data.Int (Int64)
|
||||
import Data.Typeable (Typeable)
|
||||
import Data.Word (Word8)
|
||||
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
|
||||
import Foreign.C.String (CString)
|
||||
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
|
||||
import Foreign.Marshal.Alloc (free)
|
||||
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
|
||||
import System.IO.Unsafe (unsafePerformIO)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.Text as T
|
||||
import qualified Data.Text.Encoding as T
|
||||
import qualified Data.Text.Encoding.Error as T
|
||||
import qualified Data.Vector.Storable as S
|
||||
|
||||
import Data.ProtoLens (Message, encodeMessage)
|
||||
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
|
||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
|
||||
data TensorFlowException = TensorFlowException Raw.Code T.Text
|
||||
deriving (Show, Eq, Typeable)
|
||||
|
||||
instance Exception TensorFlowException
|
||||
|
||||
-- | All of the data needed to represent a tensor.
|
||||
data TensorData = TensorData
|
||||
{ tensorDataDimensions :: [Int64]
|
||||
, tensorDataType :: !DataType
|
||||
, tensorDataBytes :: !(S.Vector Word8)
|
||||
}
|
||||
deriving (Show, Eq)
|
||||
|
||||
-- | Runs the given action after creating a session with options
|
||||
-- populated by the given optionSetter.
|
||||
withSession :: (Raw.SessionOptions -> IO ())
|
||||
-> ((IO () -> IO ()) -> Raw.Session -> IO a)
|
||||
-- ^ The action can spawn concurrent tasks which will
|
||||
-- be canceled before withSession returns.
|
||||
-> IO a
|
||||
withSession optionSetter action = do
|
||||
drain <- newMVar []
|
||||
let cleanup s =
|
||||
-- Closes the session to nudge the pending run calls to fail and exit.
|
||||
finally (checkStatus (Raw.closeSession s)) $ do
|
||||
runners <- takeMVar drain
|
||||
-- Collects all runners before deleting the session.
|
||||
mapM_ shutDownRunner runners
|
||||
checkStatus (Raw.deleteSession s)
|
||||
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
||||
optionSetter options
|
||||
bracket
|
||||
(checkStatus (Raw.newSession options))
|
||||
cleanup
|
||||
(action (asyncCollector drain))
|
||||
|
||||
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
|
||||
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
|
||||
where
|
||||
launchAndRecord restRunners = (: restRunners) <$> async runner
|
||||
|
||||
shutDownRunner :: Async () -> IO ()
|
||||
shutDownRunner r = do
|
||||
cancel r
|
||||
-- TODO(gnezdo): manage exceptions better than print.
|
||||
either print (const (return ())) =<< waitCatch r
|
||||
|
||||
extendGraph :: Raw.Session -> GraphDef -> IO ()
|
||||
extendGraph session pb =
|
||||
useProtoAsVoidPtrLen pb $ \ptr len ->
|
||||
checkStatus $ Raw.extendGraph session ptr len
|
||||
|
||||
|
||||
run :: Raw.Session
|
||||
-> [(B.ByteString, TensorData)] -- ^ Feeds.
|
||||
-> [B.ByteString] -- ^ Fetches.
|
||||
-> [B.ByteString] -- ^ Targets.
|
||||
-> IO [TensorData]
|
||||
run session feeds fetches targets = do
|
||||
let nullTensor = Raw.Tensor nullPtr
|
||||
-- Use mask to avoid leaking input tensors before they are passed to 'run'
|
||||
-- and output tensors before they are passed to 'createTensorData'.
|
||||
mask_ $
|
||||
-- Feeds
|
||||
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
|
||||
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
|
||||
withArrayLen feedTensors $ \_ cFeedTensors ->
|
||||
-- Fetches.
|
||||
withStringArrayLen fetches $ \fetchesLen fetchNames ->
|
||||
-- tensorOuts is an array of null Tensor pointers that will be filled
|
||||
-- by the call to Raw.run.
|
||||
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
|
||||
-- Targets.
|
||||
withStringArrayLen targets $ \targetsLen ctargets -> do
|
||||
checkStatus $ Raw.run
|
||||
session
|
||||
nullPtr
|
||||
feedNames cFeedTensors (fromIntegral feedsLen)
|
||||
fetchNames tensorOuts (fromIntegral fetchesLen)
|
||||
ctargets (fromIntegral targetsLen)
|
||||
nullPtr
|
||||
outTensors <- peekArray fetchesLen tensorOuts
|
||||
mapM createTensorData outTensors
|
||||
|
||||
|
||||
-- Internal.
|
||||
|
||||
|
||||
-- | Use a list of ByteString as a list of CString.
|
||||
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
|
||||
withStringList strings fn = go strings []
|
||||
where
|
||||
go [] cs = fn (reverse cs)
|
||||
-- TODO(fmayle): Is it worth using unsafeAsCString here?
|
||||
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
|
||||
|
||||
|
||||
-- | Use a list of ByteString as an array of CString.
|
||||
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
|
||||
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
|
||||
|
||||
|
||||
-- | Create a Raw.Tensor from a TensorData.
|
||||
createRawTensor :: TensorData -> IO Raw.Tensor
|
||||
createRawTensor (TensorData dims dt byteVec) =
|
||||
withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do
|
||||
let len = S.length byteVec
|
||||
dest <- mallocArray len
|
||||
S.unsafeWith byteVec $ \x -> copyArray dest x len
|
||||
Raw.newTensor (toEnum $ fromEnum dt)
|
||||
cdims (fromIntegral cdimsLen)
|
||||
(castPtr dest) (fromIntegral len)
|
||||
tensorDeallocFunPtr nullPtr
|
||||
|
||||
{-# NOINLINE tensorDeallocFunPtr #-}
|
||||
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
|
||||
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
|
||||
|
||||
-- | Create a TensorData from a Raw.Tensor.
|
||||
--
|
||||
-- Takes ownership of the Raw.Tensor.
|
||||
createTensorData :: Raw.Tensor -> IO TensorData
|
||||
createTensorData t = do
|
||||
-- Read dimensions.
|
||||
numDims <- Raw.numDims t
|
||||
dims <- mapM (Raw.dim t) [0..numDims-1]
|
||||
-- Read type.
|
||||
dtype <- toEnum . fromEnum <$> Raw.tensorType t
|
||||
-- Read data.
|
||||
len <- fromIntegral <$> Raw.tensorByteSize t
|
||||
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
|
||||
-- TODO(fmayle): Don't copy the data.
|
||||
v <- S.fromList <$> peekArray len bytes
|
||||
-- Free tensor.
|
||||
Raw.deleteTensor t
|
||||
return $ TensorData (map fromIntegral dims) dtype v
|
||||
|
||||
-- | Runs the given action which does FFI calls updating a provided
|
||||
-- status object. If the status is not OK it is thrown as
|
||||
-- TensorFlowException.
|
||||
checkStatus :: (Raw.Status -> IO a) -> IO a
|
||||
checkStatus fn =
|
||||
bracket Raw.newStatus Raw.deleteStatus $ \status -> do
|
||||
result <- fn status
|
||||
code <- Raw.getCode status
|
||||
when (code /= Raw.TF_OK) $ do
|
||||
msg <- T.decodeUtf8With T.lenientDecode <$>
|
||||
(Raw.message status >>= B.packCString)
|
||||
throwIO $ TensorFlowException code msg
|
||||
return result
|
||||
|
||||
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
|
||||
setSessionConfig pb opt =
|
||||
useProtoAsVoidPtrLen pb $ \ptr len ->
|
||||
checkStatus (Raw.setConfig opt ptr len)
|
||||
|
||||
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
|
||||
setSessionTarget target = B.useAsCString target . Raw.setTarget
|
||||
|
||||
-- | Serializes the given msg and provides it as (ptr,len) argument
|
||||
-- to the given action.
|
||||
useProtoAsVoidPtrLen :: (Message msg, Num c) =>
|
||||
msg -> (Ptr b -> c -> IO a) -> IO a
|
||||
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
|
||||
\(bytes, len) -> f (castPtr bytes) (fromIntegral len)
|
||||
|
||||
-- | Returns the serialized OpList of all OpDefs defined in this
|
||||
-- address space.
|
||||
getAllOpList :: IO B.ByteString
|
||||
getAllOpList = do
|
||||
foreignPtr <-
|
||||
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
|
||||
-- Makes a copy because it is more reliable than eviscerating
|
||||
-- Buffer to steal its memory (including custom deallocator).
|
||||
withForeignPtr foreignPtr $
|
||||
\ptr -> B.packCStringLen =<< (,)
|
||||
<$> (castPtr <$> Raw.getBufferData ptr)
|
||||
<*> (fromIntegral <$> Raw.getBufferLength ptr)
|
||||
where
|
||||
checkCall = do
|
||||
p <- Raw.getAllOpList
|
||||
when (p == nullPtr) (throwIO exception)
|
||||
return p
|
||||
exception = TensorFlowException
|
||||
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"
|
152
tensorflow/src/TensorFlow/Internal/Raw.chs
Normal file
152
tensorflow/src/TensorFlow/Internal/Raw.chs
Normal file
|
@ -0,0 +1,152 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ForeignFunctionInterface #-}
|
||||
|
||||
module TensorFlow.Internal.Raw where
|
||||
|
||||
#include "third_party/tensorflow/c/c_api.h"
|
||||
|
||||
import Foreign
|
||||
import Foreign.C
|
||||
|
||||
{#enum TF_DataType as DataType {} deriving (Show, Eq) #}
|
||||
{#enum TF_Code as Code {} deriving (Show, Eq) #}
|
||||
|
||||
|
||||
-- Status.
|
||||
{#pointer *TF_Status as Status newtype #}
|
||||
|
||||
newStatus :: IO Status
|
||||
newStatus = {# call TF_NewStatus as ^ #}
|
||||
|
||||
deleteStatus :: Status -> IO ()
|
||||
deleteStatus = {# call TF_DeleteStatus as ^ #}
|
||||
|
||||
setStatus :: Status -> Code -> CString -> IO ()
|
||||
setStatus s c = {# call TF_SetStatus as ^ #} s (fromIntegral $ fromEnum c)
|
||||
|
||||
getCode :: Status -> IO Code
|
||||
getCode s = toEnum . fromIntegral <$> {# call TF_GetCode as ^ #} s
|
||||
|
||||
message :: Status -> IO CString
|
||||
message = {# call TF_Message as ^ #}
|
||||
|
||||
|
||||
-- Buffer.
|
||||
data Buffer
|
||||
{#pointer *TF_Buffer as BufferPtr -> Buffer #}
|
||||
|
||||
getBufferData :: BufferPtr -> IO (Ptr ())
|
||||
getBufferData = {#get TF_Buffer->data #}
|
||||
|
||||
getBufferLength :: BufferPtr -> IO CULong
|
||||
getBufferLength ={#get TF_Buffer->length #}
|
||||
|
||||
-- Tensor.
|
||||
{#pointer *TF_Tensor as Tensor newtype #}
|
||||
|
||||
instance Storable Tensor where
|
||||
sizeOf (Tensor t) = sizeOf t
|
||||
alignment (Tensor t) = alignment t
|
||||
peek p = fmap Tensor (peek (castPtr p))
|
||||
poke p (Tensor t) = poke (castPtr p) t
|
||||
|
||||
newTensor :: DataType
|
||||
-> Ptr CLong -- dimensions array
|
||||
-> CInt -- num dimensions
|
||||
-> Ptr () -- data
|
||||
-> CULong -- data len
|
||||
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()) -- deallocator
|
||||
-> Ptr () -- deallocator arg
|
||||
-> IO Tensor
|
||||
newTensor dt = {# call TF_NewTensor as ^ #} (fromIntegral $ fromEnum dt)
|
||||
|
||||
deleteTensor :: Tensor -> IO ()
|
||||
deleteTensor = {# call TF_DeleteTensor as ^ #}
|
||||
|
||||
tensorType :: Tensor -> IO DataType
|
||||
tensorType t = toEnum . fromIntegral <$> {# call TF_TensorType as ^ #} t
|
||||
|
||||
numDims :: Tensor -> IO CInt
|
||||
numDims = {# call TF_NumDims as ^ #}
|
||||
|
||||
dim :: Tensor -> CInt -> IO CLong
|
||||
dim = {# call TF_Dim as ^ #}
|
||||
|
||||
tensorByteSize :: Tensor -> IO CULong
|
||||
tensorByteSize = {# call TF_TensorByteSize as ^ #}
|
||||
|
||||
tensorData :: Tensor -> IO (Ptr ())
|
||||
tensorData = {# call TF_TensorData as ^ #}
|
||||
|
||||
|
||||
-- Session Options.
|
||||
{# pointer *TF_SessionOptions as SessionOptions newtype #}
|
||||
|
||||
newSessionOptions :: IO SessionOptions
|
||||
newSessionOptions = {# call TF_NewSessionOptions as ^ #}
|
||||
|
||||
setTarget :: SessionOptions -> CString -> IO ()
|
||||
setTarget = {# call TF_SetTarget as ^ #}
|
||||
|
||||
setConfig :: SessionOptions -> Ptr () -> CULong -> Status -> IO ()
|
||||
setConfig = {# call TF_SetConfig as ^ #}
|
||||
|
||||
deleteSessionOptions :: SessionOptions -> IO ()
|
||||
deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
|
||||
|
||||
|
||||
-- Session.
|
||||
{# pointer *TF_Session as Session newtype #}
|
||||
|
||||
newSession :: SessionOptions -> Status -> IO Session
|
||||
newSession = {# call TF_NewSession as ^ #}
|
||||
|
||||
closeSession :: Session -> Status -> IO ()
|
||||
closeSession = {# call TF_CloseSession as ^ #}
|
||||
|
||||
deleteSession :: Session -> Status -> IO ()
|
||||
deleteSession = {# call TF_DeleteSession as ^ #}
|
||||
|
||||
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
|
||||
extendGraph = {# call TF_ExtendGraph as ^ #}
|
||||
|
||||
run :: Session
|
||||
-> BufferPtr -- RunOptions proto.
|
||||
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count).
|
||||
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count).
|
||||
-> Ptr CString -> CInt -- Target nodes (names, count).
|
||||
-> BufferPtr -- RunMetadata proto.
|
||||
-> Status
|
||||
-> IO ()
|
||||
run = {# call TF_Run as ^ #}
|
||||
|
||||
-- FFI helpers.
|
||||
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
|
||||
foreign import ccall "wrapper"
|
||||
wrapTensorDealloc :: TensorDeallocFn -> IO (FunPtr TensorDeallocFn)
|
||||
|
||||
|
||||
-- | Get the OpList of all OpDefs defined in this address space.
|
||||
-- Returns a BufferPtr, ownership of which is transferred to the caller
|
||||
-- (and can be freed using deleteBuffer).
|
||||
--
|
||||
-- The data in the buffer will be the serialized OpList proto for ops registered
|
||||
-- in this address space.
|
||||
getAllOpList :: IO BufferPtr
|
||||
getAllOpList = {# call TF_GetAllOpList as ^ #}
|
||||
|
||||
foreign import ccall "&TF_DeleteBuffer"
|
||||
deleteBuffer :: FunPtr (BufferPtr -> IO ())
|
50
tensorflow/src/TensorFlow/Internal/VarInt.hs
Normal file
50
tensorflow/src/TensorFlow/Internal/VarInt.hs
Normal file
|
@ -0,0 +1,50 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
|
||||
{-|
|
||||
Module : TensorFlow.Internal.VarInt
|
||||
Description : Encoders and decoders for varint types.
|
||||
|
||||
Originally taken from internal proto-lens code.
|
||||
-}
|
||||
module TensorFlow.Internal.VarInt
|
||||
( getVarInt
|
||||
, putVarInt
|
||||
) where
|
||||
|
||||
import Data.Attoparsec.ByteString as Parse
|
||||
import Data.Bits
|
||||
import Data.ByteString.Lazy.Builder as Builder
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Word (Word64)
|
||||
|
||||
-- | Decode an unsigned varint.
|
||||
getVarInt :: Parser Word64
|
||||
getVarInt = loop 1 0
|
||||
where
|
||||
loop !s !n = do
|
||||
b <- anyWord8
|
||||
let n' = n + s * fromIntegral (b .&. 127)
|
||||
if (b .&. 128) == 0
|
||||
then return n'
|
||||
else loop (128*s) n'
|
||||
|
||||
-- | Encode a Word64.
|
||||
putVarInt :: Word64 -> Builder
|
||||
putVarInt n
|
||||
| n < 128 = Builder.word8 (fromIntegral n)
|
||||
| otherwise = Builder.word8 (fromIntegral $ n .&. 127 .|. 128)
|
||||
<> putVarInt (n `shiftR` 7)
|
141
tensorflow/src/TensorFlow/Nodes.hs
Normal file
141
tensorflow/src/TensorFlow/Nodes.hs
Normal file
|
@ -0,0 +1,141 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
module TensorFlow.Nodes where
|
||||
|
||||
import Control.Applicative (liftA2, liftA3)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Set (Set)
|
||||
import Data.String (IsString)
|
||||
import Data.Text (Text)
|
||||
import Lens.Family2 ((^.))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Output
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
||||
-- | Types that contain ops which can be run.
|
||||
class Nodes t where
|
||||
getNodes :: t -> Build (Set NodeName)
|
||||
|
||||
-- | Types that tensor representations (e.g. 'Tensor', 'ControlNode') can be
|
||||
-- fetched into.
|
||||
--
|
||||
-- Includes collections of tensors (e.g. tuples).
|
||||
class Nodes t => Fetchable t a where
|
||||
getFetch :: t -> Build (Fetch a)
|
||||
|
||||
-- | Fetch action. Keeps track of what needs to be fetched and how to decode
|
||||
-- the fetched data.
|
||||
data Fetch a = Fetch
|
||||
{ -- | Nodes to fetch
|
||||
fetches :: Set Text
|
||||
-- | Function to create an 'a' from the fetched data.
|
||||
, fetchRestore :: Map Text FFI.TensorData -> a
|
||||
}
|
||||
|
||||
instance Functor Fetch where
|
||||
fmap f (Fetch fetch restore) = Fetch fetch (f . restore)
|
||||
|
||||
instance Applicative Fetch where
|
||||
pure x = Fetch Set.empty (const x)
|
||||
Fetch fetch restore <*> Fetch fetch' restore' =
|
||||
Fetch (fetch <> fetch') (restore <*> restore')
|
||||
|
||||
nodesUnion :: (Monoid b, Traversable t, Applicative f) => t (f b) -> f b
|
||||
nodesUnion = fmap (foldMap id) . sequenceA
|
||||
|
||||
instance (Nodes t1, Nodes t2) => Nodes (t1, t2) where
|
||||
getNodes (x, y) = nodesUnion [getNodes x, getNodes y]
|
||||
|
||||
instance (Nodes t1, Nodes t2, Nodes t3) => Nodes (t1, t2, t3) where
|
||||
getNodes (x, y, z) = nodesUnion [getNodes x, getNodes y, getNodes z]
|
||||
|
||||
instance (Fetchable t1 a1, Fetchable t2 a2) => Fetchable (t1, t2) (a1, a2) where
|
||||
getFetch (x, y) = liftA2 (,) <$> getFetch x <*> getFetch y
|
||||
|
||||
instance (Fetchable t1 a1, Fetchable t2 a2, Fetchable t3 a3)
|
||||
=> Fetchable (t1, t2, t3) (a1, a2, a3) where
|
||||
getFetch (x, y, z) =
|
||||
liftA3 (,,) <$> getFetch x <*> getFetch y <*> getFetch z
|
||||
|
||||
instance Nodes t => Nodes [t] where
|
||||
getNodes = nodesUnion . map getNodes
|
||||
|
||||
instance Fetchable t a => Fetchable [t] [a] where
|
||||
getFetch ts = sequenceA <$> mapM getFetch ts
|
||||
|
||||
instance Nodes ControlNode where
|
||||
getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o
|
||||
|
||||
-- We use the constraint @(a ~ ())@ to help with type inference. For example,
|
||||
-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session
|
||||
-- ()@. If we used @instance Fetchable ControlNode ()@ instead, then that
|
||||
-- expression would be ambiguous without explicitly specifying the return type.
|
||||
instance a ~ () => Fetchable ControlNode a where
|
||||
getFetch _ = return $ pure ()
|
||||
|
||||
instance Nodes (Tensor v a) where
|
||||
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
||||
|
||||
fetchTensorList :: TensorType a => Tensor v a -> Build (Fetch (Shape, [a]))
|
||||
fetchTensorList t = fmap (fmap V.toList) <$> fetchTensorVector t
|
||||
|
||||
fetchTensorVector :: forall a v . TensorType a
|
||||
=> Tensor v a -> Build (Fetch (Shape, V.Vector a))
|
||||
fetchTensorVector (Tensor _ o) = do
|
||||
outputName <- renderOutput o
|
||||
return $ Fetch (Set.singleton outputName) $ \tensors ->
|
||||
let tensorData = tensors Map.! outputName
|
||||
shape = Shape $ FFI.tensorDataDimensions tensorData
|
||||
vec = decodeTensorData $ TensorData tensorData
|
||||
|
||||
expectedType = tensorType (undefined :: a)
|
||||
actualType = FFI.tensorDataType tensorData
|
||||
badTypeError = error $ "Bad tensor type: expected "
|
||||
++ show expectedType
|
||||
++ ", got "
|
||||
++ show actualType
|
||||
in if expectedType /= actualType
|
||||
then badTypeError
|
||||
else (shape, vec)
|
||||
|
||||
-- The constraint "a ~ a'" means that the input/output of fetch can constrain
|
||||
-- the TensorType of each other.
|
||||
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (V.Vector a') where
|
||||
getFetch t = fmap snd <$> fetchTensorVector t
|
||||
|
||||
newtype Scalar a = Scalar {unScalar :: a}
|
||||
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
|
||||
RealFrac, IsString)
|
||||
|
||||
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (Scalar a') where
|
||||
getFetch t = fmap (Scalar . headFromSingleton . snd) <$> fetchTensorList t
|
||||
where
|
||||
headFromSingleton [x] = x
|
||||
headFromSingleton xs
|
||||
= error $ "Unable to extract singleton from tensor of length "
|
||||
++ show (length xs)
|
46
tensorflow/src/TensorFlow/Orphans.hs
Normal file
46
tensorflow/src/TensorFlow/Orphans.hs
Normal file
|
@ -0,0 +1,46 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
-- Orphan instances for certain proto messages/enums, used internally.
|
||||
-- TODO(judahjacobson): consider making proto-lens generate some or all of
|
||||
-- these automatically; or, alternately, make new Haskell datatypes.
|
||||
module TensorFlow.Orphans() where
|
||||
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue
|
||||
( AttrValue(..)
|
||||
, AttrValue'ListValue(..)
|
||||
, NameAttrList(..)
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
( NodeDef(..))
|
||||
import Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||
( ResourceHandle(..))
|
||||
import Proto.Tensorflow.Core.Framework.Tensor
|
||||
(TensorProto(..))
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||
(TensorShapeProto(..), TensorShapeProto'Dim(..))
|
||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||
|
||||
deriving instance Ord AttrValue
|
||||
deriving instance Ord AttrValue'ListValue
|
||||
deriving instance Ord DataType
|
||||
deriving instance Ord NameAttrList
|
||||
deriving instance Ord NodeDef
|
||||
deriving instance Ord ResourceHandle
|
||||
deriving instance Ord TensorProto
|
||||
deriving instance Ord TensorShapeProto
|
||||
deriving instance Ord TensorShapeProto'Dim
|
156
tensorflow/src/TensorFlow/Output.hs
Normal file
156
tensorflow/src/TensorFlow/Output.hs
Normal file
|
@ -0,0 +1,156 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module TensorFlow.Output
|
||||
( ControlNode(..)
|
||||
, Device(..)
|
||||
-- * Ops
|
||||
, NodeName(..)
|
||||
, Op(..)
|
||||
, opUnrendered
|
||||
, OpDef(..)
|
||||
, opName
|
||||
, opType
|
||||
, opAttr
|
||||
, opInputs
|
||||
, opControlInputs
|
||||
, OpType(..)
|
||||
, OutputIx(..)
|
||||
, Output(..)
|
||||
, output
|
||||
, outputIndex
|
||||
, outputOp
|
||||
, PendingNodeName(..)
|
||||
) where
|
||||
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.ProtoLens.TextFormat (showMessage)
|
||||
import Data.String (IsString(..))
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
|
||||
import Data.Default (def)
|
||||
import TensorFlow.Types (Attribute, attrLens)
|
||||
import TensorFlow.Orphans ()
|
||||
|
||||
-- | A type of graph node which has no outputs. These nodes are
|
||||
-- valuable for causing side effects when they are run.
|
||||
newtype ControlNode = ControlNode { unControlNode :: Op }
|
||||
|
||||
-- | The type of op of a node in the graph. This corresponds to the proto field
|
||||
-- NodeDef.op.
|
||||
newtype OpType = OpType { unOpType :: Text }
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance IsString OpType where
|
||||
fromString = OpType . Text.pack
|
||||
|
||||
-- | An output of a TensorFlow node.
|
||||
data Output = Output !OutputIx !Op
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
output :: OutputIx -> Op -> Output
|
||||
output = Output
|
||||
|
||||
outputOp :: Lens' Output Op
|
||||
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)
|
||||
|
||||
outputIndex :: Lens' Output OutputIx
|
||||
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)
|
||||
|
||||
newtype OutputIx = OutputIx { unOutputIx :: Int }
|
||||
deriving (Eq, Ord, Num, Enum, Show)
|
||||
|
||||
-- | A device that a node can be assigned to.
|
||||
-- There's a naming convention where the device names
|
||||
-- are constructed from job and replica names.
|
||||
newtype Device = Device {deviceName :: Text}
|
||||
deriving (Eq, Ord, IsString)
|
||||
|
||||
instance Show Device where
|
||||
show (Device d) = show d
|
||||
|
||||
-- | The representation of a node in a TensorFlow graph.
|
||||
data Op
|
||||
= Rendered !NodeDef -- ^ Properties are fixed, including the
|
||||
-- device, name, and scope.
|
||||
| Unrendered !OpDef -- ^ Properties are not fixed, and may change depending
|
||||
-- on which context this op is rendered in.
|
||||
deriving (Eq, Ord)
|
||||
|
||||
instance Show Op where
|
||||
show (Rendered n) = "Rendered " ++ showMessage n
|
||||
show (Unrendered o) = "Unrendered " ++ show (o ^. opName)
|
||||
|
||||
-- | Traverse on the 'Unrendered' of an 'Op'.
|
||||
--
|
||||
-- Same implementation as _Left.
|
||||
opUnrendered :: Traversal' Op OpDef
|
||||
opUnrendered f (Unrendered a) = Unrendered <$> f a
|
||||
opUnrendered _ (Rendered b) = pure (Rendered b)
|
||||
|
||||
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
|
||||
data OpDef = OpDef
|
||||
{ _opName :: !PendingNodeName
|
||||
, _opType :: !OpType
|
||||
, _opAttrs :: !(Map.Map Text AttrValue)
|
||||
, _opInputs :: [Output]
|
||||
, _opControlInputs :: [NodeName]
|
||||
} deriving (Eq, Ord)
|
||||
|
||||
-- | The name specified for an unrendered Op. If an Op has an
|
||||
-- ImplicitName, it will be assigned based on the opType plus a
|
||||
-- unique identifier. Does not contain the "scope" prefix.
|
||||
data PendingNodeName = ExplicitName !Text | ImplicitName
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
-- | The name of a node in the graph. This corresponds to the proto field
|
||||
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
|
||||
-- (if the node was implicitly named).
|
||||
newtype NodeName = NodeName { unNodeName :: Text }
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
opName :: Lens' OpDef PendingNodeName
|
||||
opName = lens _opName (\o x -> o {_opName = x})
|
||||
|
||||
opType :: Lens' OpDef OpType
|
||||
opType = lens _opType (\o x -> o { _opType = x})
|
||||
|
||||
opAttr :: Attribute a => Text -> Lens' OpDef a
|
||||
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
|
||||
. lens (Map.findWithDefault def n) (flip (Map.insert n))
|
||||
. attrLens
|
||||
|
||||
opInputs :: Lens' OpDef [Output]
|
||||
opInputs = lens _opInputs (\o x -> o {_opInputs = x})
|
||||
|
||||
opControlInputs :: Lens' OpDef [NodeName]
|
||||
opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
|
||||
|
||||
-- TODO(gnezdo): IsString instance is weird and we should move that
|
||||
-- code into a Build function
|
||||
instance IsString Output where
|
||||
fromString s = case break (==':') s of
|
||||
(n, ':':ixStr)
|
||||
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
|
||||
_ -> Output 0 $ assigned s
|
||||
where assigned n = Rendered $ def & name .~ Text.pack n
|
||||
|
202
tensorflow/src/TensorFlow/Session.hs
Normal file
202
tensorflow/src/TensorFlow/Session.hs
Normal file
|
@ -0,0 +1,202 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module TensorFlow.Session (
|
||||
Session,
|
||||
-- * Opaque value created via 'sessionConfig' and 'sessionTarget'.
|
||||
SessionOption,
|
||||
sessionConfig,
|
||||
sessionTarget,
|
||||
runSession,
|
||||
runSessionWithOptions,
|
||||
build,
|
||||
buildAnd,
|
||||
buildWithSummary,
|
||||
extend,
|
||||
addGraphDef,
|
||||
run,
|
||||
runWithFeeds,
|
||||
run_,
|
||||
runWithFeeds_,
|
||||
asyncProdNodes,
|
||||
) where
|
||||
|
||||
import Control.Monad (forever, unless, void)
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Control.Monad.Trans.Class (lift)
|
||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import Data.Set (Set)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.ProtoLens (def)
|
||||
import Lens.Family2 ((&), (.~))
|
||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||
|
||||
import TensorFlow.Build
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
import TensorFlow.Nodes
|
||||
import TensorFlow.Output (NodeName, unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
|
||||
-- Common state threaded through the session.
|
||||
data SessionState
|
||||
= SessionState {
|
||||
rawSession :: FFI.Session
|
||||
, asyncCollector :: IO () -> IO ()
|
||||
-- ^ Starts the given action concurrently.
|
||||
}
|
||||
|
||||
newtype Session a
|
||||
= Session (ReaderT SessionState (BuildT IO) a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO)
|
||||
|
||||
-- | Run 'Session' actions in a new TensorFlow session.
|
||||
runSession :: Session a -> IO a
|
||||
runSession = runSessionWithOptions []
|
||||
|
||||
-- | Setting of an option for the session (see 'runSessionWithOptions').
|
||||
newtype SessionOption =
|
||||
SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () }
|
||||
|
||||
-- | Target can be: "local", ip:port, host:port.
|
||||
-- The set of supported factories depends on the linked in libraries.
|
||||
-- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary.
|
||||
sessionTarget :: ByteString -> SessionOption
|
||||
sessionTarget = SessionOption . FFI.setSessionTarget
|
||||
|
||||
-- | Uses the specified config for the created session.
|
||||
sessionConfig :: ConfigProto -> SessionOption
|
||||
sessionConfig = SessionOption . FFI.setSessionConfig
|
||||
|
||||
-- | Run 'Session' actions in a new TensorFlow session created with
|
||||
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
||||
runSessionWithOptions :: [SessionOption] -> Session a -> IO a
|
||||
runSessionWithOptions options (Session m) =
|
||||
FFI.withSession applyOptions $
|
||||
\as rs -> evalBuildT (runReaderT m (SessionState rs as))
|
||||
where applyOptions opt = mapM_ (`unSesssionOption` opt) options
|
||||
|
||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||
-- renderings.
|
||||
build :: Build a -> Session a
|
||||
build = Session . lift . hoistBuildT (return . runIdentity)
|
||||
|
||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||
-- renderings. Returns the merged summary ops which can be used for
|
||||
-- logging, see 'TensorFlow.Logging.build' for a convenient wrapper.
|
||||
buildWithSummary :: forall a . Build a -> Session (a, [SummaryTensor])
|
||||
buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
|
||||
where v :: BuildT IO a
|
||||
v = hoistBuildT (return . runIdentity) b
|
||||
|
||||
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
||||
-- any pending initializers.
|
||||
--
|
||||
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
||||
extend :: Session ()
|
||||
extend = do
|
||||
let withSessionWhen vs action =
|
||||
unless (null vs) $ Session (asks rawSession) >>= action
|
||||
nodesToExtend <- build flushNodeBuffer
|
||||
withSessionWhen nodesToExtend $ \session ->
|
||||
liftIO $ FFI.extendGraph session
|
||||
$ def & node .~ nodesToExtend
|
||||
-- Now that all the nodes are created, run the initializers.
|
||||
initializers <- build flushInitializers
|
||||
withSessionWhen initializers $ \session ->
|
||||
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
||||
|
||||
-- | Helper combinator for doing something with the result of a 'Build' action.
|
||||
-- Example usage:
|
||||
--
|
||||
-- > buildAnd run :: Fetchable t a => Build t -> Session a
|
||||
buildAnd :: (a -> Session b) -> Build a -> Session b
|
||||
buildAnd f m = build m >>= f
|
||||
|
||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||
-- rendered, and fetch the corresponding values for 'a'.
|
||||
run :: Fetchable t a => t -> Session a
|
||||
run = runWithFeeds []
|
||||
|
||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||
-- rendered, feed the given input values, and fetch the corresponding result
|
||||
-- values for 'a'.
|
||||
runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a
|
||||
runWithFeeds feeds t = do
|
||||
ns <- build $ getNodes t
|
||||
-- Note that this call to "fetch" shouldn't affect the following "extend"
|
||||
-- call, since all nodes in t and its inputs/deps will be rendered by the
|
||||
-- above call to getNodes.
|
||||
fetch <- build $ getFetch t
|
||||
runFetchWithFeeds feeds ns fetch
|
||||
|
||||
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
|
||||
runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
||||
extend
|
||||
feeds' <- build $ fixFeeds feeds
|
||||
let fetchNames = encodeUtf8 <$> Set.toList fetch
|
||||
targetNames = toNodeNames $ Set.toList target
|
||||
session <- Session (asks rawSession)
|
||||
runResult <- liftIO $ FFI.run session
|
||||
feeds'
|
||||
fetchNames
|
||||
targetNames
|
||||
let resultTensorsMap = Map.fromList $ zip (Set.toList fetch) runResult
|
||||
return $ restore resultTensorsMap
|
||||
|
||||
toNodeNames :: [NodeName] -> [ByteString]
|
||||
toNodeNames = map (encodeUtf8 . unNodeName)
|
||||
|
||||
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
|
||||
-- already rendered. This behaves like 'run' except that it doesn't do any
|
||||
-- fetches.
|
||||
run_ :: Nodes t => t -> Session ()
|
||||
run_ = runWithFeeds_ []
|
||||
|
||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||
-- rendered, feed the given input values, and fetch the corresponding result
|
||||
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
|
||||
-- any fetches.
|
||||
runWithFeeds_ :: Nodes t => [Feed] -> t -> Session ()
|
||||
runWithFeeds_ feeds t = do
|
||||
ns <- build $ getNodes t
|
||||
runFetchWithFeeds feeds ns (pure ())
|
||||
|
||||
fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)]
|
||||
fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o
|
||||
|
||||
-- | Starts a concurrent thread which evaluates the given Nodes
|
||||
-- forever until runSession exits or an exception occurs. Graph
|
||||
-- extension happens synchronously, but the resultant run proceeds as
|
||||
-- a separate thread.
|
||||
asyncProdNodes :: Nodes t
|
||||
=> t -- ^ Node to evaluate concurrently.
|
||||
-> Session ()
|
||||
asyncProdNodes nodes = do
|
||||
target <- build (getNodes nodes)
|
||||
extend
|
||||
let targetNames = toNodeNames $ Set.toList target
|
||||
state <- Session ask
|
||||
let loop = forever (void (FFI.run (rawSession state) [] [] targetNames))
|
||||
liftIO (asyncCollector state loop)
|
85
tensorflow/src/TensorFlow/Tensor.hs
Normal file
85
tensorflow/src/TensorFlow/Tensor.hs
Normal file
|
@ -0,0 +1,85 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
|
||||
module TensorFlow.Tensor where
|
||||
|
||||
import Data.String (IsString(..))
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens', Traversal')
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
|
||||
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
||||
import TensorFlow.Types (TensorData(..), Attribute)
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
||||
-- | A named output of a TensorFlow operation.
|
||||
--
|
||||
-- The type parameter @a@ is the type of the elements in the 'Tensor'. The
|
||||
-- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is
|
||||
-- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a
|
||||
-- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via
|
||||
-- 'value'.
|
||||
data Tensor v a = Tensor (TensorKind v) Output
|
||||
|
||||
data Value
|
||||
data Ref
|
||||
|
||||
-- | This class provides a runtime switch on whether a 'Tensor' should be
|
||||
-- treated as a 'Value' or as a 'Ref'.
|
||||
data TensorKind v where
|
||||
ValueKind :: TensorKind Value
|
||||
RefKind :: TensorKind Ref
|
||||
|
||||
tensorKind :: Lens' (Tensor v a) (TensorKind v)
|
||||
tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
|
||||
|
||||
tensorOutput :: Lens' (Tensor v a) Output
|
||||
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
||||
|
||||
-- TODO: Come up with a better API for handling attributes.
|
||||
-- | Lens for the attributes of a tensor.
|
||||
--
|
||||
-- Only valid if the tensor has not yet been rendered. If the tensor has been
|
||||
-- rendered, the traversal will be over nothing (nothing can be read or
|
||||
-- written).
|
||||
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
|
||||
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
|
||||
|
||||
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
||||
-- Ref into Value. This behaves like a no-op.
|
||||
value :: Tensor v a -> Tensor Value a
|
||||
value (Tensor _ o) = Tensor ValueKind o
|
||||
|
||||
-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
|
||||
-- when running the graph.
|
||||
data Feed = Feed Output FFI.TensorData
|
||||
|
||||
-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
|
||||
-- the graph.
|
||||
--
|
||||
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
|
||||
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
|
||||
feed :: Tensor v a -> TensorData a -> Feed
|
||||
feed (Tensor _ o) (TensorData td) = Feed o td
|
||||
|
||||
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
|
||||
-- in a 'GraphDef' that was loaded via 'addGraphDef'.
|
||||
-- TODO(judahjacobson): add more safety checks here.
|
||||
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
|
||||
tensorFromName v = Tensor v . fromString . Text.unpack
|
382
tensorflow/src/TensorFlow/Types.hs
Normal file
382
tensorflow/src/TensorFlow/Types.hs
Normal file
|
@ -0,0 +1,382 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
-- We use UndecidableInstances for type families with recursive definitions
|
||||
-- like "\\". Those instances will terminate since each equation unwraps one
|
||||
-- cons cell of a type-level list.
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
|
||||
module TensorFlow.Types
|
||||
( TensorType(..)
|
||||
, TensorData(..)
|
||||
, Shape(..)
|
||||
, protoShape
|
||||
, Attribute(..)
|
||||
-- * Type constraints
|
||||
, OneOf
|
||||
, type (/=)
|
||||
-- ** Implementation of constraints
|
||||
, TypeError
|
||||
, ExcludedCase
|
||||
, TensorTypes
|
||||
, NoneOf
|
||||
, type (\\)
|
||||
, Delete
|
||||
, AllTensorTypes
|
||||
) where
|
||||
|
||||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int8, Int16, Int32, Int64)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Word (Word8, Word16, Word64)
|
||||
import Foreign.Storable (Storable)
|
||||
import GHC.Exts (Constraint, IsList(..))
|
||||
import Lens.Family2 (Lens', view, (&), (.~))
|
||||
import Lens.Family2.Unchecked (iso)
|
||||
import qualified Data.Attoparsec.ByteString as Atto
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
import Data.ByteString.Builder (Builder)
|
||||
import qualified Data.ByteString.Builder as Builder
|
||||
import qualified Data.ByteString.Lazy as L
|
||||
import qualified Data.Vector as V
|
||||
import qualified Data.Vector.Storable as S
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue
|
||||
( AttrValue(..)
|
||||
, AttrValue'ListValue(..)
|
||||
, b
|
||||
, f
|
||||
, i
|
||||
, s
|
||||
, list
|
||||
, type'
|
||||
, shape
|
||||
, tensor
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||
( TensorProto(..)
|
||||
, floatVal
|
||||
, doubleVal
|
||||
, intVal
|
||||
, stringVal
|
||||
, int64Val
|
||||
, stringVal
|
||||
, boolVal
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||
( TensorShapeProto(..)
|
||||
, dim
|
||||
, size
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||
|
||||
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
||||
-- | Data about a tensor that is encoded for the TensorFlow APIs.
|
||||
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
||||
|
||||
-- | The class of scalar types supported by tensorflow.
|
||||
class TensorType a where
|
||||
tensorType :: a -> DataType
|
||||
tensorRefType :: a -> DataType
|
||||
tensorVal :: Lens' TensorProto [a]
|
||||
-- | Decode the bytes of a TensorData into a Vector.
|
||||
decodeTensorData :: TensorData a -> V.Vector a
|
||||
-- | Encode a Vector into a TensorData.
|
||||
--
|
||||
-- The values should be in row major order, e.g.,
|
||||
--
|
||||
-- element 0: index (0, ..., 0)
|
||||
-- element 1: index (0, ..., 1)
|
||||
-- ...
|
||||
encodeTensorData :: Shape -> V.Vector a -> TensorData a
|
||||
|
||||
-- All types, besides ByteString, are encoded as simple arrays and we can use
|
||||
-- Vector.Storable to encode/decode by type casting pointers.
|
||||
|
||||
-- TODO(fmayle): Assert that the data type matches the return type.
|
||||
simpleDecode :: Storable a => TensorData a -> V.Vector a
|
||||
simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData
|
||||
|
||||
simpleEncode :: forall a . (TensorType a, Storable a)
|
||||
=> Shape -> V.Vector a -> TensorData a
|
||||
simpleEncode (Shape xs)
|
||||
= TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert
|
||||
where
|
||||
dt = tensorType (undefined :: a)
|
||||
|
||||
instance TensorType Float where
|
||||
tensorType _ = DT_FLOAT
|
||||
tensorRefType _ = DT_FLOAT_REF
|
||||
tensorVal = floatVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Double where
|
||||
tensorType _ = DT_DOUBLE
|
||||
tensorRefType _ = DT_DOUBLE_REF
|
||||
tensorVal = doubleVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int32 where
|
||||
tensorType _ = DT_INT32
|
||||
tensorRefType _ = DT_INT32_REF
|
||||
tensorVal = intVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int64 where
|
||||
tensorType _ = DT_INT64
|
||||
tensorRefType _ = DT_INT64_REF
|
||||
tensorVal = int64Val
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
integral :: Integral a => Lens' [Int32] [a]
|
||||
integral = iso (fmap fromIntegral) (fmap fromIntegral)
|
||||
|
||||
instance TensorType Word8 where
|
||||
tensorType _ = DT_UINT8
|
||||
tensorRefType _ = DT_UINT8_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Word16 where
|
||||
tensorType _ = DT_UINT16
|
||||
tensorRefType _ = DT_UINT16_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int16 where
|
||||
tensorType _ = DT_INT16
|
||||
tensorRefType _ = DT_INT16_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int8 where
|
||||
tensorType _ = DT_INT8
|
||||
tensorRefType _ = DT_INT8_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType ByteString where
|
||||
tensorType _ = DT_STRING
|
||||
tensorRefType _ = DT_STRING_REF
|
||||
tensorVal = stringVal
|
||||
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
||||
-- table offsets for each element :: [Word64]
|
||||
-- at each element offset:
|
||||
-- string length :: VarInt64
|
||||
-- string data :: [Word8]
|
||||
-- TODO(fmayle): Benchmark these functions.
|
||||
decodeTensorData tensorData =
|
||||
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
||||
if expected /= count
|
||||
then Left $ "decodeTensorData for ByteString count mismatch " ++
|
||||
show (expected, count)
|
||||
else V.mapM decodeString (S.convert offsets)
|
||||
where
|
||||
expected = S.length offsets
|
||||
count = fromIntegral $ product $ FFI.tensorDataDimensions
|
||||
$ unTensorData tensorData
|
||||
bytes = FFI.tensorDataBytes $ unTensorData tensorData
|
||||
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
|
||||
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
|
||||
decodeString :: Word64 -> Either String ByteString
|
||||
decodeString offset =
|
||||
let stringDataStart = B.drop (fromIntegral offset) dataBytes
|
||||
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
|
||||
stringParser :: Atto.Parser ByteString
|
||||
stringParser = getVarInt >>= Atto.take . fromIntegral
|
||||
encodeTensorData (Shape xs) vec =
|
||||
TensorData $ FFI.TensorData xs dt byteVector
|
||||
where
|
||||
dt = tensorType (undefined :: ByteString)
|
||||
-- Add a string to an offset table and data blob.
|
||||
addString :: (Builder, Builder, Word64)
|
||||
-> ByteString
|
||||
-> (Builder, Builder, Word64)
|
||||
addString (table, strings, offset) str =
|
||||
( table <> Builder.word64LE offset
|
||||
, strings <> lengthBytes <> Builder.byteString str
|
||||
, offset + lengthBytesLen + strLen
|
||||
)
|
||||
where
|
||||
strLen = fromIntegral $ B.length str
|
||||
lengthBytes = putVarInt $ fromIntegral $ B.length str
|
||||
lengthBytesLen =
|
||||
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
|
||||
-- Encode all strings.
|
||||
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
|
||||
-- Concat offset table with data.
|
||||
bytes = table' <> strings'
|
||||
-- Convert to Vector Word8.
|
||||
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
|
||||
|
||||
|
||||
instance TensorType Bool where
|
||||
tensorType _ = DT_BOOL
|
||||
tensorRefType _ = DT_BOOL_REF
|
||||
tensorVal = boolVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType (Complex Float) where
|
||||
tensorType _ = DT_COMPLEX64
|
||||
tensorRefType _ = DT_COMPLEX64
|
||||
tensorVal = error "TODO (Complex Float)"
|
||||
decodeTensorData = error "TODO (Complex Float)"
|
||||
encodeTensorData = error "TODO (Complex Float)"
|
||||
|
||||
instance TensorType (Complex Double) where
|
||||
tensorType _ = DT_COMPLEX128
|
||||
tensorRefType _ = DT_COMPLEX128
|
||||
tensorVal = error "TODO (Complex Double)"
|
||||
decodeTensorData = error "TODO (Complex Double)"
|
||||
encodeTensorData = error "TODO (Complex Double)"
|
||||
|
||||
-- | Shape (dimensions) of a tensor.
|
||||
newtype Shape = Shape [Int64] deriving Show
|
||||
|
||||
instance IsList Shape where
|
||||
type Item Shape = Int64
|
||||
fromList = Shape . fromList
|
||||
toList (Shape ss) = toList ss
|
||||
|
||||
protoShape :: Lens' TensorShapeProto Shape
|
||||
protoShape = iso protoToShape shapeToProto
|
||||
where
|
||||
protoToShape = Shape . fmap (view size) . view dim
|
||||
shapeToProto (Shape ds) = def & dim .~ fmap (\d -> def & size .~ d) ds
|
||||
|
||||
|
||||
class Attribute a where
|
||||
attrLens :: Lens' AttrValue a
|
||||
|
||||
instance Attribute Float where
|
||||
attrLens = f
|
||||
|
||||
instance Attribute ByteString where
|
||||
attrLens = s
|
||||
|
||||
instance Attribute Int64 where
|
||||
attrLens = i
|
||||
|
||||
instance Attribute DataType where
|
||||
attrLens = type'
|
||||
|
||||
instance Attribute TensorProto where
|
||||
attrLens = tensor
|
||||
|
||||
instance Attribute Bool where
|
||||
attrLens = b
|
||||
|
||||
instance Attribute Shape where
|
||||
attrLens = shape . protoShape
|
||||
|
||||
-- TODO(gnezdo): support generating list(Foo) from [Foo].
|
||||
instance Attribute AttrValue'ListValue where
|
||||
attrLens = list
|
||||
|
||||
instance Attribute [DataType] where
|
||||
attrLens = list . type'
|
||||
|
||||
instance Attribute [Int64] where
|
||||
attrLens = list . i
|
||||
|
||||
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
|
||||
--
|
||||
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
|
||||
-- natural representation as a conjunction, i.e.,
|
||||
--
|
||||
-- @
|
||||
-- a == Double || a == Float
|
||||
-- @
|
||||
--
|
||||
-- into a disjunction like
|
||||
--
|
||||
-- @
|
||||
-- a \/= Int32 && a \/= Int64 && a \/= ByteString && ...
|
||||
-- @
|
||||
--
|
||||
-- using an enumeration of all the possible 'TensorType's.
|
||||
type OneOf ts a
|
||||
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
|
||||
|
||||
-- | A 'Constraint' checking that the input is a list of 'TensorType's.
|
||||
-- Helps improve error messages when using 'OneOf'.
|
||||
type family TensorTypes ts :: Constraint where
|
||||
TensorTypes '[] = ()
|
||||
TensorTypes (t ': ts) = (TensorType t, TensorTypes ts)
|
||||
|
||||
-- | A constraint checking that two types are different.
|
||||
type family a /= b :: Constraint where
|
||||
a /= a = TypeError a ~ ExcludedCase
|
||||
a /= b = ()
|
||||
|
||||
-- | Helper types to produce a reasonable type error message when the Constraint
|
||||
-- "a /= a" fails.
|
||||
-- TODO(judahjacobson): Use ghc-8's CustomTypeErrors for this.
|
||||
data TypeError a
|
||||
data ExcludedCase
|
||||
|
||||
-- | An enumeration of all valid 'TensorType's.
|
||||
type AllTensorTypes =
|
||||
-- NOTE: This list should be kept in sync with
|
||||
-- TensorFlow.OpGen.dtTypeToHaskell.
|
||||
-- TODO: Add support for Complex Float/Double.
|
||||
'[ Float
|
||||
, Double
|
||||
, Int8
|
||||
, Int16
|
||||
, Int32
|
||||
, Int64
|
||||
, Word8
|
||||
, Word16
|
||||
, ByteString
|
||||
, Bool
|
||||
]
|
||||
|
||||
-- | Removes a type from the given list of types.
|
||||
type family Delete a as where
|
||||
Delete a '[] = '[]
|
||||
Delete a (a ': as) = Delete a as
|
||||
Delete a (b ': as) = b ': Delete a as
|
||||
|
||||
-- | Takes the difference of two lists of types.
|
||||
type family as \\ bs where
|
||||
as \\ '[] = as
|
||||
as \\ b ': bs = Delete b as \\ bs
|
||||
|
||||
-- | A constraint that the type @a@ doesn't appear in the type list @ts@.
|
||||
-- Assumes that @a@ and each of the elements of @ts@ are 'TensorType's.
|
||||
type family NoneOf ts a :: Constraint where
|
||||
NoneOf '[] a = ()
|
||||
NoneOf (t ': ts) a = (a /= t, NoneOf ts a)
|
84
tensorflow/tensorflow.cabal
Normal file
84
tensorflow/tensorflow.cabal
Normal file
|
@ -0,0 +1,84 @@
|
|||
name: tensorflow
|
||||
version: 0.1.0.0
|
||||
synopsis: TensorFlow bindings.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Simple
|
||||
cabal-version: >=1.22
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.Build
|
||||
, TensorFlow.BuildOp
|
||||
, TensorFlow.ControlFlow
|
||||
, TensorFlow.Internal.FFI
|
||||
, TensorFlow.Nodes
|
||||
, TensorFlow.Output
|
||||
, TensorFlow.Session
|
||||
, TensorFlow.Tensor
|
||||
, TensorFlow.Types
|
||||
, TensorFlow.Internal.VarInt
|
||||
other-modules: TensorFlow.Internal.Raw
|
||||
, TensorFlow.Orphans
|
||||
build-tools: c2hs
|
||||
build-depends: proto-lens == 0.1.*
|
||||
-- Used by the custom Setup script (for the test-suite).
|
||||
, proto-lens-protoc == 0.1.*
|
||||
, tensorflow-proto == 0.1.*
|
||||
, base >= 4.7 && < 5
|
||||
, async
|
||||
, attoparsec
|
||||
, bytestring
|
||||
, containers
|
||||
, data-default
|
||||
, fgl
|
||||
, lens-family
|
||||
, mainland-pretty
|
||||
, mtl
|
||||
, semigroups
|
||||
, split
|
||||
, text
|
||||
, temporary
|
||||
, transformers
|
||||
, vector
|
||||
extra-libraries: tensorflow_c
|
||||
default-language: Haskell2010
|
||||
include-dirs: .
|
||||
|
||||
Test-Suite FFITest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: FFITest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, bytestring
|
||||
, lens-family
|
||||
, proto-lens
|
||||
, tensorflow
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
|
||||
|
||||
Test-Suite VarIntTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: VarIntTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: base
|
||||
, attoparsec
|
||||
, bytestring
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, test-framework
|
||||
, test-framework-quickcheck2
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
38
tensorflow/tests/FFITest.hs
Normal file
38
tensorflow/tests/FFITest.hs
Normal file
|
@ -0,0 +1,38 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- | Tests for FFI.
|
||||
|
||||
module Main where
|
||||
|
||||
import Data.ProtoLens (decodeMessage)
|
||||
import Lens.Family2 (view)
|
||||
import TensorFlow.Internal.FFI (getAllOpList)
|
||||
import Test.HUnit (assertBool, assertFailure)
|
||||
import Test.Framework (defaultMain)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Proto.Tensorflow.Core.Framework.OpDef (OpList, op)
|
||||
|
||||
testParseAll :: IO ()
|
||||
testParseAll = do
|
||||
opList <- getAllOpList
|
||||
either
|
||||
assertFailure
|
||||
(assertBool "Expected non-empty list of default Ops"
|
||||
. not . null . view op)
|
||||
(decodeMessage opList :: Either String OpList)
|
||||
|
||||
main = defaultMain
|
||||
[ testCase "ParseAllOps" testParseAll
|
||||
]
|
32
tensorflow/tests/VarIntTest.hs
Normal file
32
tensorflow/tests/VarIntTest.hs
Normal file
|
@ -0,0 +1,32 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
module Main where
|
||||
|
||||
import Data.ByteString.Builder (toLazyByteString)
|
||||
import Google.Test (googleTest)
|
||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||
import qualified Data.Attoparsec.ByteString.Lazy as Atto
|
||||
|
||||
import TensorFlow.Internal.VarInt
|
||||
|
||||
testEncodeDecode = testProperty "testEncodeDecode" $ \x ->
|
||||
let bytes = toLazyByteString (putVarInt x)
|
||||
in case Atto.eitherResult $ Atto.parse getVarInt bytes of
|
||||
Left _ -> False
|
||||
Right y -> x == y
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testEncodeDecode
|
||||
]
|
1
tensorflow/third_party
Symbolic link
1
tensorflow/third_party
Symbolic link
|
@ -0,0 +1 @@
|
|||
../third_party/tensorflow
|
1
third_party/tensorflow
vendored
Submodule
1
third_party/tensorflow
vendored
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit bac7faa9a3eb5b60687a83336202cd3493de5385
|
Loading…
Reference in a new issue