commit 67690d149986f3787d3768594ee544fca6ee1d83 Author: Greg Steuck Date: Mon Oct 24 19:26:42 2016 +0000 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f6172ea --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +**/.stack-work +.stack/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..b1cb803 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/tensorflow"] + path = third_party/tensorflow + url = https://github.com/tensorflow/tensorflow.git diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..21e4c71 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,25 @@ +Want to contribute? Great! First, read this page (including the small print at the end). + +### Before you contribute +Before we can use your code, you must sign the +[Google Individual Contributor License Agreement](https://cla.developers.google.com/about/google-individual) +(CLA), which you can do online. The CLA is necessary mainly because you own the +copyright to your changes, even after your contribution becomes part of our +codebase, so we need your permission to use and distribute your code. We also +need to be sure of various other things—for instance that you'll tell us if you +know that your code infringes on other people's patents. You don't have to sign +the CLA until after you've submitted your code for review and a member has +approved it, but you must do it before we can put your code into our codebase. +Before you start working on a larger contribution, you should get in touch with +us first through the issue tracker with your idea so that we can help out and +possibly guide you. Coordinating up front makes it much easier to avoid +frustration later on. + +### Code reviews +All submissions, including submissions by project members, require review. We +use Github pull requests for this purpose. + +### The small print +Contributions made by corporations are covered by a different agreement than +the one above, the +[Software Grant and Corporate Contributor License Agreement](https://cla.developers.google.com/about/google-corporate). diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f89eb33 --- /dev/null +++ b/LICENSE @@ -0,0 +1,203 @@ +Copyright 2016 The TensorFlow Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2016, The TensorFlow Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..129bf6b --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ +The tensorflow-haskell package provides Haskell bindings to +[TensorFlow](https://www.tensorflow.org/). + +This is not an official Google product. + +# Instructions + +## Build + +For now [docker](https://www.docker.com/) is required. Once you have docker +working, the following commands will compile and run the tests. + + git clone --recursive https://github.com/tensorflow/haskell.git tensorflow-haskell + cd tensorflow-haskell + IMAGE_NAME=tensorflow/haskell:v0 + docker build -t $IMAGE_NAME docker + # TODO: move the setup step to the docker script. + stack --docker --docker-image=$IMAGE_NAME setup + stack --docker --docker-image=$IMAGE_NAME test + +There is also a demo application: + + cd tensorflow-mnist + stack --docker --docker-image=$IMAGE_NAME build --exec Main diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..7c27d67 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,29 @@ +# Prepare the image with: +# docker build -t gnezdo/tfhs:2016-10-14-1 docker +FROM gcr.io/tensorflow/tensorflow:latest-devel +MAINTAINER Greg Steuck +# Installs protoc and the libraries. +RUN \ + cd /tensorflow && \ + bazel --batch build -c opt '@protobuf//:protoc' && \ + install -s bazel-bin/external/protobuf/protoc /usr/local/bin && \ + bazel --batch build -c opt '//tensorflow:libtensorflow_c.so' && \ + install bazel-bin/tensorflow/libtensorflow_c.so /usr/local/lib && \ + bazel --batch clean + +RUN apt-get update + +RUN apt-get install -y \ + # Avoids /usr/bin/ld: cannot find -ltinfo + libncurses5-dev \ + # Makes stack viable in the container + libgmp-dev \ + # Required for locales configuration. + locales + +# Our MNIST demo program outputs Unicode characters. +RUN dpkg-reconfigure locales && \ + locale-gen en_US.UTF-8 && \ + update-locale LANG=en_US.UTF-8 + +ENV LANG en_US.UTF-8 diff --git a/google-shim/Setup.hs b/google-shim/Setup.hs new file mode 100644 index 0000000..e8ef27d --- /dev/null +++ b/google-shim/Setup.hs @@ -0,0 +1,3 @@ +import Distribution.Simple + +main = defaultMain diff --git a/google-shim/google-shim.cabal b/google-shim/google-shim.cabal new file mode 100644 index 0000000..ff9bbba --- /dev/null +++ b/google-shim/google-shim.cabal @@ -0,0 +1,23 @@ +name: google-shim +version: 0.1.0.0 +synopsis: Adapters to externalize TensorFlow code. +description: Please see README.md +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Simple +cabal-version: >=1.22 + +library + hs-source-dirs: src + exposed-modules: Google.Test + build-depends: base >= 4.7 && < 5 + , test-framework + default-language: Haskell2010 + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/google-shim/src/Google/Test.hs b/google-shim/src/Google/Test.hs new file mode 100644 index 0000000..26e1659 --- /dev/null +++ b/google-shim/src/Google/Test.hs @@ -0,0 +1,22 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- | Alternative implementations to make dependent code work without +-- changes. +module Google.Test where + +import Test.Framework (Test, defaultMain) + +googleTest :: [Test] -> IO () +googleTest = defaultMain diff --git a/stack.yaml b/stack.yaml new file mode 100644 index 0000000..2f0ce20 --- /dev/null +++ b/stack.yaml @@ -0,0 +1,24 @@ +resolver: lts-6.2 + +packages: +- google-shim +- tensorflow +- tensorflow-core-ops +- tensorflow-opgen +- tensorflow-ops +- tensorflow-proto +- tensorflow-mnist +- tensorflow-mnist-input-data +- tensorflow-queue + +extra-deps: +# proto-lens is not yet in Stackage. +- proto-lens-0.1.0.4 +- proto-lens-protoc-0.1.0.4 + +# Allow our custom Setup.hs scripts to import Data.ProtoLens.Setup from the version of +# `proto-lens-protoc` in stack's local DB. See: +# https://github.com/google/proto-lens/blob/master/README.md#using-cabal +explicit-setup-deps: + "*": true + diff --git a/tensorflow-core-ops/Setup.hs b/tensorflow-core-ops/Setup.hs new file mode 100644 index 0000000..5df8ba8 --- /dev/null +++ b/tensorflow-core-ops/Setup.hs @@ -0,0 +1,100 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- | Generates the wrappers for Ops shipped with tensorflow_c. +module Main where + +import Distribution.Simple.BuildPaths (autogenModulesDir) +import Distribution.Simple.LocalBuildInfo (LocalBuildInfo) +import Distribution.Simple + ( defaultMainWithHooks + , simpleUserHooks + , UserHooks(..) + ) +import Data.List (intercalate) +import Data.ProtoLens (decodeMessage) +import System.Directory (createDirectoryIfMissing) +import System.Exit (exitFailure) +import System.FilePath (()) +import System.IO (hPutStrLn, stderr) +import TensorFlow.Internal.FFI (getAllOpList) +import TensorFlow.OpGen (docOpList, OpGenFlags(..)) +import Text.PrettyPrint.Mainland (prettyLazyText) +import qualified Data.Text.Lazy.IO as Text + +main = defaultMainWithHooks generatingOpsWrappers + +-- TODO: Generalize for user libraries by replacing getAllOpList with +-- a wrapper around TF_LoadLibrary. The complicated part is interplay +-- between bazel and Haskell build system. +generatingOpsWrappers :: UserHooks +generatingOpsWrappers = hooks + { buildHook = \p l h f -> generateSources l >> buildHook hooks p l h f + , haddockHook = \p l h f -> generateSources l >> haddockHook hooks p l h f + , replHook = \p l h f args -> generateSources l + >> replHook hooks p l h f args + } + where + flagsBuilder dir = OpGenFlags + { outputFile = dir "Core.hs" + , prefix = "TensorFlow.GenOps" + , excludeList = intercalate "," blackList + } + hooks = simpleUserHooks + generateSources :: LocalBuildInfo -> IO () + generateSources l = do + let dir = autogenModulesDir l "TensorFlow/GenOps" + createDirectoryIfMissing True dir + let flags = flagsBuilder dir + pb <- getAllOpList + case decodeMessage pb of + Left e -> hPutStrLn stderr e >> exitFailure + Right x -> Text.writeFile (outputFile flags) + (prettyLazyText 80 $ docOpList flags x) + +blackList = + -- A few data flow ops take a list of heterogeneous + -- parameters which we don't support in general form. + [ "HashTable" + , "MutableDenseHashTable" + , "MutableHashTable" + , "MutableHashTableOfTensors" + , "QueueDequeue" + , "QueueDequeueMany" + , "QueueDequeueUpTo" + , "Stack" + , "TensorArray" + -- These should be possible to support by adding a bunch of + -- overloads with a variable number of tuple arguments. + , "Assert" + , "BarrierTakeMany" + , "Print" + , "QueueEnqueue" + , "QueueEnqueueMany" + -- These have type ambiguities because one of the type arguments + -- doesn't appear in the signature. + , "ConditionalAccumulator" + , "SparseConditionalAccumulator" + -- Need list of types support. + , "DecodeCSV" + , "ParseExample" + , "ParseSingleSequenceExample" + , "Save" + , "SaveSlices" + , "SymbolicGradient" + , "_ArrayToList" + , "_ListToArray" + -- Easy: support larger result tuples. + , "Skipgram" + ] diff --git a/tensorflow-core-ops/tensorflow-core-ops.cabal b/tensorflow-core-ops/tensorflow-core-ops.cabal new file mode 100644 index 0000000..5267efa --- /dev/null +++ b/tensorflow-core-ops/tensorflow-core-ops.cabal @@ -0,0 +1,30 @@ +name: tensorflow-core-ops +version: 0.1.0.0 +synopsis: Haskell wrappers for Core Tensorflow Ops. +description: Code generated signatures for the Ops in libtensorflow_c. +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Custom +cabal-version: >=1.22 + +library + exposed-modules: TensorFlow.GenOps.Core + build-depends: Cabal >= 1.22 && < 1.25 + , bytestring + , proto-lens == 0.1.* + , tensorflow-opgen == 0.1.* + , tensorflow == 0.1.* + , base >= 4.7 && < 5 + , filepath + , mainland-pretty + , lens-family + , text + default-language: Haskell2010 + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/tensorflow-mnist-input-data/Setup.hs b/tensorflow-mnist-input-data/Setup.hs new file mode 100644 index 0000000..e858987 --- /dev/null +++ b/tensorflow-mnist-input-data/Setup.hs @@ -0,0 +1,113 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE LambdaCase #-} + +-- | Downloads the MNIST data set and packages them as data files. +module Main where + +import Control.Monad (when) +import Data.Maybe (fromMaybe) +import Distribution.PackageDescription + ( GenericPackageDescription(packageDescription) + , dataDir + ) +import Distribution.Simple + ( UserHooks(..) + , defaultMainWithHooks + , simpleUserHooks + ) +import System.IO (hPutStrLn, stderr) +import System.FilePath (()) +import System.Directory (doesFileExist) +import qualified Crypto.Hash as Hash +import qualified Data.ByteString.Lazy as B +import qualified Network.HTTP as HTTP +import qualified Network.URI as URI + +main :: IO () +main = defaultMainWithHooks downloadingDataFiles + +downloadingDataFiles :: UserHooks +downloadingDataFiles = hooks + { confHook = \gh@(g, _) c -> downloadFiles g >> confHook hooks gh c + } + where + hooks = simpleUserHooks + downloadFiles :: GenericPackageDescription -> IO () + downloadFiles g = do + let dir = dataDir (packageDescription g) + mapM_ (maybeDownload dir) fileInfos + +maybeDownload :: FilePath -> (String, String) -> IO () +maybeDownload dataDir (basename, sha256) = do + let filePath = dataDir basename + exists <- doesFileExist filePath + when (not exists) $ do + let url = urlPrefix ++ basename + hPutStrLn stderr ("Downloading " ++ url) + httpDownload url filePath + verify filePath sha256 + +httpDownload :: String -> FilePath -> IO () +httpDownload url outFile = do + let uri = fromMaybe + (error ("Can't be: invalid URI " ++ url)) + (URI.parseURI url) + result <- HTTP.simpleHTTP (HTTP.defaultGETRequest_ uri) + HTTP.getResponseCode result >>= \case + (2, 0, 0) -> HTTP.getResponseBody result >>= B.writeFile outFile + s -> error ( "Failed to download " ++ url ++ " error code " ++ show s + ++ helpfulMessage + ) + +verify :: FilePath -> String -> IO () +verify filePath hash = do + let sha256 = Hash.hashlazy :: B.ByteString -> Hash.Digest Hash.SHA256 + computed <- show . sha256 <$> B.readFile filePath + when (hash /= computed) $ + error ( "Incorrect checksum for " ++ filePath + ++ "\nexpected " ++ hash + ++ "\ncomputed " ++ computed + ++ helpfulMessage + ) + +urlPrefix = "http://yann.lecun.com/exdb/mnist/" + +-- | File names relative to 'urlPrefix' and their sha256. +fileInfos = [ + ( "train-images-idx3-ubyte.gz" + , "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609" + ) + , + ( "train-labels-idx1-ubyte.gz" + , "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c" + ) + , + ( "t10k-images-idx3-ubyte.gz" + , "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6" + ) + , + ( "t10k-labels-idx1-ubyte.gz" + , "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6" + ) + ] + +helpfulMessage = + unlines + ( "" + : "" + : "Please download the following URLs manually and put them in data/" + : [ urlPrefix ++ h | (h, _) <- fileInfos ] + ) diff --git a/tensorflow-mnist-input-data/data/marker-to-avoid-complaints.gz b/tensorflow-mnist-input-data/data/marker-to-avoid-complaints.gz new file mode 100644 index 0000000..e69de29 diff --git a/tensorflow-mnist-input-data/src/TensorFlow/Examples/MNIST/InputData.hs b/tensorflow-mnist-input-data/src/TensorFlow/Examples/MNIST/InputData.hs new file mode 100644 index 0000000..d8141e5 --- /dev/null +++ b/tensorflow-mnist-input-data/src/TensorFlow/Examples/MNIST/InputData.hs @@ -0,0 +1,31 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +module TensorFlow.Examples.MNIST.InputData + ( trainingImageData + , trainingLabelData + , testImageData + , testLabelData + ) where + +import Paths_tensorflow_mnist_input_data (getDataFileName) + +-- | Download the files containing the canonical MNIST samples and labels. +trainingImageData, trainingLabelData :: IO FilePath +trainingImageData = getDataFileName "train-images-idx3-ubyte.gz" +trainingLabelData = getDataFileName "train-labels-idx1-ubyte.gz" + +testImageData, testLabelData :: IO FilePath +testImageData = getDataFileName "t10k-images-idx3-ubyte.gz" +testLabelData = getDataFileName "t10k-labels-idx1-ubyte.gz" diff --git a/tensorflow-mnist-input-data/tensorflow-mnist-input-data.cabal b/tensorflow-mnist-input-data/tensorflow-mnist-input-data.cabal new file mode 100644 index 0000000..f60177b --- /dev/null +++ b/tensorflow-mnist-input-data/tensorflow-mnist-input-data.cabal @@ -0,0 +1,35 @@ +name: tensorflow-mnist-input-data +version: 0.1.0.0 +synopsis: Downloader of input data for training MNIST. +description: Please see README.md +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Custom +cabal-version: >=1.22 +-- These files are downloaded automatically by Setup.hs. If the +-- automatic download fails, follow the instructions in error messages +-- displayed by Setup.hs. +data-dir: data +data-files: *.gz + +library + hs-source-dirs: src + exposed-modules: TensorFlow.Examples.MNIST.InputData + other-modules: Paths_tensorflow_mnist_input_data + build-depends: Cabal + , HTTP + , base >= 4.7 && < 5 + , bytestring + , cryptonite + , directory + , filepath + , network-uri + default-language: Haskell2010 + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/tensorflow-mnist/Setup.hs b/tensorflow-mnist/Setup.hs new file mode 100644 index 0000000..e8ef27d --- /dev/null +++ b/tensorflow-mnist/Setup.hs @@ -0,0 +1,3 @@ +import Distribution.Simple + +main = defaultMain diff --git a/tensorflow-mnist/app/Main.hs b/tensorflow-mnist/app/Main.hs new file mode 100644 index 0000000..57d5ce0 --- /dev/null +++ b/tensorflow-mnist/app/Main.hs @@ -0,0 +1,161 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedLists #-} + +import Control.Monad (zipWithM, when, forM, forM_) +import Control.Monad.IO.Class (liftIO) +import Data.Int (Int32, Int64) +import qualified Data.Text.IO as T +import qualified Data.Vector as V + +import qualified TensorFlow.ControlFlow as TF +import qualified TensorFlow.Build as TF +import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Session as TF +import qualified TensorFlow.Tensor as TF +import qualified TensorFlow.Types as TF +import qualified TensorFlow.Gradient as TF + +import TensorFlow.Examples.MNIST.InputData +import TensorFlow.Examples.MNIST.Parse + +numPixels = 28^2 :: Int64 +numLabels = 10 :: Int64 + +-- | Create tensor with random values where the stddev depends on the width. +randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float) +randomParam width (TF.Shape shape) = + (* stddev) <$> TF.truncatedNormal (TF.vector shape) + where + stddev = TF.scalar (1 / sqrt (fromIntegral width)) + +-- Types must match due to model structure (sparseToDense requires +-- index types to match) +type LabelType = Int32 +type BatchSize = Int32 + +-- | Convert scalar labels to one-hot vectors. +labelClasses :: TF.Tensor TF.Value LabelType + -> LabelType + -> BatchSize + -> TF.Tensor TF.Value Float +labelClasses labels numClasses batchSize = + let indices = TF.range 0 (TF.scalar batchSize) 1 + concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1] + in TF.sparseToDense concated + (TF.constant [2] [batchSize, numClasses]) + 1 {- ON value -} + 0 {- default (OFF) value -} + +-- | Fraction of elements that differ between two vectors. +errorRate :: Eq a => V.Vector a -> V.Vector a -> Double +errorRate xs ys = fromIntegral (len - numCorrect) / fromIntegral len + where + numCorrect = V.length $ V.filter id $ V.zipWith (==) xs ys + len = V.length xs + +data Model = Model { + train :: TF.TensorData Float -- ^ images + -> TF.TensorData LabelType + -> TF.Session () + , infer :: TF.TensorData Float -- ^ images + -> TF.Session (V.Vector LabelType) -- ^ predictions + } + +createModel :: Int64 -> TF.Build Model +createModel batchSize = do + -- Inputs. + images <- TF.placeholder [batchSize, numPixels] + -- Hidden layer. + let numUnits = 500 + hiddenWeights <- + TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits] + hiddenBiases <- TF.zeroInitializedVariable [numUnits] + let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases + let hidden = TF.relu hiddenZ + -- Logits. + logitWeights <- + TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels] + logitBiases <- TF.zeroInitializedVariable [numLabels] + let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases + predict <- TF.render $ TF.cast $ + TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType)) + + -- Create training action. + labels <- TF.placeholder [batchSize] + let labelVecs = labelClasses labels 10 (fromIntegral batchSize) + loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs + params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases] + grads <- TF.gradients loss params + + let lr = TF.scalar $ 0.001 / fromIntegral batchSize + applyGrad param grad + = TF.assign param $ param `TF.sub` (lr * grad) + trainStep <- TF.group =<< zipWithM applyGrad params grads + + return Model { + train = \imFeed lFeed -> TF.runWithFeeds_ [ + TF.feed images imFeed + , TF.feed labels lFeed + ] trainStep + , infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict + } + +main = TF.runSession $ do + -- Read training and test data. + trainingImages <- liftIO (readMNISTSamples =<< trainingImageData) + trainingLabels <- liftIO (readMNISTLabels =<< trainingLabelData) + testImages <- liftIO (readMNISTSamples =<< testImageData) + testLabels <- liftIO (readMNISTLabels =<< testLabelData) + + let batchSize = 100 :: Int64 + + -- Create the model. + model <- TF.build $ createModel batchSize + + -- Helpers for generate batches. + let selectBatch i xs = take size $ drop (i * size) $ cycle xs + where size = fromIntegral batchSize + let getImageBatch i xs = TF.encodeTensorData + [batchSize, numPixels] + $ fromIntegral <$> mconcat (selectBatch i xs) + let getExpectedLabelBatch i xs = + fromIntegral <$> V.fromList (selectBatch i xs) + + -- Train. + forM_ ([0..1000] :: [Int]) $ \i -> do + let images = getImageBatch i trainingImages + labels = getExpectedLabelBatch i trainingLabels + train model images (TF.encodeTensorData [batchSize] labels) + when (i `mod` 100 == 0) $ do + preds <- infer model images + liftIO $ putStrLn $ + "training error " ++ show (errorRate preds labels * 100) + liftIO $ putStrLn "" + + -- Test. + let numTestBatches = length testImages `div` fromIntegral batchSize + testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do + infer model (getImageBatch i testImages) + let testExpected = fromIntegral <$> V.fromList testLabels + liftIO $ putStrLn $ + "test error " ++ show (errorRate testPreds testExpected * 100) + + -- Show some predictions. + liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do + putStrLn "" + T.putStrLn $ drawMNIST $ testImages !! i + putStrLn $ "expected " ++ show (testLabels !! i) + putStrLn $ " got " ++ show (testPreds V.! i) diff --git a/tensorflow-mnist/data/MNIST.pb b/tensorflow-mnist/data/MNIST.pb new file mode 100644 index 0000000..2ec82e0 Binary files /dev/null and b/tensorflow-mnist/data/MNIST.pb differ diff --git a/tensorflow-mnist/data/MNISTBias.ckpt b/tensorflow-mnist/data/MNISTBias.ckpt new file mode 100644 index 0000000..a9d89e3 Binary files /dev/null and b/tensorflow-mnist/data/MNISTBias.ckpt differ diff --git a/tensorflow-mnist/data/MNISTWts.ckpt b/tensorflow-mnist/data/MNISTWts.ckpt new file mode 100644 index 0000000..bad8763 Binary files /dev/null and b/tensorflow-mnist/data/MNISTWts.ckpt differ diff --git a/tensorflow-mnist/src-data/TensorFlow/Examples/MNIST/TrainedGraph.hs b/tensorflow-mnist/src-data/TensorFlow/Examples/MNIST/TrainedGraph.hs new file mode 100644 index 0000000..fdc02c3 --- /dev/null +++ b/tensorflow-mnist/src-data/TensorFlow/Examples/MNIST/TrainedGraph.hs @@ -0,0 +1,30 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} +-- | Paths to test helper files. +module TensorFlow.Examples.MNIST.TrainedGraph where + +import Paths_tensorflow_mnist (getDataFileName) +import Data.ByteString (ByteString) +import Data.ByteString.Char8 (pack) + +-- | File containing a Tensorflow serialized proto of MNIST. +mnistPb :: IO FilePath +mnistPb = getDataFileName "data/MNIST.pb" + +-- | Files containing pre-trained weights for MNIST. +wtsCkpt, biasCkpt :: IO ByteString +wtsCkpt = pack <$> getDataFileName "data/MNISTWts.ckpt" +biasCkpt = pack <$> getDataFileName "data/MNISTBias.ckpt" diff --git a/tensorflow-mnist/src/TensorFlow/Examples/MNIST/Parse.hs b/tensorflow-mnist/src/TensorFlow/Examples/MNIST/Parse.hs new file mode 100644 index 0000000..35a835a --- /dev/null +++ b/tensorflow-mnist/src/TensorFlow/Examples/MNIST/Parse.hs @@ -0,0 +1,96 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ViewPatterns #-} + +module TensorFlow.Examples.MNIST.Parse where + +import Control.Monad (when, liftM) +import Data.Binary.Get (Get, runGet, getWord32be, getLazyByteString) +import Data.ByteString.Lazy (toStrict, readFile) +import Data.List.Split (chunksOf) +import Data.Monoid ((<>)) +import Data.ProtoLens (Message, decodeMessageOrDie) +import Data.Text (Text) +import Data.Word (Word8, Word32) +import Prelude hiding (readFile) +import qualified Codec.Compression.GZip as GZip +import qualified Data.ByteString.Lazy as L +import qualified Data.Text as Text +import qualified Data.Vector as V + +-- | Utilities specific to MNIST. +type MNIST = V.Vector Word8 + +-- | Produces a unicode rendering of the MNIST digit sample. +drawMNIST :: MNIST -> Text +drawMNIST = chunk . block + where + block :: V.Vector Word8 -> Text + block (V.splitAt 1 -> ([0], xs)) = " " <> block xs + block (V.splitAt 1 -> ([n], xs)) = c `Text.cons` block xs + where c = "\9617\9618\9619\9608" !! fromIntegral (n `div` 64) + block (V.splitAt 1 -> _) = "" + chunk :: Text -> Text + chunk "" = "\n" + chunk xs = Text.take 28 xs <> "\n" <> chunk (Text.drop 28 xs) + +-- | Check's the file's endianess, throwing an error if it's not as expected. +checkEndian :: Get () +checkEndian = do + magic <- getWord32be + when (magic `notElem` ([2049, 2051] :: [Word32])) $ + fail "Expected big endian, but image file is little endian." + +-- | Reads an MNIST file and returns a list of samples. +readMNISTSamples :: FilePath -> IO [MNIST] +readMNISTSamples path = do + raw <- GZip.decompress <$> readFile path + return $ runGet getMNIST raw + where + getMNIST :: Get [MNIST] + getMNIST = do + checkEndian + -- Parse header data. + cnt <- liftM fromIntegral getWord32be + rows <- liftM fromIntegral getWord32be + cols <- liftM fromIntegral getWord32be + -- Read all of the data, then split into samples. + pixels <- getLazyByteString $ fromIntegral $ cnt * rows * cols + return $ V.fromList <$> chunksOf (rows * cols) (L.unpack pixels) + +-- | Reads a list of MNIST labels from a file and returns them. +readMNISTLabels :: FilePath -> IO [Word8] +readMNISTLabels path = do + raw <- GZip.decompress <$> readFile path + return $ runGet getLabels raw + where getLabels :: Get [Word8] + getLabels = do + checkEndian + -- Parse header data. + cnt <- liftM fromIntegral getWord32be + -- Read all of the labels. + L.unpack <$> getLazyByteString cnt + +readMessageFromFileOrDie :: Message m => FilePath -> IO m +readMessageFromFileOrDie path = do + pb <- readFile path + return $ decodeMessageOrDie $ toStrict pb + +-- TODO: Write a writeMessageFromFileOrDie and read/write non-lethal +-- versions. diff --git a/tensorflow-mnist/tensorflow-mnist.cabal b/tensorflow-mnist/tensorflow-mnist.cabal new file mode 100644 index 0000000..5d406fe --- /dev/null +++ b/tensorflow-mnist/tensorflow-mnist.cabal @@ -0,0 +1,80 @@ +name: tensorflow-mnist +version: 0.1.0.0 +synopsis: TensorFlow demo application for learning MNIST model. +description: Please see README.md +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Simple +cabal-version: >=1.22 +data-files: data/*.ckpt + , data/*.pb + +library + hs-source-dirs: src + , src-data + exposed-modules: TensorFlow.Examples.MNIST.Parse + , TensorFlow.Examples.MNIST.TrainedGraph + other-modules: Paths_tensorflow_mnist + build-depends: proto-lens == 0.1.* + , base >= 4.7 && < 5 + , binary + , bytestring + , filepath + , lens-family + , containers + , split + , tensorflow-proto == 0.1.* + , tensorflow-core-ops == 0.1.* + , tensorflow + , text + , vector + , zlib + default-language: Haskell2010 + +executable Main + default-language: Haskell2010 + main-is: Main.hs + hs-source-dirs: app + build-depends: base + , bytestring + , filepath + , lens-family + , proto-lens + , tensorflow + , tensorflow-mnist + , tensorflow-mnist-input-data + , tensorflow-ops + , tensorflow-proto + , text + , transformers + , vector + +Test-Suite ParseTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: ParseTest.hs + hs-source-dirs: tests + build-depends: HUnit + , base + , bytestring + , proto-lens + , lens-family + , google-shim + , tensorflow + , tensorflow-mnist + , tensorflow-mnist-input-data + , tensorflow-ops + , tensorflow-proto + , test-framework + , test-framework-hunit + , text + , transformers + , vector + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/tensorflow-mnist/tests/ParseTest.hs b/tensorflow-mnist/tests/ParseTest.hs new file mode 100644 index 0000000..26ef124 --- /dev/null +++ b/tensorflow-mnist/tests/ParseTest.hs @@ -0,0 +1,170 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} + +module Main where + +import Control.Monad.IO.Class (liftIO) +import Data.Int (Int64) +import Data.Text (Text) +import qualified Data.Text.IO as Text +import Lens.Family2 ((&), (.~), (^.)) +import Prelude hiding (abs) +import Proto.Tensorflow.Core.Framework.Graph + ( GraphDef(..) + , version + , node ) +import Proto.Tensorflow.Core.Framework.NodeDef + ( NodeDef(..) + , op ) +import System.IO as IO +import TensorFlow.Examples.MNIST.InputData +import TensorFlow.Examples.MNIST.Parse +import TensorFlow.Examples.MNIST.TrainedGraph +import TensorFlow.Build + ( asGraphDef + , addGraphDef + , render + ) +import TensorFlow.Tensor + ( Tensor(..) + , Ref + , Value + , feed + , TensorKind(..) + , tensorFromName + ) +import TensorFlow.Ops +import TensorFlow.Nodes (unScalar) +import TensorFlow.Session + (runSession, run, run_, runWithFeeds, build, buildAnd) +import TensorFlow.Types (TensorType(..), Shape(..)) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@=?), Assertion) +import Google.Test +import qualified Data.Vector as V + +-- | Test that a file can be read and the GraphDef proto correctly parsed. +testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do + -- Check the function on a known well-formatted file. + mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef + -- Simple field read. + 1 @=? mnist^.version + -- Count the number of nodes. + let nodes :: [NodeDef] + nodes = mnist^.node + 100 @=? length nodes + -- Check that the expected op is found at an arbitrary index. + "Variable" @=? nodes!!6^.op + +-- | Parse the test set for label and image data. Will only fail if the file is +-- missing or incredibly corrupt. +testReadMNIST = testCase "testReadMNIST" $ do + imageData <- readMNISTSamples =<< testImageData + 10000 @=? length imageData + labelData <- readMNISTLabels =<< testLabelData + 10000 @=? length labelData + +testNodeName :: Text -> Tensor v a -> Assertion +testNodeName n g = n @=? opName + where + opName = head (gDef^.node)^.op + gDef = asGraphDef $ render g + +testGraphDefGen = testCase "testGraphDefGen" $ do + -- Test the inferred operation type. + let f0 :: Tensor Value Float + f0 = 0 + testNodeName "Const" f0 + testNodeName "Add" $ 1 + f0 + testNodeName "Mul" $ 1 * f0 + testNodeName "Sub" $ 1 - f0 + testNodeName "Abs" $ abs f0 + testNodeName "Sign" $ signum f0 + testNodeName "Neg" $ -f0 + -- Test the grouping. + testNodeName "Add" $ 1 + f0 * 2 + testNodeName "Add" $ 1 + (f0 * 2) + testNodeName "Mul" $ (1 + f0) * 2 + +-- | Convert a simple graph to GraphDef, load it, run it, and check the output. +testGraphDefExec = testCase "testGraphDefExec" $ do + let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10 + runSession $ do + build $ addGraphDef graphDef + x <- run $ tensorFromName ValueKind "Mul_2" + liftIO $ (50 :: Float) @=? unScalar x + +-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on +-- sample data. +testMNISTExec = testCase "testMNISTExec" $ do + -- Switch to unicode to enable pretty printing of MNIST digits. + IO.hSetEncoding IO.stdout IO.utf8 + + -- Parse the Graph definition, samples, & labels from files. + mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef + mnistSamples <- readMNISTSamples =<< testImageData + mnistLabels <- readMNISTLabels =<< testLabelData + + -- Select a sample to run on and convert it into a TensorData of Floats. + let idx = 12 + sample :: MNIST + sample = mnistSamples !! idx + label = mnistLabels !! idx + tensorSample = encodeTensorData (Shape [1,784]) floatSample + where + floatSample :: V.Vector Float + floatSample = V.map fromIntegral sample + Text.putStrLn $ drawMNIST sample + + -- Execute the graph on the sample data. + runSession $ do + -- The version of this session is 0, but the version of the graph is 1. + -- Change the graph version to 0 so they're compatible. + build $ addGraphDef $ mnist & version .~ 0 + -- Define nodes that restore saved weights and biases. + let bias, wts :: Tensor Ref Float + bias = tensorFromName RefKind "Variable" + wts = tensorFromName RefKind "weights" + wtsCkptPath <- liftIO wtsCkpt + biasCkptPath <- liftIO biasCkpt + -- Run those restoring nodes on the graph in the current session. + buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a]) + [ restore wtsCkptPath wts + , restoreFromName biasCkptPath "bias" bias + ] + -- Encode the expected sample data as one-hot data. + let ty = encodeTensorData [10] oneHotLabels + where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates + updates = [(fromIntegral label, 1)] + let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample + , feed (tensorFromName ValueKind "y-input") ty + ] + -- Run the graph with the input feeds and read the ArgMax'd result from + -- the test (not training) side of the evaluation. + x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax" + -- Print the trained model's predicted outcome. + liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n" + ++ "Prediction: " ++ show (unScalar x :: Int64) + -- Check whether the prediction matches the expectation. + liftIO $ (fromInteger . toInteger $ label :: Int64) @=? unScalar x + +main :: IO () +main = googleTest [ testReadMessageFromFileOrDie + , testReadMNIST + , testGraphDefGen + , testGraphDefExec + , testMNISTExec] diff --git a/tensorflow-opgen/Setup.hs b/tensorflow-opgen/Setup.hs new file mode 100644 index 0000000..e8ef27d --- /dev/null +++ b/tensorflow-opgen/Setup.hs @@ -0,0 +1,3 @@ +import Distribution.Simple + +main = defaultMain diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs new file mode 100644 index 0000000..bf0fc6e --- /dev/null +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -0,0 +1,457 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeFamilies #-} +-- | Rendering of TensorFlow operations as Haskell functions. + +module TensorFlow.OpGen + ( OpGenFlags(..) + , docOpList + , flagParser) + where + +import Prelude hiding (head, tail) + +import Control.Monad (guard) +import Data.Char (toLower, toUpper) +import Data.Foldable (toList) +import Data.Maybe (fromMaybe, maybeToList) +import Data.ProtoLens (def, showMessage) +import Data.List.NonEmpty (NonEmpty((:|)), head) +import qualified Data.List.NonEmpty as NE +import Lens.Family2 ((^.), (.~), (&), view) +import Options.Applicative (Parser, help, long, strOption, value) +import Proto.Tensorflow.Core.Framework.OpDef + ( OpList + , OpDef + , OpDef'ArgDef + , attr + , description + , inputArg + , name + , numberAttr + , op + , outputArg + , summary + , type' + , typeAttr + ) +import Proto.Tensorflow.Core.Framework.Types (DataType(..)) +import System.FilePath (takeBaseName) +import TensorFlow.OpGen.AttrVal + (AttrDef + , AttrCase(..) + , AttrTemplate(..) + , Template + , attrDef + , attrOriginal + , attrTemplate + , templateDefault + , templateRestrictions + ) +import Text.PrettyPrint.Mainland + ( Doc + , (<>) + , (<+>) + , () + , (<+/>) + , brackets + , comma + , commasep + , dquotes + , empty + , enclose + , flatten + , folddoc + , hang + , indent + , int + , parens + , sep + , stack + , strictText + , tuple + ) +import qualified Data.Map.Strict as Map +import qualified Data.Set as Set +import qualified Data.Text as Text +import qualified Data.Semigroup as Semigroup +import Data.Text (Text) + +data OpGenFlags = OpGenFlags + { outputFile :: String + , prefix :: String + , excludeList :: String + } + +flagParser :: Parser OpGenFlags +flagParser = OpGenFlags + <$> strOption (mconcat [ long "output" + , help "File to write." + ]) + <*> strOption (mconcat [ long "prefix" + , help "Haskell package prefix to use" + ]) + <*> strOption (mconcat [ long "exclude_list" + , value "" + , help "Comma separated Ops names to ignore" + ]) + + +docOpList :: OpGenFlags -> OpList -> Doc +docOpList flags opList = + stack [ "{-# LANGUAGE ConstraintKinds #-}" + , "{-# LANGUAGE DataKinds #-}" + , "{-# LANGUAGE FlexibleInstances #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE RankNTypes #-}" + , "{-# LANGUAGE ScopedTypeVariables #-}" + , "module" <+> strictText moduleName <+> "where" + , empty + , imports + , empty + , folddoc (\x y -> x empty y) + (map renderDef $ + filter (not . flip elem exclusions . view name) $ + toList $ opList ^. op) + ] + where moduleName = + Text.pack (prefix flags) <> "." <> camelCase + -- Discards the optional trailing _op_lib + (fromMaybe shortName (Text.stripSuffix "_op_lib" shortName)) + shortName = Text.pack (takeBaseName $ outputFile flags) + exclusions = Text.splitOn "," $ Text.pack $ excludeList flags + +camelCase s = Text.concat $ map upCase + $ filter (/= "ops") + $ Text.splitOn "_" s + +-- | Upper-case the given text. +upCase :: Text -> Text +upCase = forceCase toUpper + +-- | Lower-case the given name, and prevent it from overlapping with a reserved +-- Haskell name. +lowCase :: Text -> Text +lowCase = replaceReservedName . forceCase toLower + +forceCase :: (Char -> Char) -> Text -> Text +forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs) + (Text.uncons s) + +imports = stack [ + "import Data.ByteString (ByteString)" + , "import Data.Complex (Complex)" + , "import Data.Int (Int8, Int16, Int32, Int64)" + , "import Data.Word (Word8, Word16)" + , "import Lens.Family2 ((.~), (&))" + , "import TensorFlow.Build" + , "import TensorFlow.BuildOp" + , "import TensorFlow.Tensor" + , "import TensorFlow.Types" + ] + +renderDef :: OpDef -> Doc +renderDef d = + stack [ + haddocks + , n <+> "::" <+> hang 0 (typeSig d) + , n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" -- args are indented + -- the body needs to be indented wrt the name + indent indentation functionBody + , extras -- just for debug + ] + where + n = strictText $ fixOpName (d ^. name) + args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs + tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg] + fixOpName = lowCase + funcGuard = "eqLengthGuard" <+> brackets (commasep entries) + where + entries = + [ parens $ quotedText nAttr <> comma <+> + brackets (commasep $ toList $ + NE.map renderTensorName tensorNames) + | (nAttr, tensorNames) <- Map.toList $ numberAttrMap d + ] + renderTensorName x = parens $ quotedText x <> comma <+> + "length" <+> strictText x + -- Uses hang 0 to align the argument vertically on multiple lines. + functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts)) + indent indentation (sep tensorArgs) + buildFunction + | null outputListsSizes = "buildOp" + | otherwise = "buildListOp" <+> brackets (commasep outputListsSizes) + outputListsSizes = [ strictText numberAttrName + | o <- d ^. outputArg + , let numberAttrName = o ^. numberAttr + , not (Text.null numberAttrName) && + numberAttrName `Map.member` mandatoryAttrMap d + ] + buildOpParts = + "opDef" <+> quotedText (d ^. name) : + -- Renders tensor arguments. + [ "& opAttr" <+> quotedText tfName <+> + ".~ tensorType (undefined ::" <+> strictText hsName <> ")" + | (tfName, (hsName, _)) <- Map.toList typeMap + ] ++ + -- Renders mandatory attributes as function parameters. + [ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName + | (tfName, hsName) <- mandatoryAttrs + ] ++ + -- Renders sizes of tensor list types having number_attr. + [ "& opAttr" <+> quotedText nAttr <+> ".~" <+> + "(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)" + | (nAttr, tensorNames) <- Map.toList $ numberAttrMap d + ] + mandatoryAttrs = [(strictText tf, strictText hs) + | (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d) + ] + haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description) + extras = enclose "{-\n" "\n-}" $ + strictText $ Text.pack $ + showMessage ((def :: OpDef) + & inputArg .~ (d ^. inputArg) + & outputArg .~ (d ^. outputArg) + & attr .~ (d ^. attr)) + typeMap = opDefTypeMap d + +-- | Makes a quoted string doc out of the given text value. +quotedText :: Text.Text -> Doc +quotedText = dquotes . strictText + +-- | typeSig renders the type signature of the given OpDef. +typeSig :: OpDef -> Doc +typeSig d = + foralls <+> constraints <+/> + signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs]) + where + foralls | Map.null typeMap = empty + | otherwise = + "forall" + <+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap)) + <+> "." + constraints | Map.null typeMap = empty + | otherwise = + tuple (concatMap + (\(t, aDef) -> + "TensorType" <+> strictText t + : maybeToList (oneOfRestrictions aDef t)) + (Map.elems typeMap)) <+> "=>" + tensorInputs = zipWith tensorArg refTypes (d ^. inputArg) + refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)] + tensorArg refType arg = wrapArg refType arg <+> + hang 0 ("-- ^" <+> argComment arg) + -- Argument type is a list of tensors if number_attr is set; + -- otherwise it's a single Tensor. + wrapArg refType arg = + if Text.null (arg ^. numberAttr) then typ else brackets typ + where typ = tensorType refType arg + tensorType refType arg = + "Tensor" <+> refType <+> maybe directType strictText indirectType + where + indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap) + directType = dtTypeToDoc (arg ^. type') + outputs = + case d ^. outputArg of + [] -> "ControlNode" + [o] -> wrappedOutput o <+> "-- ^" <+> argComment o + os -> renderTupleResult os + wrappedOutput = wrapArg "Value" + -- Tuple result case is rendered differently to give + -- individual elements their own comments. + renderTupleResult os = + stack $ [ tuple (map wrappedOutput os) + , flatten commentSummary + ] ++ map commentDetails os + where + commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os] + commentDetails o = + stack [ "--" + , "-- *" <+> argComment o + ] + signatureFold = folddoc (\x y -> x "->" <+> y) + mandatoryAttrInputs = [ + dtTypeToDoc dtType <+> + hang 0 ("-- ^" <+> argComment' tfName descr) + | (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d] + typeMap = opDefTypeMap d + +-- | Returns the type restriction for the given tensor type if the +-- set of allowed types is not empty (i.e. restricted). +oneOfRestrictions :: AttrDef -> Text -> Maybe Doc +oneOfRestrictions aDef tName = do + typs <- onAttrType (^. templateRestrictions) aDef + guard $ not $ null typs + let typeList = commasep $ map strictText $ + Set.toList $ Set.fromList $ + map dtTypeToHaskell typs + return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName + +-- | Identifies the attributes used as tensor cardinalities. In such +-- cases a list of tensors is supplied as an input_arg. The number of +-- such inputs is communicated as a separate opAttr. +-- The result key is TensorFlow attribute name and the value is the +-- tensor names which have number_attr set to the result key. +numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text) +numberAttrMap d = Map.fromListWith (Semigroup.<>) [ + (nAttr, replaceReservedName (inp ^. name) :| []) + | inp <- d ^. inputArg + , nAttr <- [inp ^. numberAttr] + , not (Text.null nAttr) + ] + +argComment :: OpDef'ArgDef -> Doc +argComment arg = argComment' (arg ^. name) (arg ^. description) + +argComment' :: Text.Text -> Text.Text -> Doc +argComment' argName argDesc = + bold argName <> splitMultilineText (":" <+>) argDesc + +bold :: Text.Text -> Doc +bold n = strictText ("__" <> n <> "__") + +opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef) +opDefTypeMap d = + Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a] + +attrList :: OpDef -> [(Text.Text, AttrDef)] +attrList d = [(a ^. name, attrDef a) | a <- d ^. attr] + +isType :: AttrDef -> Bool +isType = fromMaybe False . onAttrType (const True) + +-- | Applies the given function to the data type. Is this a Prism? +onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a +onAttrType f x = case x ^. attrTemplate of + AttrSingle (AttrType a) -> Just (f a) + _ -> Nothing + +-- | mandatoryAttrMap contains the attributes chosen by +-- isMandatoryAttr, excluding those which are derived from list of +-- tensor arguments. The key is the TF name of the attribute. The +-- value tuple is (haskell name, TF type, attribute comment). +mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text) +mandatoryAttrMap d = + Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description)) + | (n, a) <- attrList d + , Just dtType <- [isMandatoryAttr a] + -- Excludes the attributes rendered as list lengths. + , n `Map.notMember` numberAttrMap d + ] + +-- | Inspects the attribute and if it is one of the implemented +-- non-tensor values lacking default, then returns Just the TF type. +isMandatoryAttr :: AttrDef -> Maybe DataType +isMandatoryAttr x = + case x ^. attrTemplate of + AttrSingle (AttrBool y) -> noDefault DT_BOOL y + AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y + AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y + _ -> Nothing + where + noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault) + +dtTypeToDoc :: DataType -> Doc +dtTypeToDoc = strictText . dtTypeToHaskell + +-- NOTE: The cases of this function should be kept in sync with +-- TensorFlow.Types.AllTensorTypes. +dtTypeToHaskell :: DataType -> Text.Text +dtTypeToHaskell DT_BOOL = "Bool" +dtTypeToHaskell DT_BFLOAT16 = "Data.Word.Word16" +dtTypeToHaskell DT_COMPLEX128 = "(Data.Complex.Complex Double)" +dtTypeToHaskell DT_COMPLEX64 = "(Data.Complex.Complex Float)" +dtTypeToHaskell DT_DOUBLE = "Double" +dtTypeToHaskell DT_FLOAT = "Float" +dtTypeToHaskell DT_INT16 = "Data.Int.Int16" +dtTypeToHaskell DT_INT32 = "Data.Int.Int32" +dtTypeToHaskell DT_INT64 = "Data.Int.Int64" +dtTypeToHaskell DT_INT8 = "Data.Int.Int8" +dtTypeToHaskell DT_QINT32 = "Data.Int.Int32" -- TODO(gnezdo): make unique +dtTypeToHaskell DT_QINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique +dtTypeToHaskell DT_QINT16 = "Data.Int.Int16" -- TODO(gnezdo): make unique +dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16" -- TODO(gnezdo): make unique +dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique +dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString" +dtTypeToHaskell DT_UINT16 = "Data.Word.Word16" +dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique +dtTypeToHaskell DT_UINT8 = "Data.Word.Word8" +dtTypeToHaskell x = + Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x + +-- | haddockComment escapes TensorFlow doc strings into haddock. +-- TODO(gnezdo): deal with the markup. +haddockComment :: Text.Text -> Doc +haddockComment = strictText + +multilineComment :: Text.Text -> Text.Text -> Doc +multilineComment summary' detail = + haddockComment summary' + splitMultilineText insertParagraphAndComment detail + where insertParagraphAndComment x = "--" "--" <+> x + +-- | Converts the given multi-line detail string into +-- a multi-line haddock. Applies the given lead to the +-- first line. Returns an empty document for empty detail. +splitMultilineText :: (Doc -> Doc) -> Text.Text -> Doc +splitMultilineText lead detail = + case Text.lines detail of + [] -> empty + (l : ls) -> stack $ lead (haddockComment l) + : map (("--" <+>) . haddockComment) ls + +replaceReservedName :: Text -> Text +replaceReservedName n + | n `Set.member` reservedKeywords = n <> "'" + | otherwise = n + +indentation = 4 + +reservedKeywords :: Set.Set Text +reservedKeywords = Set.fromList $ + -- Haskell2010 keywords: + -- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4 + -- We don't include keywords that are allowed to be variable names, + -- in particular: "as", "forall", and "hiding". + [ "case" + , "class" + , "data" + , "default" + , "deriving" + , "do" + , "else" + , "foreign" + , "if" + , "import" + , "in" + , "infix" + , "infixl" + , "infixr" + , "instance" + , "let" + , "module" + , "newtype" + , "of" + , "then" + , "type" + , "where" + ] + ++ -- Nonstandard extensions + [ "mdo" -- RecursiveDo + , "rec" -- Arrows, RecursiveDo + , "proc" -- Arrows + ] diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs b/tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs new file mode 100644 index 0000000..997a908 --- /dev/null +++ b/tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs @@ -0,0 +1,120 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} + +-- | Wrapping of TensorFlow attributes into Haskell entities. +module TensorFlow.OpGen.AttrVal + (AttrDef + , AttrCase(..) + , AttrTemplate(..) + , Template + , attrDef + , attrOriginal + , attrTemplate + , templateDefault + , templateRestrictions + ) where + +import Data.Int (Int64) +import Data.Monoid ((<>)) +import Lens.Family2 (Lens', (^.)) +import Lens.Family2.Unchecked (lens) +import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue +import Proto.Tensorflow.Core.Framework.OpDef as OpDef +import Proto.Tensorflow.Core.Framework.Types (DataType(..)) +import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto) +import qualified Data.ByteString as B +import qualified Data.Text as Text + +-- | Specifies the optional default value and a set of allowed values +-- for the given type. +data Template a = Template { + _templateDefault :: Maybe a + -- ^ The default value (mandatory if unspecified) + , _templateRestrictions :: [a] + -- ^ The allowed set of values, empty if no restrictions + } + +templateDefault :: Lens' (Template a) (Maybe a) +templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x }) + +templateRestrictions :: Lens' (Template a) [a] +templateRestrictions = lens _templateRestrictions + (\g x -> g { _templateRestrictions = x }) + +data UnusedTensor + +data AttrCase f + = AttrBytes (f B.ByteString) -- bytes s = 2; // "string" + | AttrInt64 (f Int64) -- int64 i = 3; // "int" + | AttrFloat (f Float) -- float f = 4; // "float" + | AttrBool (f Bool) -- bool b = 5; // "bool" + | AttrType (f DataType) -- type = 6; // "type" + -- To be translated into TensorFlow.Types.Shape before use. + -- Leaving as a proto to reduce dependencies. + | AttrShape (f TensorShapeProto) -- shape = 7; // "shape" + +-- | Type-reified representation of TensorFlow AttrDef. +-- Initially limited to just the types in Op descriptors. +data AttrTemplate + = AttrSingle (AttrCase Template) + | AttrList (AttrCase []) + | AttrTensor UnusedTensor -- tensor = 8; // "tensor" + +data AttrDef = AttrDef { + _attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from + , _attrTemplate :: AttrTemplate -- ^ the type of the attribute + } + +attrTemplate :: Lens' AttrDef AttrTemplate +attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x }) + +attrOriginal :: Lens' AttrDef OpDef'AttrDef +attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x }) + +attrDef :: OpDef'AttrDef -> AttrDef +attrDef a = AttrDef a + $ translate (a^.OpDef.type') + (a^.OpDef.defaultValue) + (a^.allowedValues) + +-- | Converts the given AttrValue with the type given by the string +-- into the AttrVal if the type is known. +translate :: Text.Text -- ^ one of the TensorFlow type strings + -> AttrValue -- ^ default value + -> AttrValue -- ^ allowed values + -> AttrTemplate +translate t defaults allowed + | t == "string" = makeVal AttrBytes maybe's s + | t == "int" = makeVal AttrInt64 maybe'i i + | t == "float" = makeVal AttrFloat maybe'f f + | t == "bool" = makeVal AttrBool maybe'b b + | t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type' + | t == "shape" = makeVal AttrShape maybe'shape shape + | t == "tensor" = AttrTensor $ error "tensor is unimplemented" + | t == "list(string)" = makeList AttrBytes $ list.s + | t == "list(int)" = makeList AttrInt64 $ list.i + | t == "list(float)" = makeList AttrFloat $ list.f + | t == "list(bool)" = makeList AttrBool $ list.b + | t == "list(type)" = makeList AttrType $ list.AttrValue.type' + | t == "list(shape)" = makeList AttrShape $ list.shape + | t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented" + | t == "func" = AttrTensor $ error "func is unimplemented" + | otherwise = error $ show ("Unknown attribute type " <> t) ++ + "," ++ show defaults ++ + "," ++ show allowed + where makeVal c x y = AttrSingle $ c $ + Template (defaults^.x) (allowed^.list.y) + makeList c y = AttrList $ c $ defaults^.y diff --git a/tensorflow-opgen/tensorflow-opgen.cabal b/tensorflow-opgen/tensorflow-opgen.cabal new file mode 100644 index 0000000..4028799 --- /dev/null +++ b/tensorflow-opgen/tensorflow-opgen.cabal @@ -0,0 +1,33 @@ +name: tensorflow-opgen +version: 0.1.0.0 +synopsis: Code generation for TensorFlow operations. +description: Please see README.md +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Simple +cabal-version: >=1.22 + +library + hs-source-dirs: src + exposed-modules: TensorFlow.OpGen.AttrVal + , TensorFlow.OpGen + build-depends: proto-lens == 0.1.* + , tensorflow-proto == 0.1.* + , base >= 4.7 && < 5 + , bytestring + , containers + , filepath + , lens-family + , mainland-pretty + , optparse-applicative + , semigroups + , text + default-language: Haskell2010 + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/tensorflow-opgen/third_party b/tensorflow-opgen/third_party new file mode 120000 index 0000000..20e9ecd --- /dev/null +++ b/tensorflow-opgen/third_party @@ -0,0 +1 @@ +../third_party/tensorflow \ No newline at end of file diff --git a/tensorflow-ops/Setup.hs b/tensorflow-ops/Setup.hs new file mode 100644 index 0000000..e8ef27d --- /dev/null +++ b/tensorflow-ops/Setup.hs @@ -0,0 +1,3 @@ +import Distribution.Simple + +main = defaultMain diff --git a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs new file mode 100644 index 0000000..9eb396b --- /dev/null +++ b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs @@ -0,0 +1,76 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE NoMonomorphismRestriction #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} + +-- | Parallel lookups on the list of tensors. +module TensorFlow.EmbeddingOps where + +import Control.Monad (zipWithM) +import Data.Int (Int32, Int64) +import Data.List (genericLength) +import TensorFlow.Build (Build, colocateWith, render) +import TensorFlow.Ops () -- Num instance for Tensor +import TensorFlow.Tensor (Tensor, Value) +import TensorFlow.Types (OneOf, TensorType) +import qualified TensorFlow.GenOps.Core as CoreOps + +-- | Looks up `ids` in a list of embedding tensors. +-- +-- This function is used to perform parallel lookups on the list of +-- tensors in `params`. It is a generalization of `TF.gather`, where +-- `params` is interpreted as a partition of a larger embedding +-- tensor. +-- +-- The partition_strategy is "mod", we assign each id to partition +-- `p = id % len(params)`. For instance, +-- 13 ids are split across 5 partitions as: +-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` +-- +-- The results of the lookup are concatenated into a dense +-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. +embeddingLookup :: forall a b v . + ( TensorType a + , OneOf '[Int64, Int32] b + , Num b + ) + => [Tensor v a] + -- ^ A list of tensors which can be concatenated along + -- dimension 0. Each `Tensor` must be appropriately + -- sized for `mod` partition strategy. + -> Tensor Value b + -- ^ A `Tensor` with type `int32` or `int64` + -- containing the ids to be looked up in `params`. + -- The ids are required to be flat on entry and have + -- fewer than 2^31 entries. + -> Build (Tensor Value a) + -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. +embeddingLookup params ids = + CoreOps.dynamicStitch pindices <$> partitionedResult + where np = genericLength params + pAssignments = CoreOps.cast (ids `CoreOps.mod` np) + newIds = ids `CoreOps.div` np + originalIndices = CoreOps.range 0 (CoreOps.size ids) 1 + -- Partition list of ids based on assignments into np separate lists + gatherIds = CoreOps.dynamicPartition np newIds pAssignments + -- Similarly, partition the original indices. + pindices = CoreOps.dynamicPartition np originalIndices pAssignments + -- Do np separate lookups, finding embeddings for plist[p] in params[p] + partitionedResult = zipWithM + (\p g -> colocateWith p $ render $ CoreOps.gather p g) + params gatherIds diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs new file mode 100644 index 0000000..f863e36 --- /dev/null +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -0,0 +1,697 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ViewPatterns #-} + +module TensorFlow.Gradient + ( gradients + ) where + +import Control.Monad (forM, zipWithM) +import Control.Monad.State.Strict (State, evalState, gets, modify) +import Data.ByteString (ByteString) +import Data.Complex (Complex) +import Data.Default (def) +import Data.Int (Int32, Int64) +import Data.List (foldl', sortBy) +import Data.Map.Strict (Map) +import Data.Maybe (fromMaybe, maybeToList, mapMaybe) +import Data.Ord (comparing) +import Data.ProtoLens.TextFormat (showMessage) +import Data.Set (Set) +import Data.Text (Text) +import Data.Tuple (swap) +import Lens.Family2 (Lens', (&), (^.), (.~), (%~)) +import Lens.Family2.State.Strict (uses) +import Lens.Family2.Stock (at, intAt) +import Lens.Family2.Unchecked (lens, iso) +import Prelude hiding (sum) +import Text.Printf (printf) +import qualified Data.Graph.Inductive.Basic as FGL +import qualified Data.Graph.Inductive.Graph as FGL +import qualified Data.Graph.Inductive.PatriciaTree as FGL +import qualified Data.Graph.Inductive.Query.DFS as FGL +import qualified Data.IntMap.Strict as IntMap +import qualified Data.Map.Strict as Map +import qualified Data.Set as Set +import qualified Data.Text as Text + +import qualified TensorFlow.GenOps.Core as CoreOps +import TensorFlow.Build + ( Build + , render + , renderNodeName + , renderedNodeDefs + , opDef + , opAttr + ) +import TensorFlow.BuildOp +import TensorFlow.Ops + ( addN + , broadcastGradientArgs + , expandDims + , fill + , matMul + , reducedShape + , reluGrad + , reshape + , scalar + , shape + , softmaxCrossEntropyWithLogits + , sum + , vector + , zerosLike + ) +import TensorFlow.Output + ( NodeName(..) + , Op (Rendered) + , Output(..) + , OutputIx(..) + , outputIndex + ) +import TensorFlow.Tensor + ( Tensor(..) + , TensorKind (ValueKind) + , Value + , tensorOutput + , tensorAttr + ) +import TensorFlow.Types (OneOf, TensorType, attrLens) +import Proto.Tensorflow.Core.Framework.NodeDef + (NodeDef, attr, input, op, name) + +type GradientCompatible a = + -- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason. + (Num a, OneOf '[ Float, Complex Float, Complex Double ] a) + +-- TODO(fmayle): Support control flow. +-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions. +-- TODO(fmayle): Do we need to consider control inputs? See _PendingCount in +-- tensorflow/python/ops/gradients.py. +-- TODO(fmayle): Maybe store the gradient functions and numOutputs on the OpDef. + + +-- | Gradient of @y@ w.r.t. each element of @xs@. +gradients :: forall a v1 v2 . ( Num (Tensor v1 a) + -- TODO(gnezdo): remove indirect constraint. + -- It's a wart inherited from Num instance. + , v1 ~ Value + , GradientCompatible a + ) + => Tensor v1 a -- ^ The output of the graph. + -> [Tensor v2 a] -- ^ Tensors for which gradients are computed. + -> Build [Tensor Value a] +gradients y xs = do + -- The gradients are computed using "reverse accumulation", similarly to + -- what is described here: + -- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation + -- + -- The code is summarised as follows: + -- + -- 1. Create an fgl graph of the relevant nodes (ops) and edges (tensors). + -- 2. Initialize the gradient of y to 1 (∂y/∂y = 1) and the rest of tensor's + -- gradients to nothing. + -- 3. Process the nodes in reverse topological order (i.e. each node comes + -- after all of its outputs so that the output gradients for a node have + -- been completely calculated before it is processed): + -- a. Record the gradient for each of the node's output tensors (∂y/∂w + -- for each output tensor w). + -- b. Calculate the gradient of y w.r.t. each of the node's input + -- tensors using the gradients of the node's output tensors. + -- + -- Written differently, for each output tensor w and input tensor v: + -- ∂y/∂w = ... (calculated in previous steps) + -- ∂w/∂v = ... (op specific) + -- ∂y/∂v = ∂y/∂w * ∂w/∂v (technically, if tensor v is an input + -- to multiple nodes, then this is only + -- part of ∂y/∂v) + -- + -- 4. Lookup the recorded gradient for each x in xs. + + yName <- renderNodeName y + -- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName? + nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $ + (\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x)) + . flip Map.lookup + let (gr, nodeMap) = createGraph yName nodeDefLookup + -- Set gradient of y to one. + let initPending :: Map.Map FGL.Node (PendingGradients a) + initPending = Map.empty & at (nodeMap Map.! yName) + . nonEmpty + . outputIxAt (y ^. tensorOutput . outputIndex) + . nonEmpty + .~ [fill (shape y) (scalar 1)] + -- Calculate the gradients of y w.r.t. each node in the graph. + gradientMap <- graphGrads gr initPending + -- Lookup the gradients for each x. + forM xs $ \x -> do + xName <- renderNodeName x + render $ fromMaybe (zerosLike x) $ do + n <- nodeMap ^. at xName + let i = x ^. tensorOutput . outputIndex + gradientMap ^. at n . nonEmpty . outputIxAt i + +outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v) +outputIxAt = intAt . unOutputIx + +-- | Incomplete gradients of a node's outputs. +-- +-- The lists represent partial sums. The key is an OutputIx sans newtype. +type PendingGradients a = IntMap.IntMap [Tensor Value a] + +-- | Gradients of a node's outputs. The key is an OutputIx sans newtype. +type Gradients a = IntMap.IntMap (Tensor Value a) + +-- | Graph of TensorFlow operations. +type Graph = FGL.Gr NodeDef EdgeLabel + +-- | Data associated with an edge. +-- +-- Pair of +-- 1. Output index of a tensor from the source node. +-- 2. Input index that the tensor connects to on the destination node. +type EdgeLabel = (OutputIx, OutputIx) + + +-- | State used for calculating gradients. +data GradientsState a = GradientsState + { _gradientsPending :: !(Map FGL.Node (PendingGradients a)) + , _gradientsResult :: !(Map FGL.Node (Gradients a)) + } + +gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a)) +gradientsPending = lens _gradientsPending (\x y -> x { _gradientsPending = y }) + +gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a)) +gradientsResult = lens _gradientsResult (\x y -> x { _gradientsResult = y }) + + +-- TODO(fmayle): Use something like Data.List.Safe. +-- | Safe version of (!!). +safeIndex :: [a] -> Int -> Maybe a +_ `safeIndex` n | n < 0 = Nothing +[] `safeIndex` _ = Nothing +(x:_) `safeIndex` 0 = Just x +(_:xs) `safeIndex` n = xs `safeIndex` (n-1) + +-- Copy of http://hackage.haskell.org/package/lens-3.9.0.2/docs/Control-Lens-Iso.html#v%3anon +anon :: a -> (a -> Bool) -> Lens' (Maybe a) a +anon a p = iso (fromMaybe a) go where + go b | p b = Nothing + | otherwise = Just b + +non :: Eq a => a -> Lens' (Maybe a) a +non a = anon a (a==) + +-- | Lens that defaults Nothing to mempty. +nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v) +nonEmpty = anon mempty null + +-- | Calculate the gradients for every node in a graph. +graphGrads :: forall a. GradientCompatible a + => Graph + -> Map FGL.Node (PendingGradients a) + -- ^ Initial gradients (usually just 1 for the node of interest). + -> Build (Map FGL.Node (Gradients a)) +graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult) + where + initState = GradientsState initPending Map.empty + -- Reverse topological sort. + -- TODO(fmayle): Filter out nodes that are not successors of any x in xs to + -- avoid calculating gradients that won't be used. + nodeOrder = FGL.topsort $ FGL.grev gr + go state node = + -- Aggregate the accumulated gradients for this node. + let outputGrads = + sumPendingGradient (state ^. gradientsPending . at node . nonEmpty) + in if null outputGrads + then state + else + -- Calculate the gradients for each of the node's inputs. + let nextState = state & gradientsResult %~ Map.insert node outputGrads + ctx = FGL.context gr node + in updatePendingGradients + ctx + (calculateInputGrads ctx outputGrads gr) + nextState + +-- | Reduce accumulated gradients for each output to one Tensor. +sumPendingGradient :: GradientCompatible a + => PendingGradients a -> Gradients a +sumPendingGradient = IntMap.mapMaybe f + where + f [] = Nothing + f [x] = Just x + f xs = Just (addN xs) + + +-- | Calculate the gradients of a node's input tensors. +-- +-- This is mostly just a wrapper around opGrad. +calculateInputGrads :: forall a. GradientCompatible a + => FGL.Context NodeDef EdgeLabel + -> Gradients a -- ^ Output gradients of the node. + -> Graph + -> [Maybe (Tensor Value a)] +calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = + opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads + where + fullOutGrads = + fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads + -- Create a tensor from an edge (technically an Output, but it seems less + -- confusing to refer to it as a tensor here). + edgeToTensor :: (EdgeLabel, FGL.Node) -> Output + edgeToTensor ((i, _), n) = + case FGL.lab gr n of + Just edgeNodeDef -> Output i (Rendered edgeNodeDef) + Nothing -> error $ "calculateInputGrads: missing input node for " + ++ Text.unpack (nodeDef ^. name) + -- Input tensors, sorted by input index. + inputTensors = map edgeToTensor $ sortBy (comparing (snd . fst)) inputEdges + +-- | Convert a Map of gradients to a list, with zeros for missing outputs. +fullOutputGrads :: (TensorType a, Num a) + => OutputIx -- ^ Number of outputs. + -> Op + -> Gradients a + -> [Tensor Value a] +fullOutputGrads n o gs = + map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1] + where + -- A tensor of zeros with the same shape as the i'th output. + zero i = zerosLike $ toT (Output i o) + + +-- | Update the pending gradients of a node's inputs. +updatePendingGradients :: forall a. (TensorType a, Num a) + => FGL.Context NodeDef EdgeLabel + -> [Maybe (Tensor Value a)] + -- ^ Gradient of each input tensor. + -> GradientsState a + -> GradientsState a +updatePendingGradients (inputEdges, _, nodeDef, _) inputGrads initState = + foldl' go initState inputEdges + where + go :: GradientsState a -> (EdgeLabel, FGL.Node) -> GradientsState a + go state ((outIndex, OutputIx inIndex), node) = + case maybeGradient of + Nothing -> state + Just g -> + -- Add to the list of pending gradients for this tensor. + state & gradientsPending + . at node + . nonEmpty + . outputIxAt outIndex + . nonEmpty + %~ (g:) + where + badSizeErr = error $ printf "updatePendingGradients: bad input index \ + \%d for inputGrads of length %d in %s" + inIndex (length inputGrads) + (show (nodeDef ^. name)) + maybeGradient = fromMaybe badSizeErr (safeIndex inputGrads inIndex) + + +-- | Create a graph that includes a node and its transitive dependencies. +createGraph :: NodeName -> (NodeName -> NodeDef) + -> (Graph, Map NodeName FGL.Node) +createGraph nodeName nodeDefLookup = (FGL.nmap nodeDefLookup graph, nodeMap) + where + -- Parse a tensor name. + parseTensorName :: Text -> Maybe (NodeName, OutputIx) + parseTensorName n + | Text.null n = error "parseTensorName: empty name" + | Text.head n == '^' = Nothing -- Control edge + | otherwise = + let (nm, indexStr) = Text.breakOn ":" n + index | Text.null indexStr = 0 + | otherwise = read $ Text.unpack $ Text.tail indexStr + in Just (NodeName nm, OutputIx index) + + -- Build a map from node name to outward edges. + -- + -- The state is the set of visited nodes. + collect :: Maybe (NodeName, OutputIx, OutputIx) + -> NodeName + -> State (Set NodeName) + (Map NodeName [(NodeName, OutputIx, OutputIx)]) + collect outgoingEdge nm = do + let nextLookup = Map.singleton nm (maybeToList outgoingEdge) + seen <- gets (Set.member nm) + modify (Set.insert nm) + if seen + then pure nextLookup + else do + let inputs = nodeDefLookup nm ^. input + recurse inIndex (parentName, outIndex) = + collect (Just (nm, outIndex, inIndex)) parentName + subEdgeLookups <- + zipWithM recurse [0..] $ mapMaybe parseTensorName inputs + pure $ Map.unionsWith (++) (nextLookup:subEdgeLookups) + + edgeLookup = evalState (collect Nothing nodeName) Set.empty + -- Associate an ID with each node name. + nodeMap = Map.fromList $ zip (Map.keys edgeLookup) [0..] + -- Create the graph. + graph = FGL.mkGraph (swap <$> Map.toList nodeMap) + [ (nodeMap Map.! n, nodeMap Map.! m, (i, j)) + | (n, edges) <- Map.toList edgeLookup + , (m, i, j) <- edges + ] + +-- | Function to compute the gradient of y w.r.t. each input. +-- +-- Let y be an arbitrary tensor +-- and [w_0, ..., w_n] be the output tensors of a node +-- and [v_0, ..., v_n] be the input tensors of the same node. +-- +-- Given [∂y/∂w_0, ..., ∂y/∂w_n] and [v_0, ..., v_n], a GradientFunc computes +-- [∂y/∂v_0, ..., ∂y/∂v_n] for a particular op type. +-- +-- A Nothing gradient is equivalent to zero (but allows for short circuiting +-- computation when all the gradients for something are Nothing). +type GradientFunc a = NodeDef + -> [Output] + -- ^ Input tensors. + -> [Tensor Value a] + -- ^ Gradient of y w.r.t. each output tensor. + -> [Maybe (Tensor Value a)] + -- ^ Gradient of y w.r.t. each input tensor. + + +-- TODO(fmayle): Assert the type is correct. +-- | Create a Tensor from an Output. +toT :: Output -> Tensor Value a +toT = Tensor ValueKind + +-- | The gradient function for an op type. +-- +-- These implementations should match their python counterparts in: +-- third_party/tensorflow/python/ops/*_grad.py +opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a + +opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x] +opGrad "Neg" _ [_] [dz] = [Just $ -dz] +opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x] + +opGrad "Square" _ [toT -> x] [dz] = + -- TODO(fmayle): Handle complex numbers. + -- TODO(fmayle): The python code makes dz a control dependency of the 2*x + -- (for performance reasons?). Will need to put these functions in the Build + -- monad to replicate that. + [Just $ dz * (2 * x)] + +opGrad "Gather" _ [toT -> x, toT -> indices] [dz] = + -- TODO(fmayle): The python version uses a better performance implementation + -- when the shape is known without having to run the graph. + -- TODO(fmayle): We shouldn't convert the result to a dense tensor. Sparse + -- tensor support will require some thinking. + [ Just $ CoreOps.unsortedSegmentSum values indices' numRows + , Nothing + ] + where + -- TODO(gnezdo): Use colocateWith but it requires Build monad. + denseShape = shape (x :: Tensor Value a) + numRows = CoreOps.slice denseShape 0 (1 :: Tensor Value Int32) + valuesShape = CoreOps.concat 0 [ + allDimensions + , CoreOps.slice denseShape 1 (-1 :: Tensor Value Int32) + ] + values = reshape dz valuesShape + -- TODO(fmayle): This could be either Int32 or Int64. + indices' = reshape indices allDimensions :: Tensor Value Int32 + +opGrad "Max" _ [toT -> x, toT -> indices] [dz] = + [Just $ indicators `CoreOps.div` numSelected * dz', Nothing] + where + sx = shape (x :: Tensor Value a) + outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) + x' = reshape x outputShapeKeptDims + dz' = reshape dz outputShapeKeptDims + indicators = CoreOps.cast $ CoreOps.equal x' x + numSelected = reshape (sum indicators indices) outputShapeKeptDims + +-- Min and Max have identical gradient implementations. +opGrad "Min" u v w = opGrad "Max" u v w + +opGrad "Sum" _ [toT -> x, toT -> indices] [dz] = + [ Just $ CoreOps.tile grad tileScaling, Nothing ] + where + -- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad. + sx = shape (x :: Tensor Value a) + outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) + tileScaling = safeShapeDiv sx outputShapeKeptDims + grad = reshape dz outputShapeKeptDims + +opGrad "Mean" u v@[toT -> x, _] w = + [Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing] + where + [Just dz, Nothing] = opGrad "Sum" u v w + inputShape = shape (x :: Tensor Value a) + outputShape = shape (dz :: Tensor Value a) + -- TODO(fmayle): Add fast path when shape is known. + inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape + outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape + factor = safeShapeDiv inputSize outputSize + +opGrad "Add" _ [toT -> x, toT -> y] [dz] = + [ Just $ reshape (sum dz rx) sx + , Just $ reshape (sum dz ry) sy ] + where + sx = shape (x :: Tensor Value a) + sy = shape (y :: Tensor Value a) + (rx, ry) = broadcastGradientArgs sx sy + +opGrad "Sub" u v w = + [Just x, Just (-y)] + where + [Just x, Just y] = opGrad "Add" u v w + +opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] = + [ Just $ expandDims dz (-1) * snd (softmaxCrossEntropyWithLogits x y) + , Nothing ] + +opGrad "Mul" _ [toT -> x, toT -> y] [dz] = + -- TODO(fmayle): Handle complex numbers. + [ Just $ reshape (sum (dz * y) rx) sx + , Just $ reshape (sum (x * dz) ry) sy ] + where + sx = shape (x :: Tensor Value a) + sy = shape (y :: Tensor Value a) + (rx, ry) = broadcastGradientArgs sx sy + +opGrad "Div" _ [toT -> x, toT -> y] [dz] = + -- TODO(fmayle): Handle complex numbers. + -- TODO(gnezdo): Provide Fractional instance and use '/' instead of div. + [ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx + , Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy + ] + where + sx = shape (x :: Tensor Value a) + sy = shape (y :: Tensor Value a) + (rx, ry) = broadcastGradientArgs sx sy + +opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] = + let transposeA = lookupAttr nodeDef "transpose_a" + transposeB = lookupAttr nodeDef "transpose_b" + transAttrs a b = + (tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b) + in case (transposeA, transposeB) of + (False, False) -> + [ Just $ (dz `matMul` y) & transAttrs False True + , Just $ (x `matMul` dz) & transAttrs True False ] + (False, True) -> + [ Just $ dz `matMul` y + , Just $ (x `matMul` dz) & transAttrs True False ] + (True, False) -> + [ Just $ (dz `matMul` y) & transAttrs False True + , Just $ x `matMul` dz ] + (True, True) -> + [ Just $ (dz `matMul` y) & transAttrs True True + , Just $ (x `matMul` dz) & transAttrs True True ] + +opGrad "Transpose" _ [_, toT -> p] [dz] = + [ Just $ CoreOps.transpose dz + (CoreOps.invertPermutation p :: Tensor Value Int32) + , Nothing + ] + +opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] = + [ Just $ CoreOps.conv2DBackpropInput (shape x) y dz + & tensorAttr "strides" .~ strides + & tensorAttr "padding" .~ padding + & tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu + & tensorAttr "data_format" .~ dataFormat + , Just $ CoreOps.conv2DBackpropFilter x (shape y) dz + & tensorAttr "strides" .~ strides + & tensorAttr "padding" .~ padding + & tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu + & tensorAttr "data_format" .~ dataFormat + ] + where + strides = lookupAttr nodeDef "strides" :: [Int64] + padding = lookupAttr nodeDef "padding" :: ByteString + useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool + dataFormat = lookupAttr nodeDef "data_format" :: ByteString + +opGrad "MaxPool" nodeDef [toT -> x] [dz] = + [ Just $ CoreOps.maxPoolGrad x output dz + & tensorAttr "ksize" .~ ksize + & tensorAttr "strides" .~ strides + & tensorAttr "padding" .~ padding + & tensorAttr "data_format" .~ dataFormat + ] + where + output :: Tensor Value a + output = toT $ Output 0 (Rendered nodeDef) + ksize = lookupAttr nodeDef "ksize" :: [Int64] + strides = lookupAttr nodeDef "strides" :: [Int64] + padding = lookupAttr nodeDef "padding" :: ByteString + dataFormat = lookupAttr nodeDef "data_format" :: ByteString + +opGrad "Reshape" _ [toT -> x, _] [dz] = + [Just $ reshape dz $ shape (x :: Tensor Value a), Nothing] + +opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing] +opGrad "TruncatedNormal" _ _ _ = [Nothing] + +opGrad "RefIdentity" _ _ [dz] = [Just dz] +opGrad "Cast" nodeDef _ [dz] = [Just reverseCast] + where + -- TODO(gnezdo): too permissive, python only allows float types as src_type. + reverseCast = + buildOp (opDef "Cast" + & opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString) + & opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString)) + dz + +opGrad "DynamicStitch" nodeDef inputs [dz] = + replicate halfLen Nothing ++ valuesGrads + where + halfLen = + let len = length inputs + half = len `div` 2 + in if 2 * half == len + then half + else error ("Uneven input size " ++ show (len, showMessage nodeDef)) + valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32) + | idx <- take halfLen inputs + ] + +opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz = + [ Just reconstructed, Nothing ] + where + reconstructed = CoreOps.reshape stitched + (CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32) + stitched = CoreOps.dynamicStitch partitionedIndices dz + partitionedIndices = CoreOps.dynamicPartition np originalIndices indices + np = lookupAttr nodeDef "num_partitions" :: Int64 + originalIndices = + CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape + prefixShape = shapeInt32 indices + shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32 + +opGrad "Select" _ [toT -> c, toT -> x, _] [dz] = + [ Nothing + , Just $ CoreOps.select c dz zeros + , Just $ CoreOps.select c zeros dz + ] + where zeros = CoreOps.zerosLike x + +-- TODO(gnezdo): Unlike Python, no control dependency on dz. +opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ] +-- TODO(gnezdo): Reuse the output instead of doing another exp, +-- though, it is probably CSE'd away anyway. +opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ] +opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] = + [ Just $ CoreOps.unsortedSegmentSum + (CoreOps.gather dz (t :: Tensor Value Int32)) + (y :: Tensor Value Int32) inputRows + , Nothing + , Nothing + ] + where inputRows = CoreOps.slice (shape (x :: Tensor Value a)) (scalar (0 :: Int32)) (scalar 1) + +opGrad "LabelClasses" _ _ _ = [Nothing, Nothing] +opGrad "LabelWeights" _ _ _ = [Nothing] +opGrad "Size" _ _ _ = [Nothing] +opGrad "ZerosLike" _ _ _ = [Nothing] + +-- TODO(fmayle): These can go away if we properly prune the graph. +opGrad "Const" _ _ _ = [Nothing, Nothing] +opGrad "Placeholder" _ _ _ = [] +opGrad "Variable" _ _ _ = [] + +opGrad n nodeDef ins grads = + error $ "no gradient implemented for " ++ + show (n, length ins, length grads, showMessage nodeDef, ins) + +-- | The number of outputs for an op type. +numOutputs :: NodeDef -> OutputIx +numOutputs o = + case o ^. op of + "Abs" -> 1 + "Add" -> 1 + "Cast" -> 1 + "Const" -> 1 + "Conv2D" -> 1 + "Div" -> 1 + "DynamicStitch" -> 1 + "DynamicPartition" -> + fromIntegral (lookupAttr o "num_partitions" :: Int64) + "Exp" -> 1 + "Gather" -> 1 + "LabelClasses" -> 1 + "LabelWeights" -> 1 + "Log" -> 1 + "MatMul" -> 1 + "Max" -> 1 + "MaxPool" -> 1 + "Mean" -> 1 + "Min" -> 1 + "Mul" -> 1 + "Neg" -> 1 + "Placeholder" -> 1 + "OneHot" -> 1 + "RefIdentity" -> 1 + "Relu" -> 1 + "Reshape" -> 1 + "Select" -> 1 + "Size" -> 1 + "SoftmaxCrossEntropyWithLogits" -> 2 + "Square" -> 1 + "SparseSegmentSum" -> 1 + "Sub" -> 1 + "Sum" -> 1 + "Transpose" -> 1 + "TruncatedNormal" -> 1 + "Variable" -> 1 + "ZerosLike" -> 1 + _ -> error $ "numOuputs not implemented for " ++ show (o ^. op) + +-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0` +safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1) + +allDimensions = vector [-1 :: Int32] + +rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1 + +lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs new file mode 100644 index 0000000..0730363 --- /dev/null +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -0,0 +1,296 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- | This module contains definitions for some built-in TensorFlow operations. +-- +-- Note that certain, "stateful" ops like 'variable' and 'assign' return a +-- 'Build' action (e.g., @Build (Tensor Ref a)@ instead of a pure value; the +-- returned 'Tensor's are always rendered in the current 'Build' context. This +-- approach helps us avoid problems with inlining or common subexpression +-- elimination, by writing +-- +-- > do +-- > v <- variable [] +-- > w <- assign v 3 +-- > render $ w * w +-- +-- instead of +-- +-- > let +-- > v = variable [] +-- > w = assign v 3 +-- > in w * w +-- +-- since the latter could be reasonably transformed by the compiler into (or +-- vice versa) +-- +-- > let +-- > v = variable [] +-- > w = assign v 3 +-- > w' = assign v 3 +-- > in w * w' +-- +-- Ops should return a 'Build' action if their original 'OpDef' marks them as +-- stateful, or if they take any Refs as input. (This mirrors the rules that +-- TensorFlow uses to avoid common subexpression elimination.) +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module TensorFlow.Ops + ( CoreOps.add + , CoreOps.abs + , CoreOps.addN + , CoreOps.argMax + , assign + , CoreOps.broadcastGradientArgs + , CoreOps.cast + , CoreOps.concat + , constant + , expandDims + , initializedVariable + , zeroInitializedVariable + , CoreOps.fill + , CoreOps.matMul + , matTranspose + , CoreOps.mul + , CoreOps.neg + , CoreOps.pack + , placeholder + , CoreOps.range + , reducedShape + , CoreOps.relu + , CoreOps.reluGrad + , CoreOps.reshape + , restore + , restoreFromName + , save + , scalar + , shape + , CoreOps.sign + , CoreOps.size + , CoreOps.softmax + , CoreOps.softmaxCrossEntropyWithLogits + , CoreOps.sparseToDense + , CoreOps.sub + , CoreOps.sum + , CoreOps.topK + , CoreOps.transpose + , truncatedNormal + , variable + , vector + , zeros + , CoreOps.zerosLike + ) where + +import Data.ByteString (ByteString) +import Data.Complex (Complex) +import Data.Int (Int32, Int64) +import Prelude hiding (abs, sum, concat) +import Data.ProtoLens (def) +import Data.Text.Encoding (encodeUtf8) +import Lens.Family2 ((.~), (&)) +import Text.Printf (printf) +import Proto.Tensorflow.Core.Framework.Tensor + ( TensorProto + , dtype + , tensorShape + ) +import qualified Proto.Tensorflow.Core.Framework.TensorShape + as TensorShape +import TensorFlow.Build +import TensorFlow.BuildOp +import TensorFlow.ControlFlow (group) +import TensorFlow.Output (unNodeName) +import TensorFlow.Tensor +import TensorFlow.Types + +import qualified TensorFlow.GenOps.Core as CoreOps + +import qualified Prelude (abs) + +-- TODO: Look into hs-boot refactoring to allow mutually recursive imports. +-- | Must be defined as an orphan because of the dependency order between Ops +-- and Tensor. +-- +-- The indirect constraint "v ~ Value" helps disambiguate types, for example in +-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression +-- "1". +instance ( TensorType a + , Num a + , v ~ Value + , OneOf '[ Double, Float, Int32, Int64 + , Complex Float, Complex Double] a) => Num (Tensor v a) where + (+) = CoreOps.add + (*) = CoreOps.mul + (-) = CoreOps.sub + abs = CoreOps.abs + fromInteger = scalar . fromInteger + signum = CoreOps.sign + negate = CoreOps.neg + +matTranspose :: forall a v . TensorType a + => Tensor v a -> Tensor Value a +matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32]) + +-- | Create a new, uninitialized stateful Tensor of the given shape. +variable :: forall a . TensorType a => Shape -> Build (Tensor Ref a) +variable shape' = buildOp $ opDef "Variable" + & opAttr "shape" .~ shape' + & opAttr "dtype" .~ tensorType (undefined :: a) + +placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a) +placeholder shape' = + buildOp $ opDef "Placeholder" + & opAttr "dtype" .~ tensorType (undefined :: a) + & opAttr "shape" .~ shape' + +-- Assign returns the input ref. +assign :: forall a v . TensorType a + => Tensor Ref a -> Tensor v a -> Build (Tensor Ref a) +assign = buildOp $ opDef "Assign" + & opAttr "T" .~ tensorType (undefined :: a) + & opAttr "use_locking" .~ True + +-- | Creates a variable initialized to the given value. +-- Initialization happens next time session runs. +initializedVariable :: forall a . TensorType a + => Tensor Value a -> Build (Tensor Ref a) +initializedVariable initializer = do + v <- variable [] -- The shape is not known initially. + (i :: Tensor Ref a) <- + buildOp (opDef "Assign" + & opAttr "T" .~ tensorType (undefined :: a) + & opAttr "use_locking" .~ True + & opAttr "validate_shape" .~ False + ) + v initializer + addInitializer =<< group i + return v + +-- | Creates a zero-initialized variable with the given shape. +zeroInitializedVariable + :: (TensorType a, Num a) => + TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a) +zeroInitializedVariable = initializedVariable . zeros + +-- TODO: Support heterogeneous list of tensors. +save :: forall a v . TensorType a + => ByteString -- ^ File path. + -> [Tensor v a] -- ^ Tensors to save. + -> Build ControlNode +save path xs = do + let toByteStringTensor = scalar . encodeUtf8 . unNodeName + names <- mapM (fmap toByteStringTensor . renderNodeName) xs + let types = replicate (length xs) (tensorType (undefined :: a)) + let saveOp = buildOp $ opDef "Save" + & opAttr "T" .~ types + saveOp (scalar path) (CoreOps.pack names) xs + +-- | Restore a tensor's value from a checkpoint file. +-- +-- This version allows restoring from a checkpoint file that uses a different +-- tensor name than the variable. +restoreFromName :: forall a . TensorType a + => ByteString -- ^ File path. + -> ByteString -- ^ Tensor name override. + -> Tensor Ref a -- ^ Tensor to restore. + -> Build ControlNode +restoreFromName path name x = do + let restoreOp = buildOp $ opDef "Restore" + & opAttr "dt" .~ tensorType (undefined :: a) + group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a) + +-- | Restore a tensor's value from a checkpoint file. +restore :: forall a . TensorType a + => ByteString -- ^ File path. + -> Tensor Ref a -- ^ Tensor to restore. + -> Build ControlNode +restore path x = do + name <- encodeUtf8 . unNodeName <$> renderNodeName x + restoreFromName path name x + +-- | Create a constant tensor. +-- +-- The values should be in row major order, e.g., +-- +-- element 0: index (0, ..., 0) +-- element 1: index (0, ..., 1) +-- ... +constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a +constant (Shape shape') values + | invalidLength = error invalidLengthMsg + | otherwise = buildOp $ opDef "Const" + & opAttr "value" .~ typedNode + & opAttr "dtype" .~ nodeType + where + invalidLength = product shape' /= fromIntegral (length values) + invalidLengthMsg = printf "invalid tensor length: expected %d got %d" + (product shape') + (length values) + nodeType = tensorType (undefined :: a) + typedNode :: TensorProto + typedNode = def + & dtype .~ nodeType + & tensorShape.TensorShape.dim .~ + [def & TensorShape.size .~ x | x <- shape'] + & tensorVal .~ values + +-- | Create a constant vector. +vector :: TensorType a => [a] -> Tensor Value a +vector xs = constant [fromIntegral $ length xs] xs + +-- | Create a constant scalar. +scalar :: forall a . TensorType a => a -> Tensor Value a +scalar x = constant [] [x] + +-- Random tensor from the unit normal distribution with bounded values. +truncatedNormal :: forall a v . TensorType a + => Tensor v Int64 -- ^ Shape. + -> Build (Tensor Value a) +truncatedNormal = buildOp $ opDef "TruncatedNormal" + & opAttr "dtype" .~ tensorType (undefined :: a) + & opAttr "T" .~ tensorType (undefined :: Int64) + +zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a +zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0) + +shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32 +shape = CoreOps.shape + +expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t +expandDims = CoreOps.expandDims + +-- | Helper function for reduction ops (translation of math_ops.reduced_shape). +reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) => + Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32 +reducedShape inputShape axes = + let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7] + axes32 = toInt32 axes -- [1, 2] + toInt32 x = CoreOps.cast x :: Tensor Value Int32 + inputRank = CoreOps.size inputShape32 -- 4 + axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank + axesShape = shape axesMod -- [2] + in CoreOps.dynamicStitch -- [2, 1, 1, 7] + [CoreOps.range 0 inputRank 1, -- [0, 1, 2, 3] + axesMod] -- [1, 2] + [inputShape32, -- [2, 3, 5, 7] + CoreOps.fill axesShape 1] -- [1, 1] diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal new file mode 100644 index 0000000..ac34d4c --- /dev/null +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -0,0 +1,191 @@ +name: tensorflow-ops +version: 0.1.0.0 +synopsis: Friendly layer around TensorFlow bindings. +description: Please see README.md +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Simple +cabal-version: >=1.22 + +library + hs-source-dirs: src + exposed-modules: TensorFlow.Gradient + , TensorFlow.Ops + , TensorFlow.EmbeddingOps + build-depends: proto-lens == 0.1.* + , base >= 4.7 && < 5 + , bytestring + , fgl + , mtl + , data-default + , lens-family + , containers + , tensorflow == 0.1.* + , tensorflow-proto == 0.1.* + , tensorflow-core-ops == 0.1.* + , text + default-language: Haskell2010 + +Test-Suite BuildTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: BuildTest.hs + hs-source-dirs: tests + build-depends: HUnit + , base + , proto-lens + , lens-family + , google-shim + , tensorflow + , tensorflow-ops + , tensorflow-proto + , test-framework + , test-framework-hunit + , transformers + , vector + +Test-Suite EmbeddingOpsTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: EmbeddingOpsTest.hs + hs-source-dirs: tests + build-depends: HUnit + , QuickCheck + , base + , proto-lens + , lens-family + , google-shim + , tensorflow + , tensorflow-core-ops + , tensorflow-ops + , tensorflow-proto + , test-framework + , test-framework-hunit + , test-framework-quickcheck2 + , vector + +Test-Suite ArrayOpsTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: ArrayOpsTest.hs + hs-source-dirs: tests + build-depends: HUnit + , QuickCheck + , base + , proto-lens + , lens-family + , google-shim + , tensorflow + , tensorflow-core-ops + , tensorflow-ops + , tensorflow-proto + , test-framework + , test-framework-hunit + , test-framework-quickcheck2 + , transformers + , vector + +Test-Suite OpsTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: OpsTest.hs + hs-source-dirs: tests + build-depends: HUnit + , QuickCheck + , base + , bytestring + , proto-lens + , lens-family + , google-shim + , temporary + , tensorflow + , tensorflow-core-ops + , tensorflow-ops + , tensorflow-proto + , test-framework + , test-framework-hunit + , test-framework-quickcheck2 + , transformers + , vector + +Test-Suite DataFlowOpsTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: DataFlowOpsTest.hs + hs-source-dirs: tests + build-depends: HUnit + , QuickCheck + , base + , proto-lens + , lens-family + , google-shim + , tensorflow + , tensorflow-core-ops + , tensorflow-ops + , tensorflow-proto + , test-framework + , test-framework-hunit + , test-framework-quickcheck2 + , vector + +Test-Suite GradientTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: GradientTest.hs + hs-source-dirs: tests + build-depends: HUnit + , base + , proto-lens + , lens-family + , google-shim + , tensorflow + , tensorflow-ops + , tensorflow-proto + , test-framework + , test-framework-hunit + +Test-Suite MiscTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: MiscTest.hs + hs-source-dirs: tests + build-depends: HUnit + , base + , bytestring + , vector + , google-shim + , transformers + , tensorflow + , tensorflow-ops + , tensorflow-proto + , test-framework + , test-framework-hunit + +Test-Suite TypesTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: TypesTest.hs + hs-source-dirs: tests + build-depends: HUnit + , QuickCheck + , base + , bytestring + , proto-lens + , lens-family + , google-shim + , tensorflow + , tensorflow-ops + , tensorflow-proto + , transformers + , test-framework + , test-framework-hunit + , test-framework-quickcheck2 + , vector + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/tensorflow-ops/tests/ArrayOpsTest.hs b/tensorflow-ops/tests/ArrayOpsTest.hs new file mode 100644 index 0000000..e31fa19 --- /dev/null +++ b/tensorflow-ops/tests/ArrayOpsTest.hs @@ -0,0 +1,42 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedLists #-} +module Main where + +import Control.Monad.IO.Class (liftIO) +import Google.Test (googleTest) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@=?)) +import qualified Data.Vector as V + +import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Session as TF +import qualified TensorFlow.GenOps.Core as CoreOps + +-- | Test split and concat are inverses. +testSplit = testCase "testSplit" $ TF.runSession $ do + let original = TF.constant [2, 3] [0..5 :: Float] + splitList = CoreOps.split 3 dim original + restored = CoreOps.concat dim splitList + dim = 1 -- dimension to split along (with size of 3 in original) + liftIO $ 3 @=? length splitList + (x, y, z) <- + TF.buildAnd TF.run $ return (original, restored, splitList !! 1) + liftIO $ x @=? (y :: V.Vector Float) + liftIO $ V.fromList [1, 4] @=? z + +main :: IO () +main = googleTest [ testSplit + ] diff --git a/tensorflow-ops/tests/BuildTest.hs b/tensorflow-ops/tests/BuildTest.hs new file mode 100644 index 0000000..6f1504c --- /dev/null +++ b/tensorflow-ops/tests/BuildTest.hs @@ -0,0 +1,181 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Main where + +import Control.Monad.IO.Class (liftIO) +import Data.Functor.Identity (runIdentity) +import Lens.Family2 ((^.)) +import Data.List (sort) +import Proto.Tensorflow.Core.Framework.Graph + ( node ) +import Proto.Tensorflow.Core.Framework.NodeDef + ( NodeDef + , device + , name + , op ) +import TensorFlow.Build + ( Build + , BuildT + , asGraphDef + , evalBuildT + , flushNodeBuffer + , hoistBuildT + , render + , withDevice + , colocateWith + , withNameScope + ) +import TensorFlow.ControlFlow (named) +import TensorFlow.Nodes (unScalar) +import TensorFlow.Ops + ( add + , assign + , constant + , initializedVariable + , variable + ) +import TensorFlow.Output (Device(..)) +import TensorFlow.Tensor (Tensor, Value, Ref) +import TensorFlow.Session + ( build + , buildAnd + , run + , runSession + , run_ + ) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@=?)) +import Google.Test (googleTest) +import qualified Data.Vector as V + +-- | Test named behavior. +testNamed = testCase "testNamed" $ do + let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float) + nodeDef :: NodeDef + nodeDef = head $ asGraphDef graph ^. node + "RefIdentity" @=? (nodeDef ^. op) + "foo" @=? (nodeDef ^. name) + +-- | Test named deRef behavior. +testNamedDeRef = testCase "testNamedDeRef" $ do + let graph = named "foo" <$> do + v :: Tensor Ref Float <- variable [] + assign v 5 + -- TODO: Implement TensorFlow get_variable and test it. + runSession $ do + out <- buildAnd run graph + liftIO $ 5 @=? (unScalar out :: Float) + +-- | Test that "run" will render and extend any pure ops that haven't already +-- been rendered. +testPureRender = testCase "testPureRender" $ runSession $ do + result <- run $ 2 `add` 2 + liftIO $ 4 @=? (unScalar result :: Float) + +-- | Test that "run" assigns any previously accumulated initializers. +testInitializedVariable = + testCase "testInitializedVariable" $ runSession $ do + (formula, reset) <- build $ do + v <- initializedVariable 42 + r <- assign v 24 + return (1 `add` v, r) + result <- run formula + liftIO $ 43 @=? (unScalar result :: Float) + run_ reset -- Updates v to a different value + rerunResult <- run formula + liftIO $ 25 @=? (unScalar rerunResult :: Float) + +testInitializedVariableShape = + testCase "testInitializedVariableShape" $ runSession $ do + vector <- build $ initializedVariable (constant [1] [42 :: Float]) + result <- run vector + liftIO $ [42] @=? (result :: V.Vector Float) + +-- | Test nameScoped behavior. +testNameScoped = testCase "testNameScoped" $ do + let graph = withNameScope "foo" $ variable [] :: Build (Tensor Ref Float) + nodeDef :: NodeDef + [nodeDef] = asGraphDef graph ^. node + "foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix. + "Variable" @=? (nodeDef ^. op) + +-- | Test combined named and nameScoped behavior. +testNamedAndScoped = testCase "testNamedAndScoped" $ do + let graph :: Build (Tensor Ref Float) + graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render) + nodeDef :: NodeDef + nodeDef = head $ asGraphDef graph ^. node + "RefIdentity" @=? (nodeDef ^. op) + "foo1/bar1" @=? (nodeDef ^. name) + +-- | Lift a Build action into a context for HUnit to run. +liftBuild :: Build a -> BuildT IO a +liftBuild = hoistBuildT (return . runIdentity) + +-- | Flush the node buffer and sort the nodes by name (for more stable tests). +flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a] +flushed field = sort . map field <$> liftBuild flushNodeBuffer + +-- | Test the interaction of rendering, CSE and scoping. +testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do + liftBuild renderNodes + names <- flushed (^. name) + liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names + -- Render the nodes in a different scope, which should cause them + -- to be distinct from the previous ones. + liftBuild $ withNameScope "foo" renderNodes + scopedNames <- flushed (^. name) + liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames + where + renderNodes = do + -- A stateful op and a pure op. + _ :: Tensor Ref Float <- variable [] + _ :: Tensor Value Float <- render 3 + -- Another stateful op, and a pure op which should be + -- deduped with the previous one. + _ :: Tensor Ref Float <- variable [] + _ :: Tensor Value Float <- render 3 + return () + +-- | Test the interaction of rendering, CSE and scoping. +testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do + liftBuild renderNodes + devices <- flushed (\x -> (x ^. name, x ^. device)) + liftIO $ [ ("Add_2","dev0") + , ("Const_1","dev0") + , ("Variable_0","dev0")] @=? devices + where + renderNodes = do + -- A stateful op and a pure op. + var :: Tensor Ref Float <- withDevice (Just $ Device "dev0") $ variable [] + -- Uses render to cause the expression be added to the graph. + _ <- colocateWith var $ render $ 3 `add` var + return () + +main :: IO () +main = googleTest [ testInitializedVariable + , testInitializedVariableShape + , testDeviceColocation + , testNamed + , testNamedDeRef + , testNameScoped + , testNamedAndScoped + , testPureRender + , testRenderDedup + ] diff --git a/tensorflow-ops/tests/DataFlowOpsTest.hs b/tensorflow-ops/tests/DataFlowOpsTest.hs new file mode 100644 index 0000000..cd362c9 --- /dev/null +++ b/tensorflow-ops/tests/DataFlowOpsTest.hs @@ -0,0 +1,66 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE ScopedTypeVariables #-} + +import Data.Int (Int32, Int64) +import Data.List (genericLength) +import Google.Test (googleTest) +import Test.Framework.Providers.QuickCheck2 (testProperty) +import Test.HUnit ((@=?)) +import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf) +import Test.QuickCheck.Monadic (monadicIO, run) + +import qualified Data.Vector as V +import qualified TensorFlow.GenOps.Core as CoreOps +import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Session as TF +import qualified TensorFlow.Tensor as TF +import qualified TensorFlow.Types as TF + +-- DynamicSplit is undone with DynamicStitch to get the original input +-- back. +testDynamicPartitionStitchInverse :: forall a. + (TF.TensorType a, Show a, Eq a) => StitchExample a -> Property +testDynamicPartitionStitchInverse (StitchExample numParts values partitions) = + let splitParts :: [TF.Tensor TF.Value a] = + CoreOps.dynamicPartition numParts (TF.vector values) partTensor + partTensor = TF.vector partitions + restitchIndices = CoreOps.dynamicPartition numParts + (TF.vector [0..genericLength values-1]) + partTensor + -- drop (numParts - 2) from both args to expose b/27343984 + restitch = CoreOps.dynamicStitch restitchIndices splitParts + in monadicIO $ run $ do + fromIntegral numParts @=? length splitParts + valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch + V.fromList values @=? valuesOut + +data StitchExample a = StitchExample Int64 [a] [Int32] + deriving Show + +instance Arbitrary a => Arbitrary (StitchExample a) where + arbitrary = do + -- Limits the size of the vector. + size <- choose (1, 100) + values <- vectorOf size arbitrary + numParts <- choose (2, 15) + partitions <- vectorOf size (choose (0, fromIntegral numParts - 1)) + return $ StitchExample numParts values partitions + +main :: IO () +main = googleTest + [ testProperty "DynamicPartitionStitchInverse" + (testDynamicPartitionStitchInverse :: StitchExample Int64 -> Property) + ] diff --git a/tensorflow-ops/tests/EmbeddingOpsTest.hs b/tensorflow-ops/tests/EmbeddingOpsTest.hs new file mode 100644 index 0000000..0a6b97d --- /dev/null +++ b/tensorflow-ops/tests/EmbeddingOpsTest.hs @@ -0,0 +1,88 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +-- | Tests for EmbeddingOps. +module Main where + +import Data.Int (Int32, Int64) +import Data.List (genericLength) +import Google.Test (googleTest) +import TensorFlow.EmbeddingOps (embeddingLookup) +import Test.Framework.Providers.QuickCheck2 (testProperty) +import Test.HUnit ((@=?)) +import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf) +import Test.QuickCheck.Monadic (monadicIO, run) + +import qualified Data.Vector as V +import qualified TensorFlow.GenOps.Core as CoreOps +import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Session as TF +import qualified TensorFlow.Tensor as TF +import qualified TensorFlow.Types as TF + +-- Verifies that direct gather is the same as dynamic split into +-- partitions, followed by embedding lookup. +testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a) + => LookupExample a -> Property +testEmbeddingLookupUndoesSplit + (LookupExample numParts + shape@(TF.Shape (firstDim : restDims)) + values + indices) = + let modShardedValues :: [TF.Tensor TF.Value a] = + CoreOps.dynamicPartition numParts shapedValues cyclicCounter + cyclicCounter :: TF.Tensor TF.Value Int32 = + TF.vector [0..fromIntegral firstDim-1] + `CoreOps.mod` fromIntegral numParts + indicesVector = TF.vector indices + directs = CoreOps.gather shapedValues indicesVector + shapedValues = TF.constant shape values + in monadicIO $ run $ do + (shapeOut, got, want :: V.Vector a) <- + TF.runSession $ TF.buildAnd TF.run $ do + embeddings <- embeddingLookup modShardedValues indicesVector + return (TF.cast (TF.shape embeddings), embeddings, directs) + -- Checks the explicitly documented invariant of embeddingLookup. + shapeOut @=? V.fromList (genericLength indices : restDims) + got @=? want +testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)" + +-- | Consistent set of parameters for EmbeddingLookupUndoesSplit. +data LookupExample a = LookupExample + Int64 -- ^ number of ways to split. + TF.Shape -- ^ shape of the generated tensor + [a] -- ^ data for the tensor + [Int32] -- ^ indices to split the tensor by + deriving Show + +instance Arbitrary a => Arbitrary (LookupExample a) where + arbitrary = do + rank <- choose (1, 4) + -- Takes rank-th root of 100 to cap the tensor size. + let maxDim = fromIntegral $ ceiling $ 100 ** (1 / fromIntegral rank) + shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim)) + values <- vectorOf (fromIntegral $ product shape) arbitrary + numParts <- choose (2, 15) + indSize <- choose (0, fromIntegral $ firstDim - 1) + indices <- vectorOf indSize (choose (0, fromIntegral firstDim - 1)) + return $ LookupExample numParts (TF.Shape shape) values indices + +main :: IO () +main = googleTest + [ testProperty "EmbeddingLookupUndoesSplit" + (testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property) + ] diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs new file mode 100644 index 0000000..037b309 --- /dev/null +++ b/tensorflow-ops/tests/GradientTest.hs @@ -0,0 +1,158 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +import Data.List (sort) +import Data.ProtoLens.TextFormat (showMessage) +import Google.Test (googleTest) +import Lens.Family2 ((^..)) +import Test.Framework (Test) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@=?)) + +import qualified TensorFlow.Build as TF +import qualified TensorFlow.Gradient as TF +import qualified TensorFlow.Nodes as TF +import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Session as TF +import qualified TensorFlow.Tensor as TF +import qualified TensorFlow.Types as TF + +import Proto.Tensorflow.Core.Framework.Graph (node) +import Proto.Tensorflow.Core.Framework.NodeDef (op) + +testGradientSimple :: Test +testGradientSimple = testCase "testGradientSimple" $ do + let x = TF.scalar (3 :: Float) + b = TF.scalar (4 :: Float) + y = x*x + b + grads = TF.gradients y [x, b] + -- Assert that the gradients are right. + [dx, db] <- TF.runSession $ TF.buildAnd TF.run grads + 6 @=? TF.unScalar dx + 1 @=? TF.unScalar db + -- Assert that the graph has the expected ops. + let graphDef = TF.asGraphDef grads + putStrLn $ showMessage graphDef + let ops = graphDef ^.. node . traverse . op + expected = [ "Const" + , "Mul" + , "Const" + , "Add" + -- Default output gradient of y. + , "Shape" + , "Const" + , "Fill" + -- Add gradient. + , "Shape" + , "Shape" + , "BroadcastGradientArgs" + , "Sum" + , "Sum" + , "Reshape" + , "Reshape" + -- Mul gradient. + , "Shape" + -- This Op gets dedup'd because the inputs are the same. + -- TODO(fmayle): The same would happen to the Mul and Sum ops + -- below if the gradient function didn't multiply one as + -- 'dz * y' and the other as 'x * dz'. We could change the + -- order, but I'm going to keep it the same as the python + -- version for now. + -- + -- , "Shape" + , "BroadcastGradientArgs" + , "Mul" + , "Mul" + , "Sum" + , "Sum" + , "Reshape" + , "Reshape" + -- AddN to combine x's output gradients. + , "AddN" + ] + sort expected @=? sort ops + +testGradientDisconnected :: Test +testGradientDisconnected = testCase "testGradientDisconnected" $ do + let x = TF.scalar (3 :: Float) + b = TF.scalar (4 :: Float) + grads = TF.gradients x [x, b] + -- Assert that the gradients are right. + [dx, db] <- TF.runSession $ TF.buildAnd TF.run grads + 1 @=? TF.unScalar dx + 0 @=? TF.unScalar db + -- Assert that the graph has the expected ops. + let graphDef = TF.asGraphDef grads + putStrLn $ showMessage graphDef + let ops = graphDef ^.. node . traverse . op + expected = [ "Const" + , "Const" + -- Default output gradient of x. + , "Shape" + , "Const" + , "Fill" + -- Default output gradient of b. + , "ZerosLike" + ] + sort expected @=? sort ops + + +-- Test that identical "stateful" ops work with createGraph. +testCreateGraphStateful :: Test +testCreateGraphStateful = testCase "testCreateGraphStateful" $ do + [dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do + let shape = TF.constant (TF.Shape [1]) [1] + x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape + y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape + TF.gradients (x + y*3) [x, y] + -- If this test fails, it will likely be caused by an exception within + -- `TF.gradients`. These asserts are extra. + 1 @=? TF.unScalar dx + 3 @=? TF.unScalar dy + + +-- Test that name scopes work with createGraph. +testCreateGraphNameScopes :: Test +testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do + [dx] <- TF.runSession $ TF.buildAnd TF.run $ do + let shape = TF.constant (TF.Shape [1]) [1] + x :: TF.Tensor TF.Value Float <- + TF.withNameScope "foo" (TF.truncatedNormal shape) + TF.gradients x [x] + -- If this test fails, it will likely be caused by an exception within + -- `TF.gradients`. This assert is extra. + 1 @=? TF.unScalar dx + + +-- Test that createGraph can handle graphs with diamond shapes. +testDiamond :: Test +testDiamond = testCase "testDiamond" $ do + [dx] <- TF.runSession $ TF.buildAnd TF.run $ do + let x = TF.vector [1] + y = x*x + z = y*y + TF.gradients z [x] + (4 :: Float) @=? TF.unScalar dx + + +main :: IO () +main = googleTest [ testGradientSimple + , testGradientDisconnected + , testCreateGraphStateful + , testCreateGraphNameScopes + , testDiamond + ] diff --git a/tensorflow-ops/tests/MiscTest.hs b/tensorflow-ops/tests/MiscTest.hs new file mode 100644 index 0000000..5323191 --- /dev/null +++ b/tensorflow-ops/tests/MiscTest.hs @@ -0,0 +1,46 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE RankNTypes #-} + +module Main where + +import Control.Monad.IO.Class (liftIO) +import Data.Int (Int32) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@=?)) +import Google.Test +import qualified Data.Vector as V + +import TensorFlow.Ops +import TensorFlow.Session + +-- | Test fetching multiple outputs from an op. +testMultipleOutputs = testCase "testMultipleOutputs" $ + runSession $ do + (values, indices) <- run $ topK 2 $ constant [1, 4] [10, 40, 20, 30] + liftIO $ [40, 30] @=? V.toList (values :: V.Vector Float) + liftIO $ [1, 3] @=? V.toList (indices :: V.Vector Int32) + +-- | Test op with variable number of inputs. +testVarargs = testCase "testVarargs" $ + runSession $ do + xs <- run $ pack $ map scalar [1..8] + liftIO $ [1..8] @=? V.toList (xs :: V.Vector Float) + +main :: IO () +main = googleTest [ testMultipleOutputs + , testVarargs + ] diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs new file mode 100644 index 0000000..3a5b7ae --- /dev/null +++ b/tensorflow-ops/tests/OpsTest.hs @@ -0,0 +1,70 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE NoMonomorphismRestriction #-} + +module Main where + +import Control.Monad.IO.Class (liftIO) +import Data.Int (Int32, Int64) +import Google.Test (googleTest) +import System.IO.Temp (withSystemTempDirectory) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@=?)) +import qualified Data.ByteString.Char8 as B8 + +import qualified Data.Vector as V +import qualified TensorFlow.Build as TF +import qualified TensorFlow.ControlFlow as TF +import qualified TensorFlow.Nodes as TF +import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Session as TF +import qualified TensorFlow.Tensor as TF + +-- | Test that one can easily determine number of elements in the tensor. +testSize = testCase "testSize" $ do + x <- eval $ TF.size (TF.constant [2, 3] [0..5 :: Float]) + TF.Scalar (2 * 3 :: Int32) @=? x + +eval = TF.runSession . TF.buildAnd TF.run . return + +-- | Confirms that the original example from Python code works. +testReducedShape = testCase "testReducedShape" $ do + x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64]) + (TF.vector [1, 2 :: Int32]) + V.fromList [2, 1, 1, 7 :: Int32] @=? x + +testSaveRestore = testCase "testSaveRestore" $ + withSystemTempDirectory "" $ \dirPath -> do + let path = B8.pack $ dirPath ++ "/checkpoint" + var :: TF.Build (TF.Tensor TF.Ref Float) + var = TF.render =<< TF.named "a" <$> TF.zeroInitializedVariable [] + TF.runSession $ do + v <- TF.build var + TF.buildAnd TF.run_ $ TF.assign v 134 + TF.buildAnd TF.run_ $ TF.save path [v] + result <- TF.runSession $ do + v <- TF.build var + TF.buildAnd TF.run_ $ TF.restore path v + TF.run v + liftIO $ TF.Scalar 134 @=? result + + +main :: IO () +main = googleTest [ testSaveRestore + , testSize + , testReducedShape + ] diff --git a/tensorflow-ops/tests/TypesTest.hs b/tensorflow-ops/tests/TypesTest.hs new file mode 100644 index 0000000..91175ab --- /dev/null +++ b/tensorflow-ops/tests/TypesTest.hs @@ -0,0 +1,119 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE NoMonomorphismRestriction #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeFamilies #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +import Control.Monad (replicateM) +import Control.Monad.IO.Class (liftIO) +import Data.Int (Int64) +import Google.Test (googleTest) +import Test.Framework.Providers.HUnit (testCase) +import Test.Framework.Providers.QuickCheck2 (testProperty) +import Test.HUnit ((@=?)) +import Test.QuickCheck (Arbitrary(..), listOf, suchThat) +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as B8 +import qualified Data.Vector as V + +import qualified TensorFlow.ControlFlow as TF +import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Session as TF +import qualified TensorFlow.Tensor as TF +import qualified TensorFlow.Types as TF + +instance Arbitrary B.ByteString where + arbitrary = B.pack <$> arbitrary + +-- Test encoding tensors, feeding them through tensorflow, and decoding the +-- results. +testFFIRoundTrip = testCase "testFFIRoundTrip" $ + TF.runSession $ do + let floatData = V.fromList [1..6 :: Float] + stringData = V.fromList [B8.pack (show x) | x <- [1..6]] + f <- TF.build $ TF.placeholder [2,3] + s <- TF.build $ TF.placeholder [2,3] + let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData) + , TF.feed s (TF.encodeTensorData [2,3] stringData) + ] + -- It is an error to fetch a tensor that is being fed, so the tensors + -- are passed through identity. + (f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s) + liftIO $ do + floatData @=? f' + stringData @=? s' + + +data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a) + deriving Show + +instance Arbitrary a => Arbitrary (TensorDataInputs a) where + arbitrary = do + -- Limit the size of the final vector, and also guard against overflow + -- (i.e., p<0) when there are too many dimensions + let validProduct p = p > 0 && p < 100 + sizes <- listOf (arbitrary `suchThat` (>0)) + `suchThat` (validProduct . product) + elems <- replicateM (fromIntegral $ product sizes) arbitrary + return $ TensorDataInputs sizes (V.fromList elems) + +-- Test that a vector is unchanged after being encoded and decoded. +encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool +encodeDecodeProp (TensorDataInputs shape vec) = + TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec + +testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat" + (encodeDecodeProp :: TensorDataInputs Float -> Bool) + +testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64" + (encodeDecodeProp :: TensorDataInputs Int64 -> Bool) + +testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString" + (encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool) + +doubleOrInt64Func :: TF.OneOf '[Double, Int64] a => a -> a +doubleOrInt64Func = id + +doubleOrFloatFunc :: TF.OneOf '[Double, Float] a => a -> a +doubleOrFloatFunc = id + +doubleFunc :: TF.OneOf '[Double] a => a -> a +doubleFunc = doubleOrFloatFunc . doubleOrInt64Func + +-- No explicit type signature; make sure it can be inferred automatically. +-- (Note: this would fail if we didn't have NoMonomorphismRestriction, since it +-- can't simplify the type all the way to `Double -> Double`. +doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func + +typeConstraintTests = testCase "type constraints" $ do + 42 @=? doubleOrInt64Func (42 :: Double) + 42 @=? doubleOrInt64Func (42 :: Int64) + 42 @=? doubleOrFloatFunc (42 :: Double) + 42 @=? doubleOrFloatFunc (42 :: Float) + 42 @=? doubleFunc (42 :: Double) + 42 @=? doubleFuncNoSig (42 :: Double) + + +main :: IO () +main = googleTest [ testFFIRoundTrip + , testEncodeDecodeQcFloat + , testEncodeDecodeQcInt64 + , testEncodeDecodeQcString + , typeConstraintTests + ] diff --git a/tensorflow-proto/Setup.hs b/tensorflow-proto/Setup.hs new file mode 100644 index 0000000..3621b82 --- /dev/null +++ b/tensorflow-proto/Setup.hs @@ -0,0 +1,17 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +import Data.ProtoLens.Setup + +main = defaultMainGeneratingProtos "../third_party/tensorflow" diff --git a/tensorflow-proto/tensorflow-proto.cabal b/tensorflow-proto/tensorflow-proto.cabal new file mode 100644 index 0000000..2cb3a69 --- /dev/null +++ b/tensorflow-proto/tensorflow-proto.cabal @@ -0,0 +1,40 @@ +name: tensorflow-proto +version: 0.1.0.0 +synopsis: TensorFlow protocol buffers. +description: Please see README.md +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Custom +cabal-version: >=1.22 +extra-source-files: ../third_party/tensorflow/tensorflow/core/framework/*.proto + , ../third_party/tensorflow/tensorflow/core/protobuf/config.proto + +library + exposed-modules: Proto.Tensorflow.Core.Framework.AttrValue + , Proto.Tensorflow.Core.Framework.Graph + , Proto.Tensorflow.Core.Framework.NodeDef + , Proto.Tensorflow.Core.Framework.OpDef + , Proto.Tensorflow.Core.Framework.ResourceHandle + , Proto.Tensorflow.Core.Framework.Tensor + , Proto.Tensorflow.Core.Framework.TensorShape + , Proto.Tensorflow.Core.Framework.Types + , Proto.Tensorflow.Core.Protobuf.Config + other-modules: Proto.Tensorflow.Core.Framework.AllocationDescription + , Proto.Tensorflow.Core.Framework.CostGraph + , Proto.Tensorflow.Core.Framework.Function + , Proto.Tensorflow.Core.Framework.StepStats + , Proto.Tensorflow.Core.Framework.TensorDescription + , Proto.Tensorflow.Core.Framework.Versions + build-depends: proto-lens == 0.1.* + , proto-lens-protoc == 0.1.* + , base >= 4.7 && < 5 + default-language: Haskell2010 + include-dirs: . + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/tensorflow-queue/Setup.hs b/tensorflow-queue/Setup.hs new file mode 100644 index 0000000..e8ef27d --- /dev/null +++ b/tensorflow-queue/Setup.hs @@ -0,0 +1,3 @@ +import Distribution.Simple + +main = defaultMain diff --git a/tensorflow-queue/src/TensorFlow/Queue.hs b/tensorflow-queue/src/TensorFlow/Queue.hs new file mode 100644 index 0000000..0d0ddca --- /dev/null +++ b/tensorflow-queue/src/TensorFlow/Queue.hs @@ -0,0 +1,78 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +-- | Queues in TensorFlow graph. Very limited support for now. +module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where + +import Data.ByteString (ByteString) +import Data.Int (Int64) +import Lens.Family2 ((.~), (&)) +import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef) +import TensorFlow.BuildOp (buildOp) +import TensorFlow.ControlFlow (group) +import TensorFlow.Tensor (Ref, Tensor) +import TensorFlow.Types (TensorType, tensorType) + +-- | A queue carrying tuples. The underlying structure is more +-- versatile and can be made to support arbitrary tuples. +data Queue2 a b = Queue2 { handle :: Handle } + +type Handle = Tensor Ref ByteString + +-- | Adds the given values to the queue. +enqueue :: forall a b v1 v2. (TensorType a, TensorType b) + => Queue2 a b + -> Tensor v1 a + -> Tensor v2 b + -> Build ControlNode +enqueue q = + buildOp (opDef "QueueEnqueue" + & opAttr "Tcomponents" .~ [ tensorType (undefined :: a) + , tensorType (undefined :: b)]) + (handle q) + +-- | Retrieves the values from the queue. +dequeue :: forall a b . (TensorType a, TensorType b) + => Queue2 a b + -> Build (Tensor Ref a, Tensor Ref b) + -- ^ Dequeued tensors. They are paired in a sense + -- that values appear together, even if they are + -- not consumed together. +dequeue q = + buildOp (opDef "QueueDequeue" + & opAttr "component_types" .~ [ tensorType (undefined :: a) + , tensorType (undefined :: b)]) + (handle q) + +-- | Creates a new queue with the given capacity and shared name. +makeQueue2 :: forall a b . (TensorType a, TensorType b) + => Int64 -- ^ The upper bound on the number of elements in + -- this queue. Negative numbers mean no limit. + -> ByteString -- ^ If non-empty, this queue will be shared + -- under the given name across multiple sessions. + -> Build (Queue2 a b) +makeQueue2 capacity sharedName = do + q <- buildOp (opDef "FIFOQueue" + & opAttr "component_types" .~ [ tensorType (undefined :: a) + , tensorType (undefined :: b)] + & opAttr "shared_name" .~ sharedName + & opAttr "capacity" .~ capacity + ) + group q >>= addInitializer + return (Queue2 q) + +-- TODO(gnezdo): Figure out the closing story for queues. diff --git a/tensorflow-queue/tensorflow-queue.cabal b/tensorflow-queue/tensorflow-queue.cabal new file mode 100644 index 0000000..dcf2c8a --- /dev/null +++ b/tensorflow-queue/tensorflow-queue.cabal @@ -0,0 +1,51 @@ +name: tensorflow-queue +version: 0.1.0.0 +synopsis: Basic access to TensorFlow queues. +description: Please see README.md +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Simple +cabal-version: >=1.22 + +library + hs-source-dirs: src + exposed-modules: TensorFlow.Queue + build-depends: proto-lens == 0.1.* + , base >= 4.7 && < 5 + , bytestring + , lens-family + , containers + , tensorflow-proto == 0.1.* + , tensorflow-core-ops == 0.1.* + , tensorflow + , text + default-language: Haskell2010 + +Test-Suite QueueTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: QueueTest.hs + hs-source-dirs: tests + -- Uses multiple threads and blocks without this option. + ghc-options: -threaded + build-depends: HUnit + , base + , bytestring + , proto-lens + , lens-family + , google-shim + , tensorflow + , tensorflow-ops + , tensorflow-queue + , test-framework + , test-framework-hunit + , transformers + , vector + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/tensorflow-queue/tests/QueueTest.hs b/tensorflow-queue/tests/QueueTest.hs new file mode 100644 index 0000000..4510a4d --- /dev/null +++ b/tensorflow-queue/tests/QueueTest.hs @@ -0,0 +1,79 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Main where + +import Control.Monad.IO.Class (liftIO) +import Data.Int (Int64) +import Google.Test (googleTest) +import TensorFlow.Nodes (Scalar(..)) +import TensorFlow.Ops (scalar) +import TensorFlow.Queue +import TensorFlow.Session + ( asyncProdNodes + , build + , buildAnd + , run + , runSession + , run_ + ) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@=?)) +import qualified Data.ByteString as BS + +-- | Test basic queue behaviors. +testBasic = testCase "testBasic" $ runSession $ do + (q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 "" + buildAnd run_ (enqueue q 42 (scalar "Hi")) + x <- buildAnd run (dequeue q) + liftIO $ (Scalar 42, Scalar "Hi") @=? x + + buildAnd run_ (enqueue q 56 (scalar "Bar")) + y <- buildAnd run (dequeue q) + liftIO $ (Scalar 56, Scalar "Bar") @=? y + +-- | Test queue pumping. +testPump = testCase "testPump" $ runSession $ do + (deq, pump) <- build $ do + q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue" + (,) <$> dequeue q + <*> enqueue q 31 (scalar "Baz") + -- This is a realistic use. The pump inputs are pre-bound to some + -- nodes that produce values when pumped (e.g. read from a + -- file). + run_ (pump, pump) + + (x, y) <- run (deq, deq) + liftIO $ (Scalar 31, Scalar "Baz") @=? x + liftIO $ (Scalar 31, Scalar "Baz") @=? y + +testAsync = testCase "testAsync" $ runSession $ do + (deq, pump) <- build $ do + q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "" + (,) <$> dequeue q + <*> enqueue q 10 (scalar "Async") + -- Pumps the queue until canceled by runSession exiting. + asyncProdNodes pump + -- Picks up a couple values and verifies they are as expected. + run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?) + run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?) + +main :: IO () +main = googleTest [ testBasic + , testPump + , testAsync + ] diff --git a/tensorflow/Setup.hs b/tensorflow/Setup.hs new file mode 100644 index 0000000..e8ef27d --- /dev/null +++ b/tensorflow/Setup.hs @@ -0,0 +1,3 @@ +import Distribution.Simple + +main = defaultMain diff --git a/tensorflow/src/TensorFlow/Build.hs b/tensorflow/src/TensorFlow/Build.hs new file mode 100644 index 0000000..2165c94 --- /dev/null +++ b/tensorflow/src/TensorFlow/Build.hs @@ -0,0 +1,376 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE OverloadedStrings #-} +module TensorFlow.Build + ( -- * Graph node types + ControlNode(..) + , Unique + -- * Ops + , explicitName + , implicitName + , opDef + , opDefWithName + , opName + , opType + , opAttr + , opInputs + , opControlInputs + -- * The Build monad + , GraphState + , render + , renderNodeName + , renderedNodeDefs + , BuildT + , Build + , addInitializer + , hoistBuildT + , evalBuildT + , runBuildT + , asGraphDef + , addGraphDef + , flushInitializers + , flushNodeBuffer + -- * Creating and looking up Ops + , getOrAddOp + , addNewOp + , renderOutput + -- * Modifying all nodes in a Build action + , colocateWith + , withStateLens + , withDevice + , withNameScope + , withNodeDependencies + -- * Internal Summary related bits. + , addSummary + , SummaryTensor + , collectAllSummaries + ) where + +import Control.Monad.IO.Class (MonadIO(..)) +import Control.Monad.Trans.Class (MonadTrans(..)) +import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT) +import Data.ByteString (ByteString) +import Data.Default (def) +import Data.Functor.Identity (Identity(..)) +import qualified Data.Map.Strict as Map +import Data.Monoid ((<>)) +import qualified Data.Set as Set +import Data.Set (Set) +import Data.String (IsString(..)) +import Data.Text (Text) +import qualified Data.Text as Text +import Lens.Family2 (Lens', (.~), (^.), (&)) +import Lens.Family2.State.Strict (MonadState, use, uses, (.=), (<>=), (%=)) +import Lens.Family2.Unchecked (lens) +import Proto.Tensorflow.Core.Framework.Graph + ( GraphDef + , node + ) +import Proto.Tensorflow.Core.Framework.NodeDef + ( NodeDef + , attr + , input + , device + , name + , op + ) + +import TensorFlow.Orphans () +import TensorFlow.Output +import TensorFlow.Tensor + +newtype Unique = Unique Int + deriving (Eq, Ord, Enum) + +-------------- + +implicitName :: PendingNodeName +implicitName = ImplicitName + +explicitName :: Text -> PendingNodeName +explicitName = ExplicitName + +newtype Scope = Scope {unScope :: Text} + deriving (Eq, Ord, IsString) + +instance Show Scope where + show = show . unScope + +opDef :: OpType -> OpDef +opDef = opDefWithName ImplicitName + +opDefWithName :: PendingNodeName -> OpType -> OpDef +opDefWithName n t = OpDef + { _opName = n + , _opType = t + , _opAttrs = Map.empty + , _opInputs = [] + , _opControlInputs = [] + } + +-- | Synonym for the tensors that return serialized Summary proto. +type SummaryTensor = Tensor Value ByteString + +data GraphState = GraphState + { _renderedNodes :: !(Map.Map PendingNode NodeDef) + -- ^ Nodes which have been rendered. Keeps track of the unique ID we + -- assign each implicitly-named node. Also prevents us from adding the + -- same node (implicit or explicit) more than once to the nodeBuffer. + , _renderedNodeDefs :: !(Map.Map NodeName NodeDef) + -- ^ The NodeDefs of nodes which have been rendered. Used by the + -- Gradient module to inspect the node graph. + , _nodeBuffer :: [NodeDef] + -- ^ A list of nodes that should be passed to TensorFlow during + -- the next call to Session.extend (TF_ExtendGraph). + , _nextUnique :: !Unique + -- ^ Unique ID for the next node + -- TODO(judahjacobson): watch for clashes between auto and user names. + , _defaultDevice :: !(Maybe Device) + , _currentScope :: [Scope] + , _defaultControlInputs :: !(Set NodeName) + , _initializationNodes :: [NodeName] + -- ^ The nodes to run next time a TF.run is issued, typically + -- variable initializers. + , _summaries :: [SummaryTensor] + -- ^ The tensors for summary + } + +-- | A node definition without its final name. Used as a key in the +-- "renderedNodes" map. +-- The NodeDef contained inside has an empty "name" field. +data PendingNode = PendingNode [Scope] !PendingNodeName !NodeDef + deriving (Eq, Ord) + +-- Returns an _incomplete_ NodeDef. The name is fixed by addNewOpFromPending. +pendingNodeDef :: PendingNode -> NodeDef +pendingNodeDef (PendingNode _ _ n) = n + +initGraphState :: GraphState +initGraphState = + GraphState Map.empty Map.empty [] (Unique 0) Nothing [] Set.empty [] [] + +renderedNodes :: Lens' GraphState (Map.Map PendingNode NodeDef) +renderedNodes = lens _renderedNodes (\g x -> g { _renderedNodes = x }) + +renderedNodeDefs :: Lens' GraphState (Map.Map NodeName NodeDef) +renderedNodeDefs = lens _renderedNodeDefs (\g x -> g { _renderedNodeDefs = x }) + +nodeBuffer :: Lens' GraphState [NodeDef] +nodeBuffer = lens _nodeBuffer (\g x -> g { _nodeBuffer = x }) + +nextUnique :: Lens' GraphState Unique +nextUnique = lens _nextUnique (\g x -> g { _nextUnique = x }) + +defaultDevice :: Lens' GraphState (Maybe Device) +defaultDevice = lens _defaultDevice (\g x -> g { _defaultDevice = x }) + +currentScope :: Lens' GraphState [Scope] +currentScope = lens _currentScope (\g x -> g { _currentScope = x }) + +defaultControlInputs :: Lens' GraphState (Set NodeName) +defaultControlInputs = lens _defaultControlInputs + (\g x -> g { _defaultControlInputs = x }) + +initializationNodes :: Lens' GraphState [NodeName] +initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x }) + +summaries :: Lens' GraphState [SummaryTensor] +summaries = lens _summaries (\g x -> g { _summaries = x }) + +-- | An action for building nodes in a TensorFlow graph. +-- Used to manage build state internally as part of the @Session@ monad. +newtype BuildT m a = BuildT (StateT GraphState m a) + deriving (Functor, Applicative, Monad, MonadIO, MonadTrans, + MonadState GraphState) + +-- | An action for building nodes in a TensorFlow graph. +type Build = BuildT Identity + +-- | This is Control.Monad.Morph.hoist sans the dependency. +hoistBuildT :: (forall a . m a -> n a) -> BuildT m b -> BuildT n b +hoistBuildT f (BuildT m) = BuildT $ mapStateT f m + +runBuildT :: BuildT m a -> m (a, GraphState) +runBuildT (BuildT f) = runStateT f initGraphState + +evalBuildT :: Monad m => BuildT m a -> m a +evalBuildT (BuildT f) = evalStateT f initGraphState + +-- | Get all the NodeDefs that have accumulated so far, and clear that buffer. +flushNodeBuffer :: Monad m => BuildT m [NodeDef] +flushNodeBuffer = do + ns <- use nodeBuffer + nodeBuffer .= [] + return ns + +-- | Get all the initializers that have accumulated so far, and clear +-- that buffer. +flushInitializers :: Monad m => BuildT m [NodeName] +flushInitializers = do + ns <- use initializationNodes + initializationNodes .= [] + return ns + +-- | Registers the given node to be executed before the next +-- 'TensorFlow.Session.run'. +addInitializer :: ControlNode -> Build () +addInitializer (ControlNode o) = do + i <- getOrAddOp o + initializationNodes %= (i:) + +-- | Produce a GraphDef proto representation of the nodes that are rendered in +-- the given 'Build' action. +asGraphDef :: Build a -> GraphDef +asGraphDef b = def & node .~ gs ^. nodeBuffer + where + gs = snd $ runIdentity $ runBuildT b + +-- TODO: check against existing nodes for conflicts? +addGraphDef :: GraphDef -> Build () +addGraphDef g = nodeBuffer <>= g ^. node + +-- | Render the given op if it hasn't been rendered already, and return its +-- name. +getOrAddOp :: Op -> Build NodeName +getOrAddOp o = NodeName . (^. name) <$> resolveOp o + +resolveOp :: Op -> Build NodeDef +resolveOp (Rendered n) = return n +resolveOp (Unrendered o) = do + pending <- getPendingNode o + uses renderedNodes (Map.lookup pending) >>= \case + Just n -> return n + Nothing -> addNewOpFromPending pending + +-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops +-- which are not safe to dedup (e.g, "variable" and "assign"). +addNewOp :: OpDef -> Build NodeDef +addNewOp o = getPendingNode o >>= addNewOpFromPending + +addNewOpFromPending :: PendingNode -> Build NodeDef +addNewOpFromPending pending = do + nodeName <- renderPendingNode pending + let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName + nodeBuffer %= (nodeDef :) + renderedNodes %= Map.insert pending nodeDef + renderedNodeDefs %= Map.insert nodeName nodeDef + return nodeDef + +-- | Get the pending node corresponding to an OpDef, which may or may not have +-- been rendered before. Implicitly renders all of this node's inputs. +getPendingNode :: OpDef -> Build PendingNode +getPendingNode o = do + -- An empty string in the proto field means that no specific + -- device is specified. + dev <- maybe "" deviceName <$> use defaultDevice + inputs <- mapM getInput (o ^. opInputs) + scope <- use currentScope + controls <- use defaultControlInputs + let controlInputs + = map getDep (o ^. opControlInputs ++ Set.toList controls) + return $ PendingNode scope (o ^. opName) + $ def & op .~ (unOpType (o ^. opType) :: Text) + & attr .~ _opAttrs o + & input .~ (inputs ++ controlInputs) + & device .~ dev + where + getInput (Output (OutputIx k) subOp) + = (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp + getDep = ("^" <>) . unNodeName + +-- | Pick a name for a pending node. If it has an explicit name, just use that; +-- if the name is implicit, assign a new unique name based on the op type. +renderPendingNode :: PendingNode -> Build NodeName +renderPendingNode (PendingNode scope pendingName nodeDef) + = NodeName . (scopePrefix <>) <$> getName + where + scopePrefix = Text.concat $ fmap ((<> "/") . unScope) scope + getName = case pendingName of + ExplicitName n -> return n + ImplicitName -> do + u@(Unique k) <- use nextUnique + nextUnique .= succ u + return $ nodeDef ^. op <> "_" <> Text.pack (show k) + + +-- | Render an 'Output' and return a string representation for the TensorFlow +-- foreign APIs. +renderOutput :: Output -> Build Text +renderOutput (Output (OutputIx i) o) = do + n <- getOrAddOp o + return $ unNodeName n <> Text.pack (":" ++ show i) + +-- | Modify some part of the state, run an action, and restore the state +-- after that action is done. +withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b +withStateLens accessor f act = do + old <- use accessor + accessor %= f + result <- act + accessor .= old + return result + +-- | Set a device for all nodes rendered in the given 'Build' action +-- (unless further overridden by another use of withDevice). +withDevice :: Maybe Device -> Build a -> Build a +withDevice d = withStateLens defaultDevice (const d) + +-- | Places all nodes rendered in the given 'Build' action on the same +-- device as the given Tensor (see also 'withDevice'). Make sure that +-- the action has side effects of rendering the desired tensors. A pure +-- return would not have the desired effect. +colocateWith :: forall a v b . Tensor v b -> Build a -> Build a +colocateWith t x = do + d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp) + withDevice (Just d) x + +-- | Prepend a scope to all nodes rendered in the given 'Build' action. +withNameScope :: Text -> Build a -> Build a +withNameScope s = withStateLens currentScope (Scope s :) + +-- | Add control inputs to all nodes rendered in the given 'Build' action. +withNodeDependencies :: Set NodeName -> Build a -> Build a +withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes) + +-- | Render a 'Tensor', fixing its name, scope, device and control inputs from +-- the 'Build' context. Also renders any dependencies of the 'Tensor' that +-- weren't already rendered. +-- +-- This operation is idempotent; @render >=> render === render@. However, +-- rendering a (previously un-rendered) 'Tensor' in two different contexts +-- may result in two different 'Tensor's. +render :: Tensor v a -> Build (Tensor v a) +render = tensorOutput $ outputOp $ fmap Rendered . resolveOp + +-- | Render a 'Tensor' and get its node's name. +renderNodeName :: Tensor v a -> Build NodeName +renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp) + +-- | Records the given summary action in Build for retrieval with +-- 'collectAllSummaries'. The summary op is required to produce a +-- Summary protocol buffer in string form. For safety, use the +-- pre-composed functions: Logging.scalarSummary and +-- Logging.histogramSummary. +addSummary :: SummaryTensor -> Build () +addSummary t = summaries %= (t :) + +-- | Retrieves the summary ops collected thus far. Typically this only +-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used +-- repeatedly, the values accumulate. +collectAllSummaries :: Monad m => BuildT m [SummaryTensor] +collectAllSummaries = use summaries diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs new file mode 100644 index 0000000..9a96ced --- /dev/null +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -0,0 +1,199 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TupleSections #-} + +module TensorFlow.BuildOp + ( OpResult + , BuildOp + , buildOp + , buildListOp + , eqLengthGuard + ) + where + +import Control.Monad (replicateM) +import Control.Monad.Reader (ReaderT, runReaderT, ask) +import Control.Monad.State.Strict (State, runState, get, put) +import Data.Int (Int64) +import Lens.Family2 ((&), (<>~), (^.)) + +import TensorFlow.Build +import TensorFlow.Output +import TensorFlow.Tensor + +data ResultState = ResultState !OutputIx [Int64] deriving Show + +type Result = ReaderT Op (State ResultState) + +-- | Class of types that can be used as op outputs. +class OpResult a where + toResult :: Result a + +instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where + toResult = (,) <$> toResult <*> toResult + +instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where + toResult = (,,) <$> toResult <*> toResult <*> toResult + +instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4) + => OpResult (a1, a2, a3, a4) where + toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult + +instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5) + => OpResult (a1, a2, a3, a4, a5) where + toResult = (,,,,) <$> toResult + <*> toResult + <*> toResult + <*> toResult + <*> toResult + +instance ( OpResult a1 + , OpResult a2 + , OpResult a3 + , OpResult a4 + , OpResult a5 + , OpResult a6 + ) + => OpResult (a1, a2, a3, a4, a5, a6) where + toResult = (,,,,,) + <$> toResult + <*> toResult + <*> toResult + <*> toResult + <*> toResult + <*> toResult + +tensorResult :: TensorKind v -> Result (Tensor v a) +tensorResult v = do + o <- ask + ResultState i ns <- get + put $! ResultState (i+1) ns + return $! Tensor v $ output i o + +instance OpResult (Tensor Value a) where + toResult = tensorResult ValueKind + +instance OpResult (Tensor Ref a) where + toResult = tensorResult RefKind + +instance OpResult ControlNode where + toResult = ControlNode <$> ask + +instance OpResult a => OpResult [a] where + toResult = do + ResultState i ns <- get + case ns of + [] -> error $ "Ran out of counts in toResult. " ++ + "Likely misuse of buildListOp." + (n : rest) -> do + put $! ResultState i rest + replicateM (fromIntegral n) toResult + +runResult :: OpResult a => [Int64] -> Op -> a +runResult ns o = + case runState (runReaderT toResult o) (ResultState 0 ns) of + (x, ResultState _ []) -> x + (_, ns') -> error $ "Ununsed length in runResult attributes: " ++ + show (ns, ns') + +-- | Make a new "pure" op, which may be deduped with identical ops within +-- the same scope. +pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a +pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts + +-- | Make a new "stateful" op, which will not be deduped with otherwise +-- identical ops. +buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a +buildResult ns o ts + = runResult ns . Rendered <$> addNewOp (addReversedInputs o ts) + +addReversedInputs :: OpDef -> [Output] -> OpDef +addReversedInputs o ts = o & opInputs <>~ reverse ts + +-- | Class of types that can be used as op functions. +class BuildOp f where + buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr) + -> OpDef + -> [Output] -- ^ Accumulator for inputs to the op. + -> f + +-- | Starts an operation that returns a structured set of tensors +-- (singletons or tuples). +buildOp :: BuildOp f => OpDef -> f +buildOp o = buildOp' [] o [] + +-- | Starts an operation that returns a list of tensors. +buildListOp :: BuildOp f => [Int64] + -- ^ Cardinality of the corresponding list of tensors output. + -> OpDef -> f +buildListOp counts o = buildOp' counts o [] + +instance BuildOp ControlNode where + buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts + +instance BuildOp (Tensor Value a) where + buildOp' = pureResult + +instance BuildOp (Tensor Ref a) where + buildOp' = pureResult + +instance BuildOp [Tensor Value a] where + buildOp' = pureResult + +instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where + buildOp' = pureResult + +instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where + buildOp' = pureResult + +instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4) + => BuildOp (t1, t2, t3, t4) where + buildOp' = pureResult + +instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5) + => BuildOp (t1, t2, t3, t4, t5) where + buildOp' = pureResult + +instance ( OpResult t1 + , OpResult t2 + , OpResult t3 + , OpResult t4 + , OpResult t5 + , OpResult t6 + ) + => BuildOp (t1, t2, t3, t4, t5, t6) where + buildOp' = pureResult + +instance OpResult a => BuildOp (Build a) where + buildOp' = buildResult + +instance BuildOp f => BuildOp (Tensor v a -> f) where + buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts) + +instance BuildOp f => BuildOp ([Tensor v a] -> f) where + buildOp' rf o accum ts + = buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum) + +-- | Returns true if all the integers in each tuple are identical. +-- Throws an error with a descriptive message if not. +eqLengthGuard :: [(String, [(String, Int)])] -> Bool +eqLengthGuard = all eachOk + where + eachOk (_, []) = True + -- The next line has (== 1) . length . nub in disguise + eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs || + error ("number_attr " ++ numberAttrName ++ + " contains tensors with different length " ++ show pairs) diff --git a/tensorflow/src/TensorFlow/ControlFlow.hs b/tensorflow/src/TensorFlow/ControlFlow.hs new file mode 100644 index 0000000..9b3f112 --- /dev/null +++ b/tensorflow/src/TensorFlow/ControlFlow.hs @@ -0,0 +1,87 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module TensorFlow.ControlFlow + ( -- * Dependencies + withControlDependencies + , group + -- * Operations + , identity + , noOp + , named + ) where + +import qualified Data.Set as Set +import Data.Text (Text) +import Lens.Family2 ((&), (^.), (.~)) + +import TensorFlow.BuildOp +import TensorFlow.Build +import TensorFlow.Nodes +import TensorFlow.Output +import TensorFlow.Tensor +import TensorFlow.Types + +-- | Modify a 'Build' action, such that all new ops rendered in it will depend +-- on the nodes in the first argument. +withControlDependencies :: Nodes t => t -> Build a -> Build a +withControlDependencies deps act = do + nodes <- getNodes deps + withNodeDependencies nodes act + +-- TODO(judahjacobson): Reimplement withDependencies. + +-- | Create an op that groups multiple operations. +-- +-- When this op finishes, all ops in the input @n@ have finished. This op has +-- no output. +group :: Nodes t => t -> Build ControlNode +group deps = do + nodes <- Set.toList <$> getNodes deps + -- TODO: slicker way + return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes + + +-- | Returns a 'Tensor' with the same shape and contents as the input. +identity :: TensorType a => Tensor v a -> Tensor v a +identity = namedIdentity implicitName + +-- | Returns a 'Tensor' with a given name and the same shape and contents as +-- the input. +-- +-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s, +-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into +-- whether we can change that op. +named :: TensorType a => Text -> Tensor v a -> Tensor v a +named = namedIdentity . explicitName + +-- | An internal version of "identity" that allows setting the name +-- of the output Tensor. +namedIdentity :: forall a v . TensorType a + => PendingNodeName -> Tensor v a -> Tensor v a +namedIdentity n t = case t ^. tensorKind of + ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t + RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t + where + setTypeAttr = opAttr "T" .~ tensorType (undefined :: a) + + +-- | Does nothing. Only useful as a placeholder for control edges. +noOp :: ControlNode +noOp = buildOp $ opDef "NoOp" diff --git a/tensorflow/src/TensorFlow/Internal/FFI.hs b/tensorflow/src/TensorFlow/Internal/FFI.hs new file mode 100644 index 0000000..05a33b6 --- /dev/null +++ b/tensorflow/src/TensorFlow/Internal/FFI.hs @@ -0,0 +1,243 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE OverloadedStrings #-} + +module TensorFlow.Internal.FFI + ( TensorFlowException(..) + , Raw.Session + , withSession + , extendGraph + , run + , TensorData(..) + , setSessionConfig + , setSessionTarget + , getAllOpList + -- * Internal helper. + , useProtoAsVoidPtrLen + ) + where + +import Control.Concurrent.Async (Async, async, cancel, waitCatch) +import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar) +import Control.Exception (Exception, throwIO, bracket, finally, mask_) +import Control.Monad (when) +import Data.Int (Int64) +import Data.Typeable (Typeable) +import Data.Word (Word8) +import Foreign (Ptr, FunPtr, nullPtr, castPtr) +import Foreign.C.String (CString) +import Foreign.ForeignPtr (newForeignPtr, withForeignPtr) +import Foreign.Marshal.Alloc (free) +import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray) +import System.IO.Unsafe (unsafePerformIO) +import qualified Data.ByteString as B +import qualified Data.Text as T +import qualified Data.Text.Encoding as T +import qualified Data.Text.Encoding.Error as T +import qualified Data.Vector.Storable as S + +import Data.ProtoLens (Message, encodeMessage) +import Proto.Tensorflow.Core.Framework.Graph (GraphDef) +import Proto.Tensorflow.Core.Framework.Types (DataType(..)) +import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) + +import qualified TensorFlow.Internal.Raw as Raw + +data TensorFlowException = TensorFlowException Raw.Code T.Text + deriving (Show, Eq, Typeable) + +instance Exception TensorFlowException + +-- | All of the data needed to represent a tensor. +data TensorData = TensorData + { tensorDataDimensions :: [Int64] + , tensorDataType :: !DataType + , tensorDataBytes :: !(S.Vector Word8) + } + deriving (Show, Eq) + +-- | Runs the given action after creating a session with options +-- populated by the given optionSetter. +withSession :: (Raw.SessionOptions -> IO ()) + -> ((IO () -> IO ()) -> Raw.Session -> IO a) + -- ^ The action can spawn concurrent tasks which will + -- be canceled before withSession returns. + -> IO a +withSession optionSetter action = do + drain <- newMVar [] + let cleanup s = + -- Closes the session to nudge the pending run calls to fail and exit. + finally (checkStatus (Raw.closeSession s)) $ do + runners <- takeMVar drain + -- Collects all runners before deleting the session. + mapM_ shutDownRunner runners + checkStatus (Raw.deleteSession s) + bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do + optionSetter options + bracket + (checkStatus (Raw.newSession options)) + cleanup + (action (asyncCollector drain)) + +asyncCollector :: MVar [Async ()] -> IO () -> IO () +asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord + where + launchAndRecord restRunners = (: restRunners) <$> async runner + +shutDownRunner :: Async () -> IO () +shutDownRunner r = do + cancel r + -- TODO(gnezdo): manage exceptions better than print. + either print (const (return ())) =<< waitCatch r + +extendGraph :: Raw.Session -> GraphDef -> IO () +extendGraph session pb = + useProtoAsVoidPtrLen pb $ \ptr len -> + checkStatus $ Raw.extendGraph session ptr len + + +run :: Raw.Session + -> [(B.ByteString, TensorData)] -- ^ Feeds. + -> [B.ByteString] -- ^ Fetches. + -> [B.ByteString] -- ^ Targets. + -> IO [TensorData] +run session feeds fetches targets = do + let nullTensor = Raw.Tensor nullPtr + -- Use mask to avoid leaking input tensors before they are passed to 'run' + -- and output tensors before they are passed to 'createTensorData'. + mask_ $ + -- Feeds + withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames -> + mapM (createRawTensor . snd) feeds >>= \feedTensors -> + withArrayLen feedTensors $ \_ cFeedTensors -> + -- Fetches. + withStringArrayLen fetches $ \fetchesLen fetchNames -> + -- tensorOuts is an array of null Tensor pointers that will be filled + -- by the call to Raw.run. + withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts -> + -- Targets. + withStringArrayLen targets $ \targetsLen ctargets -> do + checkStatus $ Raw.run + session + nullPtr + feedNames cFeedTensors (fromIntegral feedsLen) + fetchNames tensorOuts (fromIntegral fetchesLen) + ctargets (fromIntegral targetsLen) + nullPtr + outTensors <- peekArray fetchesLen tensorOuts + mapM createTensorData outTensors + + +-- Internal. + + +-- | Use a list of ByteString as a list of CString. +withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a +withStringList strings fn = go strings [] + where + go [] cs = fn (reverse cs) + -- TODO(fmayle): Is it worth using unsafeAsCString here? + go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs) + + +-- | Use a list of ByteString as an array of CString. +withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a +withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn) + + +-- | Create a Raw.Tensor from a TensorData. +createRawTensor :: TensorData -> IO Raw.Tensor +createRawTensor (TensorData dims dt byteVec) = + withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do + let len = S.length byteVec + dest <- mallocArray len + S.unsafeWith byteVec $ \x -> copyArray dest x len + Raw.newTensor (toEnum $ fromEnum dt) + cdims (fromIntegral cdimsLen) + (castPtr dest) (fromIntegral len) + tensorDeallocFunPtr nullPtr + +{-# NOINLINE tensorDeallocFunPtr #-} +tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn +tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x + +-- | Create a TensorData from a Raw.Tensor. +-- +-- Takes ownership of the Raw.Tensor. +createTensorData :: Raw.Tensor -> IO TensorData +createTensorData t = do + -- Read dimensions. + numDims <- Raw.numDims t + dims <- mapM (Raw.dim t) [0..numDims-1] + -- Read type. + dtype <- toEnum . fromEnum <$> Raw.tensorType t + -- Read data. + len <- fromIntegral <$> Raw.tensorByteSize t + bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8) + -- TODO(fmayle): Don't copy the data. + v <- S.fromList <$> peekArray len bytes + -- Free tensor. + Raw.deleteTensor t + return $ TensorData (map fromIntegral dims) dtype v + +-- | Runs the given action which does FFI calls updating a provided +-- status object. If the status is not OK it is thrown as +-- TensorFlowException. +checkStatus :: (Raw.Status -> IO a) -> IO a +checkStatus fn = + bracket Raw.newStatus Raw.deleteStatus $ \status -> do + result <- fn status + code <- Raw.getCode status + when (code /= Raw.TF_OK) $ do + msg <- T.decodeUtf8With T.lenientDecode <$> + (Raw.message status >>= B.packCString) + throwIO $ TensorFlowException code msg + return result + +setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO () +setSessionConfig pb opt = + useProtoAsVoidPtrLen pb $ \ptr len -> + checkStatus (Raw.setConfig opt ptr len) + +setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO () +setSessionTarget target = B.useAsCString target . Raw.setTarget + +-- | Serializes the given msg and provides it as (ptr,len) argument +-- to the given action. +useProtoAsVoidPtrLen :: (Message msg, Num c) => + msg -> (Ptr b -> c -> IO a) -> IO a +useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $ + \(bytes, len) -> f (castPtr bytes) (fromIntegral len) + +-- | Returns the serialized OpList of all OpDefs defined in this +-- address space. +getAllOpList :: IO B.ByteString +getAllOpList = do + foreignPtr <- + mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall) + -- Makes a copy because it is more reliable than eviscerating + -- Buffer to steal its memory (including custom deallocator). + withForeignPtr foreignPtr $ + \ptr -> B.packCStringLen =<< (,) + <$> (castPtr <$> Raw.getBufferData ptr) + <*> (fromIntegral <$> Raw.getBufferLength ptr) + where + checkCall = do + p <- Raw.getAllOpList + when (p == nullPtr) (throwIO exception) + return p + exception = TensorFlowException + Raw.TF_UNKNOWN "GetAllOpList failure, check logs" diff --git a/tensorflow/src/TensorFlow/Internal/Raw.chs b/tensorflow/src/TensorFlow/Internal/Raw.chs new file mode 100644 index 0000000..94ce31b --- /dev/null +++ b/tensorflow/src/TensorFlow/Internal/Raw.chs @@ -0,0 +1,152 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE ForeignFunctionInterface #-} + +module TensorFlow.Internal.Raw where + +#include "third_party/tensorflow/c/c_api.h" + +import Foreign +import Foreign.C + +{#enum TF_DataType as DataType {} deriving (Show, Eq) #} +{#enum TF_Code as Code {} deriving (Show, Eq) #} + + +-- Status. +{#pointer *TF_Status as Status newtype #} + +newStatus :: IO Status +newStatus = {# call TF_NewStatus as ^ #} + +deleteStatus :: Status -> IO () +deleteStatus = {# call TF_DeleteStatus as ^ #} + +setStatus :: Status -> Code -> CString -> IO () +setStatus s c = {# call TF_SetStatus as ^ #} s (fromIntegral $ fromEnum c) + +getCode :: Status -> IO Code +getCode s = toEnum . fromIntegral <$> {# call TF_GetCode as ^ #} s + +message :: Status -> IO CString +message = {# call TF_Message as ^ #} + + +-- Buffer. +data Buffer +{#pointer *TF_Buffer as BufferPtr -> Buffer #} + +getBufferData :: BufferPtr -> IO (Ptr ()) +getBufferData = {#get TF_Buffer->data #} + +getBufferLength :: BufferPtr -> IO CULong +getBufferLength ={#get TF_Buffer->length #} + +-- Tensor. +{#pointer *TF_Tensor as Tensor newtype #} + +instance Storable Tensor where + sizeOf (Tensor t) = sizeOf t + alignment (Tensor t) = alignment t + peek p = fmap Tensor (peek (castPtr p)) + poke p (Tensor t) = poke (castPtr p) t + +newTensor :: DataType + -> Ptr CLong -- dimensions array + -> CInt -- num dimensions + -> Ptr () -- data + -> CULong -- data len + -> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()) -- deallocator + -> Ptr () -- deallocator arg + -> IO Tensor +newTensor dt = {# call TF_NewTensor as ^ #} (fromIntegral $ fromEnum dt) + +deleteTensor :: Tensor -> IO () +deleteTensor = {# call TF_DeleteTensor as ^ #} + +tensorType :: Tensor -> IO DataType +tensorType t = toEnum . fromIntegral <$> {# call TF_TensorType as ^ #} t + +numDims :: Tensor -> IO CInt +numDims = {# call TF_NumDims as ^ #} + +dim :: Tensor -> CInt -> IO CLong +dim = {# call TF_Dim as ^ #} + +tensorByteSize :: Tensor -> IO CULong +tensorByteSize = {# call TF_TensorByteSize as ^ #} + +tensorData :: Tensor -> IO (Ptr ()) +tensorData = {# call TF_TensorData as ^ #} + + +-- Session Options. +{# pointer *TF_SessionOptions as SessionOptions newtype #} + +newSessionOptions :: IO SessionOptions +newSessionOptions = {# call TF_NewSessionOptions as ^ #} + +setTarget :: SessionOptions -> CString -> IO () +setTarget = {# call TF_SetTarget as ^ #} + +setConfig :: SessionOptions -> Ptr () -> CULong -> Status -> IO () +setConfig = {# call TF_SetConfig as ^ #} + +deleteSessionOptions :: SessionOptions -> IO () +deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #} + + +-- Session. +{# pointer *TF_Session as Session newtype #} + +newSession :: SessionOptions -> Status -> IO Session +newSession = {# call TF_NewSession as ^ #} + +closeSession :: Session -> Status -> IO () +closeSession = {# call TF_CloseSession as ^ #} + +deleteSession :: Session -> Status -> IO () +deleteSession = {# call TF_DeleteSession as ^ #} + +extendGraph :: Session -> Ptr () -> CULong -> Status -> IO () +extendGraph = {# call TF_ExtendGraph as ^ #} + +run :: Session + -> BufferPtr -- RunOptions proto. + -> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count). + -> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count). + -> Ptr CString -> CInt -- Target nodes (names, count). + -> BufferPtr -- RunMetadata proto. + -> Status + -> IO () +run = {# call TF_Run as ^ #} + +-- FFI helpers. +type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO () +foreign import ccall "wrapper" + wrapTensorDealloc :: TensorDeallocFn -> IO (FunPtr TensorDeallocFn) + + +-- | Get the OpList of all OpDefs defined in this address space. +-- Returns a BufferPtr, ownership of which is transferred to the caller +-- (and can be freed using deleteBuffer). +-- +-- The data in the buffer will be the serialized OpList proto for ops registered +-- in this address space. +getAllOpList :: IO BufferPtr +getAllOpList = {# call TF_GetAllOpList as ^ #} + +foreign import ccall "&TF_DeleteBuffer" + deleteBuffer :: FunPtr (BufferPtr -> IO ()) diff --git a/tensorflow/src/TensorFlow/Internal/VarInt.hs b/tensorflow/src/TensorFlow/Internal/VarInt.hs new file mode 100644 index 0000000..a82bec9 --- /dev/null +++ b/tensorflow/src/TensorFlow/Internal/VarInt.hs @@ -0,0 +1,50 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE BangPatterns #-} + +{-| +Module : TensorFlow.Internal.VarInt +Description : Encoders and decoders for varint types. + +Originally taken from internal proto-lens code. +-} +module TensorFlow.Internal.VarInt + ( getVarInt + , putVarInt + ) where + +import Data.Attoparsec.ByteString as Parse +import Data.Bits +import Data.ByteString.Lazy.Builder as Builder +import Data.Monoid ((<>)) +import Data.Word (Word64) + +-- | Decode an unsigned varint. +getVarInt :: Parser Word64 +getVarInt = loop 1 0 + where + loop !s !n = do + b <- anyWord8 + let n' = n + s * fromIntegral (b .&. 127) + if (b .&. 128) == 0 + then return n' + else loop (128*s) n' + +-- | Encode a Word64. +putVarInt :: Word64 -> Builder +putVarInt n + | n < 128 = Builder.word8 (fromIntegral n) + | otherwise = Builder.word8 (fromIntegral $ n .&. 127 .|. 128) + <> putVarInt (n `shiftR` 7) diff --git a/tensorflow/src/TensorFlow/Nodes.hs b/tensorflow/src/TensorFlow/Nodes.hs new file mode 100644 index 0000000..730c9e5 --- /dev/null +++ b/tensorflow/src/TensorFlow/Nodes.hs @@ -0,0 +1,141 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +module TensorFlow.Nodes where + +import Control.Applicative (liftA2, liftA3) +import Data.Map.Strict (Map) +import Data.Monoid ((<>)) +import Data.Set (Set) +import Data.String (IsString) +import Data.Text (Text) +import Lens.Family2 ((^.)) +import qualified Data.Map.Strict as Map +import qualified Data.Set as Set +import qualified Data.Vector as V + +import TensorFlow.Build +import TensorFlow.Output +import TensorFlow.Tensor +import TensorFlow.Types +import qualified TensorFlow.Internal.FFI as FFI + +-- | Types that contain ops which can be run. +class Nodes t where + getNodes :: t -> Build (Set NodeName) + +-- | Types that tensor representations (e.g. 'Tensor', 'ControlNode') can be +-- fetched into. +-- +-- Includes collections of tensors (e.g. tuples). +class Nodes t => Fetchable t a where + getFetch :: t -> Build (Fetch a) + +-- | Fetch action. Keeps track of what needs to be fetched and how to decode +-- the fetched data. +data Fetch a = Fetch + { -- | Nodes to fetch + fetches :: Set Text + -- | Function to create an 'a' from the fetched data. + , fetchRestore :: Map Text FFI.TensorData -> a + } + +instance Functor Fetch where + fmap f (Fetch fetch restore) = Fetch fetch (f . restore) + +instance Applicative Fetch where + pure x = Fetch Set.empty (const x) + Fetch fetch restore <*> Fetch fetch' restore' = + Fetch (fetch <> fetch') (restore <*> restore') + +nodesUnion :: (Monoid b, Traversable t, Applicative f) => t (f b) -> f b +nodesUnion = fmap (foldMap id) . sequenceA + +instance (Nodes t1, Nodes t2) => Nodes (t1, t2) where + getNodes (x, y) = nodesUnion [getNodes x, getNodes y] + +instance (Nodes t1, Nodes t2, Nodes t3) => Nodes (t1, t2, t3) where + getNodes (x, y, z) = nodesUnion [getNodes x, getNodes y, getNodes z] + +instance (Fetchable t1 a1, Fetchable t2 a2) => Fetchable (t1, t2) (a1, a2) where + getFetch (x, y) = liftA2 (,) <$> getFetch x <*> getFetch y + +instance (Fetchable t1 a1, Fetchable t2 a2, Fetchable t3 a3) + => Fetchable (t1, t2, t3) (a1, a2, a3) where + getFetch (x, y, z) = + liftA3 (,,) <$> getFetch x <*> getFetch y <*> getFetch z + +instance Nodes t => Nodes [t] where + getNodes = nodesUnion . map getNodes + +instance Fetchable t a => Fetchable [t] [a] where + getFetch ts = sequenceA <$> mapM getFetch ts + +instance Nodes ControlNode where + getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o + +-- We use the constraint @(a ~ ())@ to help with type inference. For example, +-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session +-- ()@. If we used @instance Fetchable ControlNode ()@ instead, then that +-- expression would be ambiguous without explicitly specifying the return type. +instance a ~ () => Fetchable ControlNode a where + getFetch _ = return $ pure () + +instance Nodes (Tensor v a) where + getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp) + +fetchTensorList :: TensorType a => Tensor v a -> Build (Fetch (Shape, [a])) +fetchTensorList t = fmap (fmap V.toList) <$> fetchTensorVector t + +fetchTensorVector :: forall a v . TensorType a + => Tensor v a -> Build (Fetch (Shape, V.Vector a)) +fetchTensorVector (Tensor _ o) = do + outputName <- renderOutput o + return $ Fetch (Set.singleton outputName) $ \tensors -> + let tensorData = tensors Map.! outputName + shape = Shape $ FFI.tensorDataDimensions tensorData + vec = decodeTensorData $ TensorData tensorData + + expectedType = tensorType (undefined :: a) + actualType = FFI.tensorDataType tensorData + badTypeError = error $ "Bad tensor type: expected " + ++ show expectedType + ++ ", got " + ++ show actualType + in if expectedType /= actualType + then badTypeError + else (shape, vec) + +-- The constraint "a ~ a'" means that the input/output of fetch can constrain +-- the TensorType of each other. +instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (V.Vector a') where + getFetch t = fmap snd <$> fetchTensorVector t + +newtype Scalar a = Scalar {unScalar :: a} + deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat, + RealFrac, IsString) + +instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (Scalar a') where + getFetch t = fmap (Scalar . headFromSingleton . snd) <$> fetchTensorList t + where + headFromSingleton [x] = x + headFromSingleton xs + = error $ "Unable to extract singleton from tensor of length " + ++ show (length xs) diff --git a/tensorflow/src/TensorFlow/Orphans.hs b/tensorflow/src/TensorFlow/Orphans.hs new file mode 100644 index 0000000..cf185e2 --- /dev/null +++ b/tensorflow/src/TensorFlow/Orphans.hs @@ -0,0 +1,46 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + + +{-# LANGUAGE StandaloneDeriving #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} +-- Orphan instances for certain proto messages/enums, used internally. +-- TODO(judahjacobson): consider making proto-lens generate some or all of +-- these automatically; or, alternately, make new Haskell datatypes. +module TensorFlow.Orphans() where + +import Proto.Tensorflow.Core.Framework.AttrValue + ( AttrValue(..) + , AttrValue'ListValue(..) + , NameAttrList(..) + ) +import Proto.Tensorflow.Core.Framework.NodeDef + ( NodeDef(..)) +import Proto.Tensorflow.Core.Framework.ResourceHandle + ( ResourceHandle(..)) +import Proto.Tensorflow.Core.Framework.Tensor + (TensorProto(..)) +import Proto.Tensorflow.Core.Framework.TensorShape + (TensorShapeProto(..), TensorShapeProto'Dim(..)) +import Proto.Tensorflow.Core.Framework.Types (DataType(..)) + +deriving instance Ord AttrValue +deriving instance Ord AttrValue'ListValue +deriving instance Ord DataType +deriving instance Ord NameAttrList +deriving instance Ord NodeDef +deriving instance Ord ResourceHandle +deriving instance Ord TensorProto +deriving instance Ord TensorShapeProto +deriving instance Ord TensorShapeProto'Dim diff --git a/tensorflow/src/TensorFlow/Output.hs b/tensorflow/src/TensorFlow/Output.hs new file mode 100644 index 0000000..6bee40a --- /dev/null +++ b/tensorflow/src/TensorFlow/Output.hs @@ -0,0 +1,156 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE OverloadedStrings #-} + +module TensorFlow.Output + ( ControlNode(..) + , Device(..) + -- * Ops + , NodeName(..) + , Op(..) + , opUnrendered + , OpDef(..) + , opName + , opType + , opAttr + , opInputs + , opControlInputs + , OpType(..) + , OutputIx(..) + , Output(..) + , output + , outputIndex + , outputOp + , PendingNodeName(..) + ) where + +import qualified Data.Map.Strict as Map +import Data.ProtoLens.TextFormat (showMessage) +import Data.String (IsString(..)) +import Data.Text (Text) +import qualified Data.Text as Text +import Lens.Family2 (Lens', Traversal', (.~), (&), (^.)) +import Lens.Family2.Unchecked (lens) +import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..)) +import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name) +import Data.Default (def) +import TensorFlow.Types (Attribute, attrLens) +import TensorFlow.Orphans () + +-- | A type of graph node which has no outputs. These nodes are +-- valuable for causing side effects when they are run. +newtype ControlNode = ControlNode { unControlNode :: Op } + +-- | The type of op of a node in the graph. This corresponds to the proto field +-- NodeDef.op. +newtype OpType = OpType { unOpType :: Text } + deriving (Eq, Ord, Show) + +instance IsString OpType where + fromString = OpType . Text.pack + +-- | An output of a TensorFlow node. +data Output = Output !OutputIx !Op + deriving (Eq, Ord, Show) + +output :: OutputIx -> Op -> Output +output = Output + +outputOp :: Lens' Output Op +outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o) + +outputIndex :: Lens' Output OutputIx +outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o) + +newtype OutputIx = OutputIx { unOutputIx :: Int } + deriving (Eq, Ord, Num, Enum, Show) + +-- | A device that a node can be assigned to. +-- There's a naming convention where the device names +-- are constructed from job and replica names. +newtype Device = Device {deviceName :: Text} + deriving (Eq, Ord, IsString) + +instance Show Device where + show (Device d) = show d + +-- | The representation of a node in a TensorFlow graph. +data Op + = Rendered !NodeDef -- ^ Properties are fixed, including the + -- device, name, and scope. + | Unrendered !OpDef -- ^ Properties are not fixed, and may change depending + -- on which context this op is rendered in. + deriving (Eq, Ord) + +instance Show Op where + show (Rendered n) = "Rendered " ++ showMessage n + show (Unrendered o) = "Unrendered " ++ show (o ^. opName) + +-- | Traverse on the 'Unrendered' of an 'Op'. +-- +-- Same implementation as _Left. +opUnrendered :: Traversal' Op OpDef +opUnrendered f (Unrendered a) = Unrendered <$> f a +opUnrendered _ (Rendered b) = pure (Rendered b) + +-- | Op definition. This corresponds somewhat to the 'NodeDef' proto. +data OpDef = OpDef + { _opName :: !PendingNodeName + , _opType :: !OpType + , _opAttrs :: !(Map.Map Text AttrValue) + , _opInputs :: [Output] + , _opControlInputs :: [NodeName] + } deriving (Eq, Ord) + +-- | The name specified for an unrendered Op. If an Op has an +-- ImplicitName, it will be assigned based on the opType plus a +-- unique identifier. Does not contain the "scope" prefix. +data PendingNodeName = ExplicitName !Text | ImplicitName + deriving (Eq, Ord, Show) + +-- | The name of a node in the graph. This corresponds to the proto field +-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier +-- (if the node was implicitly named). +newtype NodeName = NodeName { unNodeName :: Text } + deriving (Eq, Ord, Show) + +opName :: Lens' OpDef PendingNodeName +opName = lens _opName (\o x -> o {_opName = x}) + +opType :: Lens' OpDef OpType +opType = lens _opType (\o x -> o { _opType = x}) + +opAttr :: Attribute a => Text -> Lens' OpDef a +opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x}) + . lens (Map.findWithDefault def n) (flip (Map.insert n)) + . attrLens + +opInputs :: Lens' OpDef [Output] +opInputs = lens _opInputs (\o x -> o {_opInputs = x}) + +opControlInputs :: Lens' OpDef [NodeName] +opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x}) + +-- TODO(gnezdo): IsString instance is weird and we should move that +-- code into a Build function +instance IsString Output where + fromString s = case break (==':') s of + (n, ':':ixStr) + | [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n + _ -> Output 0 $ assigned s + where assigned n = Rendered $ def & name .~ Text.pack n + diff --git a/tensorflow/src/TensorFlow/Session.hs b/tensorflow/src/TensorFlow/Session.hs new file mode 100644 index 0000000..39d5b18 --- /dev/null +++ b/tensorflow/src/TensorFlow/Session.hs @@ -0,0 +1,202 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} + +module TensorFlow.Session ( + Session, + -- * Opaque value created via 'sessionConfig' and 'sessionTarget'. + SessionOption, + sessionConfig, + sessionTarget, + runSession, + runSessionWithOptions, + build, + buildAnd, + buildWithSummary, + extend, + addGraphDef, + run, + runWithFeeds, + run_, + runWithFeeds_, + asyncProdNodes, + ) where + +import Control.Monad (forever, unless, void) +import Control.Monad.IO.Class (MonadIO, liftIO) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Reader (ReaderT(..), ask, asks) +import Data.ByteString (ByteString) +import Data.Functor.Identity (runIdentity) +import qualified Data.Map.Strict as Map +import qualified Data.Set as Set +import Data.Set (Set) +import Data.Text.Encoding (encodeUtf8) +import Data.ProtoLens (def) +import Lens.Family2 ((&), (.~)) +import Proto.Tensorflow.Core.Framework.Graph (node) +import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) + +import TensorFlow.Build +import qualified TensorFlow.Internal.FFI as FFI +import qualified TensorFlow.Internal.Raw as Raw +import TensorFlow.Nodes +import TensorFlow.Output (NodeName, unNodeName) +import TensorFlow.Tensor + +-- Common state threaded through the session. +data SessionState + = SessionState { + rawSession :: FFI.Session + , asyncCollector :: IO () -> IO () + -- ^ Starts the given action concurrently. + } + +newtype Session a + = Session (ReaderT SessionState (BuildT IO) a) + deriving (Functor, Applicative, Monad, MonadIO) + +-- | Run 'Session' actions in a new TensorFlow session. +runSession :: Session a -> IO a +runSession = runSessionWithOptions [] + +-- | Setting of an option for the session (see 'runSessionWithOptions'). +newtype SessionOption = + SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () } + +-- | Target can be: "local", ip:port, host:port. +-- The set of supported factories depends on the linked in libraries. +-- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary. +sessionTarget :: ByteString -> SessionOption +sessionTarget = SessionOption . FFI.setSessionTarget + +-- | Uses the specified config for the created session. +sessionConfig :: ConfigProto -> SessionOption +sessionConfig = SessionOption . FFI.setSessionConfig + +-- | Run 'Session' actions in a new TensorFlow session created with +-- the given option setter actions ('sessionTarget', 'sessionConfig'). +runSessionWithOptions :: [SessionOption] -> Session a -> IO a +runSessionWithOptions options (Session m) = + FFI.withSession applyOptions $ + \as rs -> evalBuildT (runReaderT m (SessionState rs as)) + where applyOptions opt = mapM_ (`unSesssionOption` opt) options + +-- | Lift a 'Build' action into a 'Session', including any explicit op +-- renderings. +build :: Build a -> Session a +build = Session . lift . hoistBuildT (return . runIdentity) + +-- | Lift a 'Build' action into a 'Session', including any explicit op +-- renderings. Returns the merged summary ops which can be used for +-- logging, see 'TensorFlow.Logging.build' for a convenient wrapper. +buildWithSummary :: forall a . Build a -> Session (a, [SummaryTensor]) +buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries + where v :: BuildT IO a + v = hoistBuildT (return . runIdentity) b + +-- | Add all pending rendered nodes to the TensorFlow graph and runs +-- any pending initializers. +-- +-- Note that run, runWithFeeds, etc. will all call this function implicitly. +extend :: Session () +extend = do + let withSessionWhen vs action = + unless (null vs) $ Session (asks rawSession) >>= action + nodesToExtend <- build flushNodeBuffer + withSessionWhen nodesToExtend $ \session -> + liftIO $ FFI.extendGraph session + $ def & node .~ nodesToExtend + -- Now that all the nodes are created, run the initializers. + initializers <- build flushInitializers + withSessionWhen initializers $ \session -> + void $ liftIO $ FFI.run session [] [] (toNodeNames initializers) + +-- | Helper combinator for doing something with the result of a 'Build' action. +-- Example usage: +-- +-- > buildAnd run :: Fetchable t a => Build t -> Session a +buildAnd :: (a -> Session b) -> Build a -> Session b +buildAnd f m = build m >>= f + +-- | Run a subgraph 't', rendering any dependent nodes that aren't already +-- rendered, and fetch the corresponding values for 'a'. +run :: Fetchable t a => t -> Session a +run = runWithFeeds [] + +-- | Run a subgraph 't', rendering any dependent nodes that aren't already +-- rendered, feed the given input values, and fetch the corresponding result +-- values for 'a'. +runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a +runWithFeeds feeds t = do + ns <- build $ getNodes t + -- Note that this call to "fetch" shouldn't affect the following "extend" + -- call, since all nodes in t and its inputs/deps will be rendered by the + -- above call to getNodes. + fetch <- build $ getFetch t + runFetchWithFeeds feeds ns fetch + +runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a +runFetchWithFeeds feeds target (Fetch fetch restore) = do + extend + feeds' <- build $ fixFeeds feeds + let fetchNames = encodeUtf8 <$> Set.toList fetch + targetNames = toNodeNames $ Set.toList target + session <- Session (asks rawSession) + runResult <- liftIO $ FFI.run session + feeds' + fetchNames + targetNames + let resultTensorsMap = Map.fromList $ zip (Set.toList fetch) runResult + return $ restore resultTensorsMap + +toNodeNames :: [NodeName] -> [ByteString] +toNodeNames = map (encodeUtf8 . unNodeName) + +-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't +-- already rendered. This behaves like 'run' except that it doesn't do any +-- fetches. +run_ :: Nodes t => t -> Session () +run_ = runWithFeeds_ [] + +-- | Run a subgraph 't', rendering any dependent nodes that aren't already +-- rendered, feed the given input values, and fetch the corresponding result +-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do +-- any fetches. +runWithFeeds_ :: Nodes t => [Feed] -> t -> Session () +runWithFeeds_ feeds t = do + ns <- build $ getNodes t + runFetchWithFeeds feeds ns (pure ()) + +fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)] +fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o + +-- | Starts a concurrent thread which evaluates the given Nodes +-- forever until runSession exits or an exception occurs. Graph +-- extension happens synchronously, but the resultant run proceeds as +-- a separate thread. +asyncProdNodes :: Nodes t + => t -- ^ Node to evaluate concurrently. + -> Session () +asyncProdNodes nodes = do + target <- build (getNodes nodes) + extend + let targetNames = toNodeNames $ Set.toList target + state <- Session ask + let loop = forever (void (FFI.run (rawSession state) [] [] targetNames)) + liftIO (asyncCollector state loop) diff --git a/tensorflow/src/TensorFlow/Tensor.hs b/tensorflow/src/TensorFlow/Tensor.hs new file mode 100644 index 0000000..da03184 --- /dev/null +++ b/tensorflow/src/TensorFlow/Tensor.hs @@ -0,0 +1,85 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE Rank2Types #-} + +module TensorFlow.Tensor where + +import Data.String (IsString(..)) +import qualified Data.Text as Text +import Lens.Family2 (Lens', Traversal') +import Lens.Family2.Unchecked (lens) + +import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr) +import TensorFlow.Types (TensorData(..), Attribute) +import qualified TensorFlow.Internal.FFI as FFI + +-- | A named output of a TensorFlow operation. +-- +-- The type parameter @a@ is the type of the elements in the 'Tensor'. The +-- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is +-- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a +-- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via +-- 'value'. +data Tensor v a = Tensor (TensorKind v) Output + +data Value +data Ref + +-- | This class provides a runtime switch on whether a 'Tensor' should be +-- treated as a 'Value' or as a 'Ref'. +data TensorKind v where + ValueKind :: TensorKind Value + RefKind :: TensorKind Ref + +tensorKind :: Lens' (Tensor v a) (TensorKind v) +tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o) + +tensorOutput :: Lens' (Tensor v a) Output +tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o) + +-- TODO: Come up with a better API for handling attributes. +-- | Lens for the attributes of a tensor. +-- +-- Only valid if the tensor has not yet been rendered. If the tensor has been +-- rendered, the traversal will be over nothing (nothing can be read or +-- written). +tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr +tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x + +-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a +-- Ref into Value. This behaves like a no-op. +value :: Tensor v a -> Tensor Value a +value (Tensor _ o) = Tensor ValueKind o + +-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor' +-- when running the graph. +data Feed = Feed Output FFI.TensorData + +-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running +-- the graph. +-- +-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the +-- rendered 'Tensor' may be different than feeding the original 'Tensor'. +feed :: Tensor v a -> TensorData a -> Feed +feed (Tensor _ o) (TensorData td) = Feed o td + +-- | Create a 'Tensor' for a given name. This can be used to reference nodes +-- in a 'GraphDef' that was loaded via 'addGraphDef'. +-- TODO(judahjacobson): add more safety checks here. +tensorFromName :: TensorKind v -> Text.Text -> Tensor v a +tensorFromName v = Tensor v . fromString . Text.unpack diff --git a/tensorflow/src/TensorFlow/Types.hs b/tensorflow/src/TensorFlow/Types.hs new file mode 100644 index 0000000..3d47f39 --- /dev/null +++ b/tensorflow/src/TensorFlow/Types.hs @@ -0,0 +1,382 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +-- We use UndecidableInstances for type families with recursive definitions +-- like "\\". Those instances will terminate since each equation unwraps one +-- cons cell of a type-level list. +{-# LANGUAGE UndecidableInstances #-} + +module TensorFlow.Types + ( TensorType(..) + , TensorData(..) + , Shape(..) + , protoShape + , Attribute(..) + -- * Type constraints + , OneOf + , type (/=) + -- ** Implementation of constraints + , TypeError + , ExcludedCase + , TensorTypes + , NoneOf + , type (\\) + , Delete + , AllTensorTypes + ) where + +import Data.Complex (Complex) +import Data.Default (def) +import Data.Int (Int8, Int16, Int32, Int64) +import Data.Monoid ((<>)) +import Data.Word (Word8, Word16, Word64) +import Foreign.Storable (Storable) +import GHC.Exts (Constraint, IsList(..)) +import Lens.Family2 (Lens', view, (&), (.~)) +import Lens.Family2.Unchecked (iso) +import qualified Data.Attoparsec.ByteString as Atto +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import Data.ByteString.Builder (Builder) +import qualified Data.ByteString.Builder as Builder +import qualified Data.ByteString.Lazy as L +import qualified Data.Vector as V +import qualified Data.Vector.Storable as S +import Proto.Tensorflow.Core.Framework.AttrValue + ( AttrValue(..) + , AttrValue'ListValue(..) + , b + , f + , i + , s + , list + , type' + , shape + , tensor + ) +import Proto.Tensorflow.Core.Framework.Tensor as Tensor + ( TensorProto(..) + , floatVal + , doubleVal + , intVal + , stringVal + , int64Val + , stringVal + , boolVal + ) +import Proto.Tensorflow.Core.Framework.TensorShape + ( TensorShapeProto(..) + , dim + , size + ) +import Proto.Tensorflow.Core.Framework.Types (DataType(..)) + +import TensorFlow.Internal.VarInt (getVarInt, putVarInt) +import qualified TensorFlow.Internal.FFI as FFI + +-- | Data about a tensor that is encoded for the TensorFlow APIs. +newtype TensorData a = TensorData { unTensorData :: FFI.TensorData } + +-- | The class of scalar types supported by tensorflow. +class TensorType a where + tensorType :: a -> DataType + tensorRefType :: a -> DataType + tensorVal :: Lens' TensorProto [a] + -- | Decode the bytes of a TensorData into a Vector. + decodeTensorData :: TensorData a -> V.Vector a + -- | Encode a Vector into a TensorData. + -- + -- The values should be in row major order, e.g., + -- + -- element 0: index (0, ..., 0) + -- element 1: index (0, ..., 1) + -- ... + encodeTensorData :: Shape -> V.Vector a -> TensorData a + +-- All types, besides ByteString, are encoded as simple arrays and we can use +-- Vector.Storable to encode/decode by type casting pointers. + +-- TODO(fmayle): Assert that the data type matches the return type. +simpleDecode :: Storable a => TensorData a -> V.Vector a +simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData + +simpleEncode :: forall a . (TensorType a, Storable a) + => Shape -> V.Vector a -> TensorData a +simpleEncode (Shape xs) + = TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert + where + dt = tensorType (undefined :: a) + +instance TensorType Float where + tensorType _ = DT_FLOAT + tensorRefType _ = DT_FLOAT_REF + tensorVal = floatVal + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorType Double where + tensorType _ = DT_DOUBLE + tensorRefType _ = DT_DOUBLE_REF + tensorVal = doubleVal + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorType Int32 where + tensorType _ = DT_INT32 + tensorRefType _ = DT_INT32_REF + tensorVal = intVal + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorType Int64 where + tensorType _ = DT_INT64 + tensorRefType _ = DT_INT64_REF + tensorVal = int64Val + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +integral :: Integral a => Lens' [Int32] [a] +integral = iso (fmap fromIntegral) (fmap fromIntegral) + +instance TensorType Word8 where + tensorType _ = DT_UINT8 + tensorRefType _ = DT_UINT8_REF + tensorVal = intVal . integral + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorType Word16 where + tensorType _ = DT_UINT16 + tensorRefType _ = DT_UINT16_REF + tensorVal = intVal . integral + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorType Int16 where + tensorType _ = DT_INT16 + tensorRefType _ = DT_INT16_REF + tensorVal = intVal . integral + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorType Int8 where + tensorType _ = DT_INT8 + tensorRefType _ = DT_INT8_REF + tensorVal = intVal . integral + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorType ByteString where + tensorType _ = DT_STRING + tensorRefType _ = DT_STRING_REF + tensorVal = stringVal + -- Encoded data layout (described in third_party/tensorflow/c/c_api.h): + -- table offsets for each element :: [Word64] + -- at each element offset: + -- string length :: VarInt64 + -- string data :: [Word8] + -- TODO(fmayle): Benchmark these functions. + decodeTensorData tensorData = + either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $ + if expected /= count + then Left $ "decodeTensorData for ByteString count mismatch " ++ + show (expected, count) + else V.mapM decodeString (S.convert offsets) + where + expected = S.length offsets + count = fromIntegral $ product $ FFI.tensorDataDimensions + $ unTensorData tensorData + bytes = FFI.tensorDataBytes $ unTensorData tensorData + offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64 + dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes + decodeString :: Word64 -> Either String ByteString + decodeString offset = + let stringDataStart = B.drop (fromIntegral offset) dataBytes + in Atto.eitherResult $ Atto.parse stringParser stringDataStart + stringParser :: Atto.Parser ByteString + stringParser = getVarInt >>= Atto.take . fromIntegral + encodeTensorData (Shape xs) vec = + TensorData $ FFI.TensorData xs dt byteVector + where + dt = tensorType (undefined :: ByteString) + -- Add a string to an offset table and data blob. + addString :: (Builder, Builder, Word64) + -> ByteString + -> (Builder, Builder, Word64) + addString (table, strings, offset) str = + ( table <> Builder.word64LE offset + , strings <> lengthBytes <> Builder.byteString str + , offset + lengthBytesLen + strLen + ) + where + strLen = fromIntegral $ B.length str + lengthBytes = putVarInt $ fromIntegral $ B.length str + lengthBytesLen = + fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes + -- Encode all strings. + (table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec + -- Concat offset table with data. + bytes = table' <> strings' + -- Convert to Vector Word8. + byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes + + +instance TensorType Bool where + tensorType _ = DT_BOOL + tensorRefType _ = DT_BOOL_REF + tensorVal = boolVal + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorType (Complex Float) where + tensorType _ = DT_COMPLEX64 + tensorRefType _ = DT_COMPLEX64 + tensorVal = error "TODO (Complex Float)" + decodeTensorData = error "TODO (Complex Float)" + encodeTensorData = error "TODO (Complex Float)" + +instance TensorType (Complex Double) where + tensorType _ = DT_COMPLEX128 + tensorRefType _ = DT_COMPLEX128 + tensorVal = error "TODO (Complex Double)" + decodeTensorData = error "TODO (Complex Double)" + encodeTensorData = error "TODO (Complex Double)" + +-- | Shape (dimensions) of a tensor. +newtype Shape = Shape [Int64] deriving Show + +instance IsList Shape where + type Item Shape = Int64 + fromList = Shape . fromList + toList (Shape ss) = toList ss + +protoShape :: Lens' TensorShapeProto Shape +protoShape = iso protoToShape shapeToProto + where + protoToShape = Shape . fmap (view size) . view dim + shapeToProto (Shape ds) = def & dim .~ fmap (\d -> def & size .~ d) ds + + +class Attribute a where + attrLens :: Lens' AttrValue a + +instance Attribute Float where + attrLens = f + +instance Attribute ByteString where + attrLens = s + +instance Attribute Int64 where + attrLens = i + +instance Attribute DataType where + attrLens = type' + +instance Attribute TensorProto where + attrLens = tensor + +instance Attribute Bool where + attrLens = b + +instance Attribute Shape where + attrLens = shape . protoShape + +-- TODO(gnezdo): support generating list(Foo) from [Foo]. +instance Attribute AttrValue'ListValue where + attrLens = list + +instance Attribute [DataType] where + attrLens = list . type' + +instance Attribute [Int64] where + attrLens = list . i + +-- | A 'Constraint' specifying the possible choices of a 'TensorType'. +-- +-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the +-- natural representation as a conjunction, i.e., +-- +-- @ +-- a == Double || a == Float +-- @ +-- +-- into a disjunction like +-- +-- @ +-- a \/= Int32 && a \/= Int64 && a \/= ByteString && ... +-- @ +-- +-- using an enumeration of all the possible 'TensorType's. +type OneOf ts a + = (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a) + +-- | A 'Constraint' checking that the input is a list of 'TensorType's. +-- Helps improve error messages when using 'OneOf'. +type family TensorTypes ts :: Constraint where + TensorTypes '[] = () + TensorTypes (t ': ts) = (TensorType t, TensorTypes ts) + +-- | A constraint checking that two types are different. +type family a /= b :: Constraint where + a /= a = TypeError a ~ ExcludedCase + a /= b = () + +-- | Helper types to produce a reasonable type error message when the Constraint +-- "a /= a" fails. +-- TODO(judahjacobson): Use ghc-8's CustomTypeErrors for this. +data TypeError a +data ExcludedCase + +-- | An enumeration of all valid 'TensorType's. +type AllTensorTypes = + -- NOTE: This list should be kept in sync with + -- TensorFlow.OpGen.dtTypeToHaskell. + -- TODO: Add support for Complex Float/Double. + '[ Float + , Double + , Int8 + , Int16 + , Int32 + , Int64 + , Word8 + , Word16 + , ByteString + , Bool + ] + +-- | Removes a type from the given list of types. +type family Delete a as where + Delete a '[] = '[] + Delete a (a ': as) = Delete a as + Delete a (b ': as) = b ': Delete a as + +-- | Takes the difference of two lists of types. +type family as \\ bs where + as \\ '[] = as + as \\ b ': bs = Delete b as \\ bs + +-- | A constraint that the type @a@ doesn't appear in the type list @ts@. +-- Assumes that @a@ and each of the elements of @ts@ are 'TensorType's. +type family NoneOf ts a :: Constraint where + NoneOf '[] a = () + NoneOf (t ': ts) a = (a /= t, NoneOf ts a) diff --git a/tensorflow/tensorflow.cabal b/tensorflow/tensorflow.cabal new file mode 100644 index 0000000..b7ccdf3 --- /dev/null +++ b/tensorflow/tensorflow.cabal @@ -0,0 +1,84 @@ +name: tensorflow +version: 0.1.0.0 +synopsis: TensorFlow bindings. +description: Please see README.md +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Simple +cabal-version: >=1.22 + +library + hs-source-dirs: src + exposed-modules: TensorFlow.Build + , TensorFlow.BuildOp + , TensorFlow.ControlFlow + , TensorFlow.Internal.FFI + , TensorFlow.Nodes + , TensorFlow.Output + , TensorFlow.Session + , TensorFlow.Tensor + , TensorFlow.Types + , TensorFlow.Internal.VarInt + other-modules: TensorFlow.Internal.Raw + , TensorFlow.Orphans + build-tools: c2hs + build-depends: proto-lens == 0.1.* + -- Used by the custom Setup script (for the test-suite). + , proto-lens-protoc == 0.1.* + , tensorflow-proto == 0.1.* + , base >= 4.7 && < 5 + , async + , attoparsec + , bytestring + , containers + , data-default + , fgl + , lens-family + , mainland-pretty + , mtl + , semigroups + , split + , text + , temporary + , transformers + , vector + extra-libraries: tensorflow_c + default-language: Haskell2010 + include-dirs: . + +Test-Suite FFITest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: FFITest.hs + hs-source-dirs: tests + build-depends: HUnit + , base + , bytestring + , lens-family + , proto-lens + , tensorflow + , tensorflow-proto + , test-framework + , test-framework-hunit + + +Test-Suite VarIntTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: VarIntTest.hs + hs-source-dirs: tests + build-depends: base + , attoparsec + , bytestring + , google-shim + , tensorflow + , test-framework + , test-framework-quickcheck2 + +source-repository head + type: git + location: https://github.com/tensorflow/haskell diff --git a/tensorflow/tests/FFITest.hs b/tensorflow/tests/FFITest.hs new file mode 100644 index 0000000..28cc6db --- /dev/null +++ b/tensorflow/tests/FFITest.hs @@ -0,0 +1,38 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- | Tests for FFI. + +module Main where + +import Data.ProtoLens (decodeMessage) +import Lens.Family2 (view) +import TensorFlow.Internal.FFI (getAllOpList) +import Test.HUnit (assertBool, assertFailure) +import Test.Framework (defaultMain) +import Test.Framework.Providers.HUnit (testCase) +import Proto.Tensorflow.Core.Framework.OpDef (OpList, op) + +testParseAll :: IO () +testParseAll = do + opList <- getAllOpList + either + assertFailure + (assertBool "Expected non-empty list of default Ops" + . not . null . view op) + (decodeMessage opList :: Either String OpList) + +main = defaultMain + [ testCase "ParseAllOps" testParseAll + ] diff --git a/tensorflow/tests/VarIntTest.hs b/tensorflow/tests/VarIntTest.hs new file mode 100644 index 0000000..fad63b0 --- /dev/null +++ b/tensorflow/tests/VarIntTest.hs @@ -0,0 +1,32 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +module Main where + +import Data.ByteString.Builder (toLazyByteString) +import Google.Test (googleTest) +import Test.Framework.Providers.QuickCheck2 (testProperty) +import qualified Data.Attoparsec.ByteString.Lazy as Atto + +import TensorFlow.Internal.VarInt + +testEncodeDecode = testProperty "testEncodeDecode" $ \x -> + let bytes = toLazyByteString (putVarInt x) + in case Atto.eitherResult $ Atto.parse getVarInt bytes of + Left _ -> False + Right y -> x == y + +main :: IO () +main = googleTest [ testEncodeDecode + ] diff --git a/tensorflow/third_party b/tensorflow/third_party new file mode 120000 index 0000000..20e9ecd --- /dev/null +++ b/tensorflow/third_party @@ -0,0 +1 @@ +../third_party/tensorflow \ No newline at end of file diff --git a/third_party/tensorflow b/third_party/tensorflow new file mode 160000 index 0000000..bac7faa --- /dev/null +++ b/third_party/tensorflow @@ -0,0 +1 @@ +Subproject commit bac7faa9a3eb5b60687a83336202cd3493de5385