1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-22 19:09:43 +01:00

Initial commit

This commit is contained in:
Greg Steuck 2016-10-24 19:26:42 +00:00
commit 67690d1499
67 changed files with 6400 additions and 0 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
**/.stack-work
.stack/

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "third_party/tensorflow"]
path = third_party/tensorflow
url = https://github.com/tensorflow/tensorflow.git

25
CONTRIBUTING.md Normal file
View file

@ -0,0 +1,25 @@
Want to contribute? Great! First, read this page (including the small print at the end).
### Before you contribute
Before we can use your code, you must sign the
[Google Individual Contributor License Agreement](https://cla.developers.google.com/about/google-individual)
(CLA), which you can do online. The CLA is necessary mainly because you own the
copyright to your changes, even after your contribution becomes part of our
codebase, so we need your permission to use and distribute your code. We also
need to be sure of various other things—for instance that you'll tell us if you
know that your code infringes on other people's patents. You don't have to sign
the CLA until after you've submitted your code for review and a member has
approved it, but you must do it before we can put your code into our codebase.
Before you start working on a larger contribution, you should get in touch with
us first through the issue tracker with your idea so that we can help out and
possibly guide you. Coordinating up front makes it much easier to avoid
frustration later on.
### Code reviews
All submissions, including submissions by project members, require review. We
use Github pull requests for this purpose.
### The small print
Contributions made by corporations are covered by a different agreement than
the one above, the
[Software Grant and Corporate Contributor License Agreement](https://cla.developers.google.com/about/google-corporate).

203
LICENSE Normal file
View file

@ -0,0 +1,203 @@
Copyright 2016 The TensorFlow Authors. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2016, The TensorFlow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

24
README.md Normal file
View file

@ -0,0 +1,24 @@
The tensorflow-haskell package provides Haskell bindings to
[TensorFlow](https://www.tensorflow.org/).
This is not an official Google product.
# Instructions
## Build
For now [docker](https://www.docker.com/) is required. Once you have docker
working, the following commands will compile and run the tests.
git clone --recursive https://github.com/tensorflow/haskell.git tensorflow-haskell
cd tensorflow-haskell
IMAGE_NAME=tensorflow/haskell:v0
docker build -t $IMAGE_NAME docker
# TODO: move the setup step to the docker script.
stack --docker --docker-image=$IMAGE_NAME setup
stack --docker --docker-image=$IMAGE_NAME test
There is also a demo application:
cd tensorflow-mnist
stack --docker --docker-image=$IMAGE_NAME build --exec Main

29
docker/Dockerfile Normal file
View file

@ -0,0 +1,29 @@
# Prepare the image with:
# docker build -t gnezdo/tfhs:2016-10-14-1 docker
FROM gcr.io/tensorflow/tensorflow:latest-devel
MAINTAINER Greg Steuck <gnezdo+tfhs@google.com>
# Installs protoc and the libraries.
RUN \
cd /tensorflow && \
bazel --batch build -c opt '@protobuf//:protoc' && \
install -s bazel-bin/external/protobuf/protoc /usr/local/bin && \
bazel --batch build -c opt '//tensorflow:libtensorflow_c.so' && \
install bazel-bin/tensorflow/libtensorflow_c.so /usr/local/lib && \
bazel --batch clean
RUN apt-get update
RUN apt-get install -y \
# Avoids /usr/bin/ld: cannot find -ltinfo
libncurses5-dev \
# Makes stack viable in the container
libgmp-dev \
# Required for locales configuration.
locales
# Our MNIST demo program outputs Unicode characters.
RUN dpkg-reconfigure locales && \
locale-gen en_US.UTF-8 && \
update-locale LANG=en_US.UTF-8
ENV LANG en_US.UTF-8

3
google-shim/Setup.hs Normal file
View file

@ -0,0 +1,3 @@
import Distribution.Simple
main = defaultMain

View file

@ -0,0 +1,23 @@
name: google-shim
version: 0.1.0.0
synopsis: Adapters to externalize TensorFlow code.
description: Please see README.md
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Simple
cabal-version: >=1.22
library
hs-source-dirs: src
exposed-modules: Google.Test
build-depends: base >= 4.7 && < 5
, test-framework
default-language: Haskell2010
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View file

@ -0,0 +1,22 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | Alternative implementations to make dependent code work without
-- changes.
module Google.Test where
import Test.Framework (Test, defaultMain)
googleTest :: [Test] -> IO ()
googleTest = defaultMain

24
stack.yaml Normal file
View file

@ -0,0 +1,24 @@
resolver: lts-6.2
packages:
- google-shim
- tensorflow
- tensorflow-core-ops
- tensorflow-opgen
- tensorflow-ops
- tensorflow-proto
- tensorflow-mnist
- tensorflow-mnist-input-data
- tensorflow-queue
extra-deps:
# proto-lens is not yet in Stackage.
- proto-lens-0.1.0.4
- proto-lens-protoc-0.1.0.4
# Allow our custom Setup.hs scripts to import Data.ProtoLens.Setup from the version of
# `proto-lens-protoc` in stack's local DB. See:
# https://github.com/google/proto-lens/blob/master/README.md#using-cabal
explicit-setup-deps:
"*": true

View file

@ -0,0 +1,100 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | Generates the wrappers for Ops shipped with tensorflow_c.
module Main where
import Distribution.Simple.BuildPaths (autogenModulesDir)
import Distribution.Simple.LocalBuildInfo (LocalBuildInfo)
import Distribution.Simple
( defaultMainWithHooks
, simpleUserHooks
, UserHooks(..)
)
import Data.List (intercalate)
import Data.ProtoLens (decodeMessage)
import System.Directory (createDirectoryIfMissing)
import System.Exit (exitFailure)
import System.FilePath ((</>))
import System.IO (hPutStrLn, stderr)
import TensorFlow.Internal.FFI (getAllOpList)
import TensorFlow.OpGen (docOpList, OpGenFlags(..))
import Text.PrettyPrint.Mainland (prettyLazyText)
import qualified Data.Text.Lazy.IO as Text
main = defaultMainWithHooks generatingOpsWrappers
-- TODO: Generalize for user libraries by replacing getAllOpList with
-- a wrapper around TF_LoadLibrary. The complicated part is interplay
-- between bazel and Haskell build system.
generatingOpsWrappers :: UserHooks
generatingOpsWrappers = hooks
{ buildHook = \p l h f -> generateSources l >> buildHook hooks p l h f
, haddockHook = \p l h f -> generateSources l >> haddockHook hooks p l h f
, replHook = \p l h f args -> generateSources l
>> replHook hooks p l h f args
}
where
flagsBuilder dir = OpGenFlags
{ outputFile = dir </> "Core.hs"
, prefix = "TensorFlow.GenOps"
, excludeList = intercalate "," blackList
}
hooks = simpleUserHooks
generateSources :: LocalBuildInfo -> IO ()
generateSources l = do
let dir = autogenModulesDir l </> "TensorFlow/GenOps"
createDirectoryIfMissing True dir
let flags = flagsBuilder dir
pb <- getAllOpList
case decodeMessage pb of
Left e -> hPutStrLn stderr e >> exitFailure
Right x -> Text.writeFile (outputFile flags)
(prettyLazyText 80 $ docOpList flags x)
blackList =
-- A few data flow ops take a list of heterogeneous
-- parameters which we don't support in general form.
[ "HashTable"
, "MutableDenseHashTable"
, "MutableHashTable"
, "MutableHashTableOfTensors"
, "QueueDequeue"
, "QueueDequeueMany"
, "QueueDequeueUpTo"
, "Stack"
, "TensorArray"
-- These should be possible to support by adding a bunch of
-- overloads with a variable number of tuple arguments.
, "Assert"
, "BarrierTakeMany"
, "Print"
, "QueueEnqueue"
, "QueueEnqueueMany"
-- These have type ambiguities because one of the type arguments
-- doesn't appear in the signature.
, "ConditionalAccumulator"
, "SparseConditionalAccumulator"
-- Need list of types support.
, "DecodeCSV"
, "ParseExample"
, "ParseSingleSequenceExample"
, "Save"
, "SaveSlices"
, "SymbolicGradient"
, "_ArrayToList"
, "_ListToArray"
-- Easy: support larger result tuples.
, "Skipgram"
]

View file

@ -0,0 +1,30 @@
name: tensorflow-core-ops
version: 0.1.0.0
synopsis: Haskell wrappers for Core Tensorflow Ops.
description: Code generated signatures for the Ops in libtensorflow_c.
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Custom
cabal-version: >=1.22
library
exposed-modules: TensorFlow.GenOps.Core
build-depends: Cabal >= 1.22 && < 1.25
, bytestring
, proto-lens == 0.1.*
, tensorflow-opgen == 0.1.*
, tensorflow == 0.1.*
, base >= 4.7 && < 5
, filepath
, mainland-pretty
, lens-family
, text
default-language: Haskell2010
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View file

@ -0,0 +1,113 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE LambdaCase #-}
-- | Downloads the MNIST data set and packages them as data files.
module Main where
import Control.Monad (when)
import Data.Maybe (fromMaybe)
import Distribution.PackageDescription
( GenericPackageDescription(packageDescription)
, dataDir
)
import Distribution.Simple
( UserHooks(..)
, defaultMainWithHooks
, simpleUserHooks
)
import System.IO (hPutStrLn, stderr)
import System.FilePath ((</>))
import System.Directory (doesFileExist)
import qualified Crypto.Hash as Hash
import qualified Data.ByteString.Lazy as B
import qualified Network.HTTP as HTTP
import qualified Network.URI as URI
main :: IO ()
main = defaultMainWithHooks downloadingDataFiles
downloadingDataFiles :: UserHooks
downloadingDataFiles = hooks
{ confHook = \gh@(g, _) c -> downloadFiles g >> confHook hooks gh c
}
where
hooks = simpleUserHooks
downloadFiles :: GenericPackageDescription -> IO ()
downloadFiles g = do
let dir = dataDir (packageDescription g)
mapM_ (maybeDownload dir) fileInfos
maybeDownload :: FilePath -> (String, String) -> IO ()
maybeDownload dataDir (basename, sha256) = do
let filePath = dataDir </> basename
exists <- doesFileExist filePath
when (not exists) $ do
let url = urlPrefix ++ basename
hPutStrLn stderr ("Downloading " ++ url)
httpDownload url filePath
verify filePath sha256
httpDownload :: String -> FilePath -> IO ()
httpDownload url outFile = do
let uri = fromMaybe
(error ("Can't be: invalid URI " ++ url))
(URI.parseURI url)
result <- HTTP.simpleHTTP (HTTP.defaultGETRequest_ uri)
HTTP.getResponseCode result >>= \case
(2, 0, 0) -> HTTP.getResponseBody result >>= B.writeFile outFile
s -> error ( "Failed to download " ++ url ++ " error code " ++ show s
++ helpfulMessage
)
verify :: FilePath -> String -> IO ()
verify filePath hash = do
let sha256 = Hash.hashlazy :: B.ByteString -> Hash.Digest Hash.SHA256
computed <- show . sha256 <$> B.readFile filePath
when (hash /= computed) $
error ( "Incorrect checksum for " ++ filePath
++ "\nexpected " ++ hash
++ "\ncomputed " ++ computed
++ helpfulMessage
)
urlPrefix = "http://yann.lecun.com/exdb/mnist/"
-- | File names relative to 'urlPrefix' and their sha256.
fileInfos = [
( "train-images-idx3-ubyte.gz"
, "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
)
,
( "train-labels-idx1-ubyte.gz"
, "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
)
,
( "t10k-images-idx3-ubyte.gz"
, "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
)
,
( "t10k-labels-idx1-ubyte.gz"
, "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
)
]
helpfulMessage =
unlines
( ""
: ""
: "Please download the following URLs manually and put them in data/"
: [ urlPrefix ++ h | (h, _) <- fileInfos ]
)

View file

@ -0,0 +1,31 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
module TensorFlow.Examples.MNIST.InputData
( trainingImageData
, trainingLabelData
, testImageData
, testLabelData
) where
import Paths_tensorflow_mnist_input_data (getDataFileName)
-- | Download the files containing the canonical MNIST samples and labels.
trainingImageData, trainingLabelData :: IO FilePath
trainingImageData = getDataFileName "train-images-idx3-ubyte.gz"
trainingLabelData = getDataFileName "train-labels-idx1-ubyte.gz"
testImageData, testLabelData :: IO FilePath
testImageData = getDataFileName "t10k-images-idx3-ubyte.gz"
testLabelData = getDataFileName "t10k-labels-idx1-ubyte.gz"

View file

@ -0,0 +1,35 @@
name: tensorflow-mnist-input-data
version: 0.1.0.0
synopsis: Downloader of input data for training MNIST.
description: Please see README.md
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Custom
cabal-version: >=1.22
-- These files are downloaded automatically by Setup.hs. If the
-- automatic download fails, follow the instructions in error messages
-- displayed by Setup.hs.
data-dir: data
data-files: *.gz
library
hs-source-dirs: src
exposed-modules: TensorFlow.Examples.MNIST.InputData
other-modules: Paths_tensorflow_mnist_input_data
build-depends: Cabal
, HTTP
, base >= 4.7 && < 5
, bytestring
, cryptonite
, directory
, filepath
, network-uri
default-language: Haskell2010
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View file

@ -0,0 +1,3 @@
import Distribution.Simple
main = defaultMain

View file

@ -0,0 +1,161 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
import Control.Monad (zipWithM, when, forM, forM_)
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64)
import qualified Data.Text.IO as T
import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Gradient as TF
import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse
numPixels = 28^2 :: Int64
numLabels = 10 :: Int64
-- | Create tensor with random values where the stddev depends on the width.
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float)
randomParam width (TF.Shape shape) =
(* stddev) <$> TF.truncatedNormal (TF.vector shape)
where
stddev = TF.scalar (1 / sqrt (fromIntegral width))
-- Types must match due to model structure (sparseToDense requires
-- index types to match)
type LabelType = Int32
type BatchSize = Int32
-- | Convert scalar labels to one-hot vectors.
labelClasses :: TF.Tensor TF.Value LabelType
-> LabelType
-> BatchSize
-> TF.Tensor TF.Value Float
labelClasses labels numClasses batchSize =
let indices = TF.range 0 (TF.scalar batchSize) 1
concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1]
in TF.sparseToDense concated
(TF.constant [2] [batchSize, numClasses])
1 {- ON value -}
0 {- default (OFF) value -}
-- | Fraction of elements that differ between two vectors.
errorRate :: Eq a => V.Vector a -> V.Vector a -> Double
errorRate xs ys = fromIntegral (len - numCorrect) / fromIntegral len
where
numCorrect = V.length $ V.filter id $ V.zipWith (==) xs ys
len = V.length xs
data Model = Model {
train :: TF.TensorData Float -- ^ images
-> TF.TensorData LabelType
-> TF.Session ()
, infer :: TF.TensorData Float -- ^ images
-> TF.Session (V.Vector LabelType) -- ^ predictions
}
createModel :: Int64 -> TF.Build Model
createModel batchSize = do
-- Inputs.
images <- TF.placeholder [batchSize, numPixels]
-- Hidden layer.
let numUnits = 500
hiddenWeights <-
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases
let hidden = TF.relu hiddenZ
-- Logits.
logitWeights <-
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
logitBiases <- TF.zeroInitializedVariable [numLabels]
let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases
predict <- TF.render $ TF.cast $
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
-- Create training action.
labels <- TF.placeholder [batchSize]
let labelVecs = labelClasses labels 10 (fromIntegral batchSize)
loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
grads <- TF.gradients loss params
let lr = TF.scalar $ 0.001 / fromIntegral batchSize
applyGrad param grad
= TF.assign param $ param `TF.sub` (lr * grad)
trainStep <- TF.group =<< zipWithM applyGrad params grads
return Model {
train = \imFeed lFeed -> TF.runWithFeeds_ [
TF.feed images imFeed
, TF.feed labels lFeed
] trainStep
, infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict
}
main = TF.runSession $ do
-- Read training and test data.
trainingImages <- liftIO (readMNISTSamples =<< trainingImageData)
trainingLabels <- liftIO (readMNISTLabels =<< trainingLabelData)
testImages <- liftIO (readMNISTSamples =<< testImageData)
testLabels <- liftIO (readMNISTLabels =<< testLabelData)
let batchSize = 100 :: Int64
-- Create the model.
model <- TF.build $ createModel batchSize
-- Helpers for generate batches.
let selectBatch i xs = take size $ drop (i * size) $ cycle xs
where size = fromIntegral batchSize
let getImageBatch i xs = TF.encodeTensorData
[batchSize, numPixels]
$ fromIntegral <$> mconcat (selectBatch i xs)
let getExpectedLabelBatch i xs =
fromIntegral <$> V.fromList (selectBatch i xs)
-- Train.
forM_ ([0..1000] :: [Int]) $ \i -> do
let images = getImageBatch i trainingImages
labels = getExpectedLabelBatch i trainingLabels
train model images (TF.encodeTensorData [batchSize] labels)
when (i `mod` 100 == 0) $ do
preds <- infer model images
liftIO $ putStrLn $
"training error " ++ show (errorRate preds labels * 100)
liftIO $ putStrLn ""
-- Test.
let numTestBatches = length testImages `div` fromIntegral batchSize
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do
infer model (getImageBatch i testImages)
let testExpected = fromIntegral <$> V.fromList testLabels
liftIO $ putStrLn $
"test error " ++ show (errorRate testPreds testExpected * 100)
-- Show some predictions.
liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do
putStrLn ""
T.putStrLn $ drawMNIST $ testImages !! i
putStrLn $ "expected " ++ show (testLabels !! i)
putStrLn $ " got " ++ show (testPreds V.! i)

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,30 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
-- | Paths to test helper files.
module TensorFlow.Examples.MNIST.TrainedGraph where
import Paths_tensorflow_mnist (getDataFileName)
import Data.ByteString (ByteString)
import Data.ByteString.Char8 (pack)
-- | File containing a Tensorflow serialized proto of MNIST.
mnistPb :: IO FilePath
mnistPb = getDataFileName "data/MNIST.pb"
-- | Files containing pre-trained weights for MNIST.
wtsCkpt, biasCkpt :: IO ByteString
wtsCkpt = pack <$> getDataFileName "data/MNISTWts.ckpt"
biasCkpt = pack <$> getDataFileName "data/MNISTBias.ckpt"

View file

@ -0,0 +1,96 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ViewPatterns #-}
module TensorFlow.Examples.MNIST.Parse where
import Control.Monad (when, liftM)
import Data.Binary.Get (Get, runGet, getWord32be, getLazyByteString)
import Data.ByteString.Lazy (toStrict, readFile)
import Data.List.Split (chunksOf)
import Data.Monoid ((<>))
import Data.ProtoLens (Message, decodeMessageOrDie)
import Data.Text (Text)
import Data.Word (Word8, Word32)
import Prelude hiding (readFile)
import qualified Codec.Compression.GZip as GZip
import qualified Data.ByteString.Lazy as L
import qualified Data.Text as Text
import qualified Data.Vector as V
-- | Utilities specific to MNIST.
type MNIST = V.Vector Word8
-- | Produces a unicode rendering of the MNIST digit sample.
drawMNIST :: MNIST -> Text
drawMNIST = chunk . block
where
block :: V.Vector Word8 -> Text
block (V.splitAt 1 -> ([0], xs)) = " " <> block xs
block (V.splitAt 1 -> ([n], xs)) = c `Text.cons` block xs
where c = "\9617\9618\9619\9608" !! fromIntegral (n `div` 64)
block (V.splitAt 1 -> _) = ""
chunk :: Text -> Text
chunk "" = "\n"
chunk xs = Text.take 28 xs <> "\n" <> chunk (Text.drop 28 xs)
-- | Check's the file's endianess, throwing an error if it's not as expected.
checkEndian :: Get ()
checkEndian = do
magic <- getWord32be
when (magic `notElem` ([2049, 2051] :: [Word32])) $
fail "Expected big endian, but image file is little endian."
-- | Reads an MNIST file and returns a list of samples.
readMNISTSamples :: FilePath -> IO [MNIST]
readMNISTSamples path = do
raw <- GZip.decompress <$> readFile path
return $ runGet getMNIST raw
where
getMNIST :: Get [MNIST]
getMNIST = do
checkEndian
-- Parse header data.
cnt <- liftM fromIntegral getWord32be
rows <- liftM fromIntegral getWord32be
cols <- liftM fromIntegral getWord32be
-- Read all of the data, then split into samples.
pixels <- getLazyByteString $ fromIntegral $ cnt * rows * cols
return $ V.fromList <$> chunksOf (rows * cols) (L.unpack pixels)
-- | Reads a list of MNIST labels from a file and returns them.
readMNISTLabels :: FilePath -> IO [Word8]
readMNISTLabels path = do
raw <- GZip.decompress <$> readFile path
return $ runGet getLabels raw
where getLabels :: Get [Word8]
getLabels = do
checkEndian
-- Parse header data.
cnt <- liftM fromIntegral getWord32be
-- Read all of the labels.
L.unpack <$> getLazyByteString cnt
readMessageFromFileOrDie :: Message m => FilePath -> IO m
readMessageFromFileOrDie path = do
pb <- readFile path
return $ decodeMessageOrDie $ toStrict pb
-- TODO: Write a writeMessageFromFileOrDie and read/write non-lethal
-- versions.

View file

@ -0,0 +1,80 @@
name: tensorflow-mnist
version: 0.1.0.0
synopsis: TensorFlow demo application for learning MNIST model.
description: Please see README.md
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Simple
cabal-version: >=1.22
data-files: data/*.ckpt
, data/*.pb
library
hs-source-dirs: src
, src-data
exposed-modules: TensorFlow.Examples.MNIST.Parse
, TensorFlow.Examples.MNIST.TrainedGraph
other-modules: Paths_tensorflow_mnist
build-depends: proto-lens == 0.1.*
, base >= 4.7 && < 5
, binary
, bytestring
, filepath
, lens-family
, containers
, split
, tensorflow-proto == 0.1.*
, tensorflow-core-ops == 0.1.*
, tensorflow
, text
, vector
, zlib
default-language: Haskell2010
executable Main
default-language: Haskell2010
main-is: Main.hs
hs-source-dirs: app
build-depends: base
, bytestring
, filepath
, lens-family
, proto-lens
, tensorflow
, tensorflow-mnist
, tensorflow-mnist-input-data
, tensorflow-ops
, tensorflow-proto
, text
, transformers
, vector
Test-Suite ParseTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: ParseTest.hs
hs-source-dirs: tests
build-depends: HUnit
, base
, bytestring
, proto-lens
, lens-family
, google-shim
, tensorflow
, tensorflow-mnist
, tensorflow-mnist-input-data
, tensorflow-ops
, tensorflow-proto
, test-framework
, test-framework-hunit
, text
, transformers
, vector
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View file

@ -0,0 +1,170 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Data.Text (Text)
import qualified Data.Text.IO as Text
import Lens.Family2 ((&), (.~), (^.))
import Prelude hiding (abs)
import Proto.Tensorflow.Core.Framework.Graph
( GraphDef(..)
, version
, node )
import Proto.Tensorflow.Core.Framework.NodeDef
( NodeDef(..)
, op )
import System.IO as IO
import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse
import TensorFlow.Examples.MNIST.TrainedGraph
import TensorFlow.Build
( asGraphDef
, addGraphDef
, render
)
import TensorFlow.Tensor
( Tensor(..)
, Ref
, Value
, feed
, TensorKind(..)
, tensorFromName
)
import TensorFlow.Ops
import TensorFlow.Nodes (unScalar)
import TensorFlow.Session
(runSession, run, run_, runWithFeeds, build, buildAnd)
import TensorFlow.Types (TensorType(..), Shape(..))
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), Assertion)
import Google.Test
import qualified Data.Vector as V
-- | Test that a file can be read and the GraphDef proto correctly parsed.
testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
-- Check the function on a known well-formatted file.
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
-- Simple field read.
1 @=? mnist^.version
-- Count the number of nodes.
let nodes :: [NodeDef]
nodes = mnist^.node
100 @=? length nodes
-- Check that the expected op is found at an arbitrary index.
"Variable" @=? nodes!!6^.op
-- | Parse the test set for label and image data. Will only fail if the file is
-- missing or incredibly corrupt.
testReadMNIST = testCase "testReadMNIST" $ do
imageData <- readMNISTSamples =<< testImageData
10000 @=? length imageData
labelData <- readMNISTLabels =<< testLabelData
10000 @=? length labelData
testNodeName :: Text -> Tensor v a -> Assertion
testNodeName n g = n @=? opName
where
opName = head (gDef^.node)^.op
gDef = asGraphDef $ render g
testGraphDefGen = testCase "testGraphDefGen" $ do
-- Test the inferred operation type.
let f0 :: Tensor Value Float
f0 = 0
testNodeName "Const" f0
testNodeName "Add" $ 1 + f0
testNodeName "Mul" $ 1 * f0
testNodeName "Sub" $ 1 - f0
testNodeName "Abs" $ abs f0
testNodeName "Sign" $ signum f0
testNodeName "Neg" $ -f0
-- Test the grouping.
testNodeName "Add" $ 1 + f0 * 2
testNodeName "Add" $ 1 + (f0 * 2)
testNodeName "Mul" $ (1 + f0) * 2
-- | Convert a simple graph to GraphDef, load it, run it, and check the output.
testGraphDefExec = testCase "testGraphDefExec" $ do
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
runSession $ do
build $ addGraphDef graphDef
x <- run $ tensorFromName ValueKind "Mul_2"
liftIO $ (50 :: Float) @=? unScalar x
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
-- sample data.
testMNISTExec = testCase "testMNISTExec" $ do
-- Switch to unicode to enable pretty printing of MNIST digits.
IO.hSetEncoding IO.stdout IO.utf8
-- Parse the Graph definition, samples, & labels from files.
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
mnistSamples <- readMNISTSamples =<< testImageData
mnistLabels <- readMNISTLabels =<< testLabelData
-- Select a sample to run on and convert it into a TensorData of Floats.
let idx = 12
sample :: MNIST
sample = mnistSamples !! idx
label = mnistLabels !! idx
tensorSample = encodeTensorData (Shape [1,784]) floatSample
where
floatSample :: V.Vector Float
floatSample = V.map fromIntegral sample
Text.putStrLn $ drawMNIST sample
-- Execute the graph on the sample data.
runSession $ do
-- The version of this session is 0, but the version of the graph is 1.
-- Change the graph version to 0 so they're compatible.
build $ addGraphDef $ mnist & version .~ 0
-- Define nodes that restore saved weights and biases.
let bias, wts :: Tensor Ref Float
bias = tensorFromName RefKind "Variable"
wts = tensorFromName RefKind "weights"
wtsCkptPath <- liftIO wtsCkpt
biasCkptPath <- liftIO biasCkpt
-- Run those restoring nodes on the graph in the current session.
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
[ restore wtsCkptPath wts
, restoreFromName biasCkptPath "bias" bias
]
-- Encode the expected sample data as one-hot data.
let ty = encodeTensorData [10] oneHotLabels
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
updates = [(fromIntegral label, 1)]
let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample
, feed (tensorFromName ValueKind "y-input") ty
]
-- Run the graph with the input feeds and read the ArgMax'd result from
-- the test (not training) side of the evaluation.
x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax"
-- Print the trained model's predicted outcome.
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
++ "Prediction: " ++ show (unScalar x :: Int64)
-- Check whether the prediction matches the expectation.
liftIO $ (fromInteger . toInteger $ label :: Int64) @=? unScalar x
main :: IO ()
main = googleTest [ testReadMessageFromFileOrDie
, testReadMNIST
, testGraphDefGen
, testGraphDefExec
, testMNISTExec]

View file

@ -0,0 +1,3 @@
import Distribution.Simple
main = defaultMain

View file

@ -0,0 +1,457 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
-- | Rendering of TensorFlow operations as Haskell functions.
module TensorFlow.OpGen
( OpGenFlags(..)
, docOpList
, flagParser)
where
import Prelude hiding (head, tail)
import Control.Monad (guard)
import Data.Char (toLower, toUpper)
import Data.Foldable (toList)
import Data.Maybe (fromMaybe, maybeToList)
import Data.ProtoLens (def, showMessage)
import Data.List.NonEmpty (NonEmpty((:|)), head)
import qualified Data.List.NonEmpty as NE
import Lens.Family2 ((^.), (.~), (&), view)
import Options.Applicative (Parser, help, long, strOption, value)
import Proto.Tensorflow.Core.Framework.OpDef
( OpList
, OpDef
, OpDef'ArgDef
, attr
, description
, inputArg
, name
, numberAttr
, op
, outputArg
, summary
, type'
, typeAttr
)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import System.FilePath (takeBaseName)
import TensorFlow.OpGen.AttrVal
(AttrDef
, AttrCase(..)
, AttrTemplate(..)
, Template
, attrDef
, attrOriginal
, attrTemplate
, templateDefault
, templateRestrictions
)
import Text.PrettyPrint.Mainland
( Doc
, (<>)
, (<+>)
, (</>)
, (<+/>)
, brackets
, comma
, commasep
, dquotes
, empty
, enclose
, flatten
, folddoc
, hang
, indent
, int
, parens
, sep
, stack
, strictText
, tuple
)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as Text
import qualified Data.Semigroup as Semigroup
import Data.Text (Text)
data OpGenFlags = OpGenFlags
{ outputFile :: String
, prefix :: String
, excludeList :: String
}
flagParser :: Parser OpGenFlags
flagParser = OpGenFlags
<$> strOption (mconcat [ long "output"
, help "File to write."
])
<*> strOption (mconcat [ long "prefix"
, help "Haskell package prefix to use"
])
<*> strOption (mconcat [ long "exclude_list"
, value ""
, help "Comma separated Ops names to ignore"
])
docOpList :: OpGenFlags -> OpList -> Doc
docOpList flags opList =
stack [ "{-# LANGUAGE ConstraintKinds #-}"
, "{-# LANGUAGE DataKinds #-}"
, "{-# LANGUAGE FlexibleInstances #-}"
, "{-# LANGUAGE OverloadedStrings #-}"
, "{-# LANGUAGE RankNTypes #-}"
, "{-# LANGUAGE ScopedTypeVariables #-}"
, "module" <+> strictText moduleName <+> "where"
, empty
, imports
, empty
, folddoc (\x y -> x </> empty </> y)
(map renderDef $
filter (not . flip elem exclusions . view name) $
toList $ opList ^. op)
]
where moduleName =
Text.pack (prefix flags) <> "." <> camelCase
-- Discards the optional trailing _op_lib
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
shortName = Text.pack (takeBaseName $ outputFile flags)
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
camelCase s = Text.concat $ map upCase
$ filter (/= "ops")
$ Text.splitOn "_" s
-- | Upper-case the given text.
upCase :: Text -> Text
upCase = forceCase toUpper
-- | Lower-case the given name, and prevent it from overlapping with a reserved
-- Haskell name.
lowCase :: Text -> Text
lowCase = replaceReservedName . forceCase toLower
forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
(Text.uncons s)
imports = stack [
"import Data.ByteString (ByteString)"
, "import Data.Complex (Complex)"
, "import Data.Int (Int8, Int16, Int32, Int64)"
, "import Data.Word (Word8, Word16)"
, "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build"
, "import TensorFlow.BuildOp"
, "import TensorFlow.Tensor"
, "import TensorFlow.Types"
]
renderDef :: OpDef -> Doc
renderDef d =
stack [
haddocks
, n <+> "::" <+> hang 0 (typeSig d)
, n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" </> -- args are indented
-- the body needs to be indented wrt the name
indent indentation functionBody
, extras -- just for debug
]
where
n = strictText $ fixOpName (d ^. name)
args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs
tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg]
fixOpName = lowCase
funcGuard = "eqLengthGuard" <+> brackets (commasep entries)
where
entries =
[ parens $ quotedText nAttr <> comma <+>
brackets (commasep $ toList $
NE.map renderTensorName tensorNames)
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
]
renderTensorName x = parens $ quotedText x <> comma <+>
"length" <+> strictText x
-- Uses hang 0 to align the argument vertically on multiple lines.
functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
buildFunction
| null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+> brackets (commasep outputListsSizes)
outputListsSizes = [ strictText numberAttrName
| o <- d ^. outputArg
, let numberAttrName = o ^. numberAttr
, not (Text.null numberAttrName) &&
numberAttrName `Map.member` mandatoryAttrMap d
]
buildOpParts =
"opDef" <+> quotedText (d ^. name) :
-- Renders tensor arguments.
[ "& opAttr" <+> quotedText tfName <+>
".~ tensorType (undefined ::" <+> strictText hsName <> ")"
| (tfName, (hsName, _)) <- Map.toList typeMap
] ++
-- Renders mandatory attributes as function parameters.
[ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName
| (tfName, hsName) <- mandatoryAttrs
] ++
-- Renders sizes of tensor list types having number_attr.
[ "& opAttr" <+> quotedText nAttr <+> ".~" <+>
"(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)"
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
]
mandatoryAttrs = [(strictText tf, strictText hs)
| (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d)
]
haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description)
extras = enclose "{-\n" "\n-}" $
strictText $ Text.pack $
showMessage ((def :: OpDef)
& inputArg .~ (d ^. inputArg)
& outputArg .~ (d ^. outputArg)
& attr .~ (d ^. attr))
typeMap = opDefTypeMap d
-- | Makes a quoted string doc out of the given text value.
quotedText :: Text.Text -> Doc
quotedText = dquotes . strictText
-- | typeSig renders the type signature of the given OpDef.
typeSig :: OpDef -> Doc
typeSig d =
foralls <+> constraints <+/>
signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs])
where
foralls | Map.null typeMap = empty
| otherwise =
"forall"
<+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap))
<+> "."
constraints | Map.null typeMap = empty
| otherwise =
tuple (concatMap
(\(t, aDef) ->
"TensorType" <+> strictText t
: maybeToList (oneOfRestrictions aDef t))
(Map.elems typeMap)) <+> "=>"
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)]
tensorArg refType arg = wrapArg refType arg <+>
hang 0 ("-- ^" <+> argComment arg)
-- Argument type is a list of tensors if number_attr is set;
-- otherwise it's a single Tensor.
wrapArg refType arg =
if Text.null (arg ^. numberAttr) then typ else brackets typ
where typ = tensorType refType arg
tensorType refType arg =
"Tensor" <+> refType <+> maybe directType strictText indirectType
where
indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap)
directType = dtTypeToDoc (arg ^. type')
outputs =
case d ^. outputArg of
[] -> "ControlNode"
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
os -> renderTupleResult os
wrappedOutput = wrapArg "Value"
-- Tuple result case is rendered differently to give
-- individual elements their own comments.
renderTupleResult os =
stack $ [ tuple (map wrappedOutput os)
, flatten commentSummary
] ++ map commentDetails os
where
commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os]
commentDetails o =
stack [ "--"
, "-- *" <+> argComment o
]
signatureFold = folddoc (\x y -> x </> "->" <+> y)
mandatoryAttrInputs = [
dtTypeToDoc dtType <+>
hang 0 ("-- ^" <+> argComment' tfName descr)
| (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d]
typeMap = opDefTypeMap d
-- | Returns the type restriction for the given tensor type if the
-- set of allowed types is not empty (i.e. restricted).
oneOfRestrictions :: AttrDef -> Text -> Maybe Doc
oneOfRestrictions aDef tName = do
typs <- onAttrType (^. templateRestrictions) aDef
guard $ not $ null typs
let typeList = commasep $ map strictText $
Set.toList $ Set.fromList $
map dtTypeToHaskell typs
return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName
-- | Identifies the attributes used as tensor cardinalities. In such
-- cases a list of tensors is supplied as an input_arg. The number of
-- such inputs is communicated as a separate opAttr.
-- The result key is TensorFlow attribute name and the value is the
-- tensor names which have number_attr set to the result key.
numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text)
numberAttrMap d = Map.fromListWith (Semigroup.<>) [
(nAttr, replaceReservedName (inp ^. name) :| [])
| inp <- d ^. inputArg
, nAttr <- [inp ^. numberAttr]
, not (Text.null nAttr)
]
argComment :: OpDef'ArgDef -> Doc
argComment arg = argComment' (arg ^. name) (arg ^. description)
argComment' :: Text.Text -> Text.Text -> Doc
argComment' argName argDesc =
bold argName <> splitMultilineText (":" <+>) argDesc
bold :: Text.Text -> Doc
bold n = strictText ("__" <> n <> "__")
opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef)
opDefTypeMap d =
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
attrList :: OpDef -> [(Text.Text, AttrDef)]
attrList d = [(a ^. name, attrDef a) | a <- d ^. attr]
isType :: AttrDef -> Bool
isType = fromMaybe False . onAttrType (const True)
-- | Applies the given function to the data type. Is this a Prism?
onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a
onAttrType f x = case x ^. attrTemplate of
AttrSingle (AttrType a) -> Just (f a)
_ -> Nothing
-- | mandatoryAttrMap contains the attributes chosen by
-- isMandatoryAttr, excluding those which are derived from list of
-- tensor arguments. The key is the TF name of the attribute. The
-- value tuple is (haskell name, TF type, attribute comment).
mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text)
mandatoryAttrMap d =
Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description))
| (n, a) <- attrList d
, Just dtType <- [isMandatoryAttr a]
-- Excludes the attributes rendered as list lengths.
, n `Map.notMember` numberAttrMap d
]
-- | Inspects the attribute and if it is one of the implemented
-- non-tensor values lacking default, then returns Just the TF type.
isMandatoryAttr :: AttrDef -> Maybe DataType
isMandatoryAttr x =
case x ^. attrTemplate of
AttrSingle (AttrBool y) -> noDefault DT_BOOL y
AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y
AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y
_ -> Nothing
where
noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault)
dtTypeToDoc :: DataType -> Doc
dtTypeToDoc = strictText . dtTypeToHaskell
-- NOTE: The cases of this function should be kept in sync with
-- TensorFlow.Types.AllTensorTypes.
dtTypeToHaskell :: DataType -> Text.Text
dtTypeToHaskell DT_BOOL = "Bool"
dtTypeToHaskell DT_BFLOAT16 = "Data.Word.Word16"
dtTypeToHaskell DT_COMPLEX128 = "(Data.Complex.Complex Double)"
dtTypeToHaskell DT_COMPLEX64 = "(Data.Complex.Complex Float)"
dtTypeToHaskell DT_DOUBLE = "Double"
dtTypeToHaskell DT_FLOAT = "Float"
dtTypeToHaskell DT_INT16 = "Data.Int.Int16"
dtTypeToHaskell DT_INT32 = "Data.Int.Int32"
dtTypeToHaskell DT_INT64 = "Data.Int.Int64"
dtTypeToHaskell DT_INT8 = "Data.Int.Int8"
dtTypeToHaskell DT_QINT32 = "Data.Int.Int32" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QINT16 = "Data.Int.Int16" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell x =
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
-- | haddockComment escapes TensorFlow doc strings into haddock.
-- TODO(gnezdo): deal with the markup.
haddockComment :: Text.Text -> Doc
haddockComment = strictText
multilineComment :: Text.Text -> Text.Text -> Doc
multilineComment summary' detail =
haddockComment summary' </>
splitMultilineText insertParagraphAndComment detail
where insertParagraphAndComment x = "--" </> "--" <+> x
-- | Converts the given multi-line detail string into
-- a multi-line haddock. Applies the given lead to the
-- first line. Returns an empty document for empty detail.
splitMultilineText :: (Doc -> Doc) -> Text.Text -> Doc
splitMultilineText lead detail =
case Text.lines detail of
[] -> empty
(l : ls) -> stack $ lead (haddockComment l)
: map (("--" <+>) . haddockComment) ls
replaceReservedName :: Text -> Text
replaceReservedName n
| n `Set.member` reservedKeywords = n <> "'"
| otherwise = n
indentation = 4
reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
-- Haskell2010 keywords:
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
-- We don't include keywords that are allowed to be variable names,
-- in particular: "as", "forall", and "hiding".
[ "case"
, "class"
, "data"
, "default"
, "deriving"
, "do"
, "else"
, "foreign"
, "if"
, "import"
, "in"
, "infix"
, "infixl"
, "infixr"
, "instance"
, "let"
, "module"
, "newtype"
, "of"
, "then"
, "type"
, "where"
]
++ -- Nonstandard extensions
[ "mdo" -- RecursiveDo
, "rec" -- Arrows, RecursiveDo
, "proc" -- Arrows
]

View file

@ -0,0 +1,120 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
-- | Wrapping of TensorFlow attributes into Haskell entities.
module TensorFlow.OpGen.AttrVal
(AttrDef
, AttrCase(..)
, AttrTemplate(..)
, Template
, attrDef
, attrOriginal
, attrTemplate
, templateDefault
, templateRestrictions
) where
import Data.Int (Int64)
import Data.Monoid ((<>))
import Lens.Family2 (Lens', (^.))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue
import Proto.Tensorflow.Core.Framework.OpDef as OpDef
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto)
import qualified Data.ByteString as B
import qualified Data.Text as Text
-- | Specifies the optional default value and a set of allowed values
-- for the given type.
data Template a = Template {
_templateDefault :: Maybe a
-- ^ The default value (mandatory if unspecified)
, _templateRestrictions :: [a]
-- ^ The allowed set of values, empty if no restrictions
}
templateDefault :: Lens' (Template a) (Maybe a)
templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x })
templateRestrictions :: Lens' (Template a) [a]
templateRestrictions = lens _templateRestrictions
(\g x -> g { _templateRestrictions = x })
data UnusedTensor
data AttrCase f
= AttrBytes (f B.ByteString) -- bytes s = 2; // "string"
| AttrInt64 (f Int64) -- int64 i = 3; // "int"
| AttrFloat (f Float) -- float f = 4; // "float"
| AttrBool (f Bool) -- bool b = 5; // "bool"
| AttrType (f DataType) -- type = 6; // "type"
-- To be translated into TensorFlow.Types.Shape before use.
-- Leaving as a proto to reduce dependencies.
| AttrShape (f TensorShapeProto) -- shape = 7; // "shape"
-- | Type-reified representation of TensorFlow AttrDef.
-- Initially limited to just the types in Op descriptors.
data AttrTemplate
= AttrSingle (AttrCase Template)
| AttrList (AttrCase [])
| AttrTensor UnusedTensor -- tensor = 8; // "tensor"
data AttrDef = AttrDef {
_attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from
, _attrTemplate :: AttrTemplate -- ^ the type of the attribute
}
attrTemplate :: Lens' AttrDef AttrTemplate
attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x })
attrOriginal :: Lens' AttrDef OpDef'AttrDef
attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x })
attrDef :: OpDef'AttrDef -> AttrDef
attrDef a = AttrDef a
$ translate (a^.OpDef.type')
(a^.OpDef.defaultValue)
(a^.allowedValues)
-- | Converts the given AttrValue with the type given by the string
-- into the AttrVal if the type is known.
translate :: Text.Text -- ^ one of the TensorFlow type strings
-> AttrValue -- ^ default value
-> AttrValue -- ^ allowed values
-> AttrTemplate
translate t defaults allowed
| t == "string" = makeVal AttrBytes maybe's s
| t == "int" = makeVal AttrInt64 maybe'i i
| t == "float" = makeVal AttrFloat maybe'f f
| t == "bool" = makeVal AttrBool maybe'b b
| t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type'
| t == "shape" = makeVal AttrShape maybe'shape shape
| t == "tensor" = AttrTensor $ error "tensor is unimplemented"
| t == "list(string)" = makeList AttrBytes $ list.s
| t == "list(int)" = makeList AttrInt64 $ list.i
| t == "list(float)" = makeList AttrFloat $ list.f
| t == "list(bool)" = makeList AttrBool $ list.b
| t == "list(type)" = makeList AttrType $ list.AttrValue.type'
| t == "list(shape)" = makeList AttrShape $ list.shape
| t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented"
| t == "func" = AttrTensor $ error "func is unimplemented"
| otherwise = error $ show ("Unknown attribute type " <> t) ++
"," ++ show defaults ++
"," ++ show allowed
where makeVal c x y = AttrSingle $ c $
Template (defaults^.x) (allowed^.list.y)
makeList c y = AttrList $ c $ defaults^.y

View file

@ -0,0 +1,33 @@
name: tensorflow-opgen
version: 0.1.0.0
synopsis: Code generation for TensorFlow operations.
description: Please see README.md
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Simple
cabal-version: >=1.22
library
hs-source-dirs: src
exposed-modules: TensorFlow.OpGen.AttrVal
, TensorFlow.OpGen
build-depends: proto-lens == 0.1.*
, tensorflow-proto == 0.1.*
, base >= 4.7 && < 5
, bytestring
, containers
, filepath
, lens-family
, mainland-pretty
, optparse-applicative
, semigroups
, text
default-language: Haskell2010
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View file

@ -0,0 +1 @@
../third_party/tensorflow

3
tensorflow-ops/Setup.hs Normal file
View file

@ -0,0 +1,3 @@
import Distribution.Simple
main = defaultMain

View file

@ -0,0 +1,76 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
-- | Parallel lookups on the list of tensors.
module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops () -- Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
-- | Looks up `ids` in a list of embedding tensors.
--
-- This function is used to perform parallel lookups on the list of
-- tensors in `params`. It is a generalization of `TF.gather`, where
-- `params` is interpreted as a partition of a larger embedding
-- tensor.
--
-- The partition_strategy is "mod", we assign each id to partition
-- `p = id % len(params)`. For instance,
-- 13 ids are split across 5 partitions as:
-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
--
-- The results of the lookup are concatenated into a dense
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v .
( TensorType a
, OneOf '[Int64, Int32] b
, Num b
)
=> [Tensor v a]
-- ^ A list of tensors which can be concatenated along
-- dimension 0. Each `Tensor` must be appropriately
-- sized for `mod` partition strategy.
-> Tensor Value b
-- ^ A `Tensor` with type `int32` or `int64`
-- containing the ids to be looked up in `params`.
-- The ids are required to be flat on entry and have
-- fewer than 2^31 entries.
-> Build (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup params ids =
CoreOps.dynamicStitch pindices <$> partitionedResult
where np = genericLength params
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
newIds = ids `CoreOps.div` np
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
-- Partition list of ids based on assignments into np separate lists
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
-- Similarly, partition the original indices.
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
partitionedResult = zipWithM
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
params gatherIds

View file

@ -0,0 +1,697 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module TensorFlow.Gradient
( gradients
) where
import Control.Monad (forM, zipWithM)
import Control.Monad.State.Strict (State, evalState, gets, modify)
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int32, Int64)
import Data.List (foldl', sortBy)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
import Data.Ord (comparing)
import Data.ProtoLens.TextFormat (showMessage)
import Data.Set (Set)
import Data.Text (Text)
import Data.Tuple (swap)
import Lens.Family2 (Lens', (&), (^.), (.~), (%~))
import Lens.Family2.State.Strict (uses)
import Lens.Family2.Stock (at, intAt)
import Lens.Family2.Unchecked (lens, iso)
import Prelude hiding (sum)
import Text.Printf (printf)
import qualified Data.Graph.Inductive.Basic as FGL
import qualified Data.Graph.Inductive.Graph as FGL
import qualified Data.Graph.Inductive.PatriciaTree as FGL
import qualified Data.Graph.Inductive.Query.DFS as FGL
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as Text
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Build
( Build
, render
, renderNodeName
, renderedNodeDefs
, opDef
, opAttr
)
import TensorFlow.BuildOp
import TensorFlow.Ops
( addN
, broadcastGradientArgs
, expandDims
, fill
, matMul
, reducedShape
, reluGrad
, reshape
, scalar
, shape
, softmaxCrossEntropyWithLogits
, sum
, vector
, zerosLike
)
import TensorFlow.Output
( NodeName(..)
, Op (Rendered)
, Output(..)
, OutputIx(..)
, outputIndex
)
import TensorFlow.Tensor
( Tensor(..)
, TensorKind (ValueKind)
, Value
, tensorOutput
, tensorAttr
)
import TensorFlow.Types (OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef
(NodeDef, attr, input, op, name)
type GradientCompatible a =
-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
-- TODO(fmayle): Support control flow.
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.
-- TODO(fmayle): Do we need to consider control inputs? See _PendingCount in
-- tensorflow/python/ops/gradients.py.
-- TODO(fmayle): Maybe store the gradient functions and numOutputs on the OpDef.
-- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
-- TODO(gnezdo): remove indirect constraint.
-- It's a wart inherited from Num instance.
, v1 ~ Value
, GradientCompatible a
)
=> Tensor v1 a -- ^ The output of the graph.
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
-> Build [Tensor Value a]
gradients y xs = do
-- The gradients are computed using "reverse accumulation", similarly to
-- what is described here:
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
--
-- The code is summarised as follows:
--
-- 1. Create an fgl graph of the relevant nodes (ops) and edges (tensors).
-- 2. Initialize the gradient of y to 1 (∂y/∂y = 1) and the rest of tensor's
-- gradients to nothing.
-- 3. Process the nodes in reverse topological order (i.e. each node comes
-- after all of its outputs so that the output gradients for a node have
-- been completely calculated before it is processed):
-- a. Record the gradient for each of the node's output tensors (∂y/∂w
-- for each output tensor w).
-- b. Calculate the gradient of y w.r.t. each of the node's input
-- tensors using the gradients of the node's output tensors.
--
-- Written differently, for each output tensor w and input tensor v:
-- ∂y/∂w = ... (calculated in previous steps)
-- ∂w/∂v = ... (op specific)
-- ∂y/∂v = ∂y/∂w * ∂w/∂v (technically, if tensor v is an input
-- to multiple nodes, then this is only
-- part of ∂y/∂v)
--
-- 4. Lookup the recorded gradient for each x in xs.
yName <- renderNodeName y
-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
. flip Map.lookup
let (gr, nodeMap) = createGraph yName nodeDefLookup
-- Set gradient of y to one.
let initPending :: Map.Map FGL.Node (PendingGradients a)
initPending = Map.empty & at (nodeMap Map.! yName)
. nonEmpty
. outputIxAt (y ^. tensorOutput . outputIndex)
. nonEmpty
.~ [fill (shape y) (scalar 1)]
-- Calculate the gradients of y w.r.t. each node in the graph.
gradientMap <- graphGrads gr initPending
-- Lookup the gradients for each x.
forM xs $ \x -> do
xName <- renderNodeName x
render $ fromMaybe (zerosLike x) $ do
n <- nodeMap ^. at xName
let i = x ^. tensorOutput . outputIndex
gradientMap ^. at n . nonEmpty . outputIxAt i
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
outputIxAt = intAt . unOutputIx
-- | Incomplete gradients of a node's outputs.
--
-- The lists represent partial sums. The key is an OutputIx sans newtype.
type PendingGradients a = IntMap.IntMap [Tensor Value a]
-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
type Gradients a = IntMap.IntMap (Tensor Value a)
-- | Graph of TensorFlow operations.
type Graph = FGL.Gr NodeDef EdgeLabel
-- | Data associated with an edge.
--
-- Pair of
-- 1. Output index of a tensor from the source node.
-- 2. Input index that the tensor connects to on the destination node.
type EdgeLabel = (OutputIx, OutputIx)
-- | State used for calculating gradients.
data GradientsState a = GradientsState
{ _gradientsPending :: !(Map FGL.Node (PendingGradients a))
, _gradientsResult :: !(Map FGL.Node (Gradients a))
}
gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a))
gradientsPending = lens _gradientsPending (\x y -> x { _gradientsPending = y })
gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a))
gradientsResult = lens _gradientsResult (\x y -> x { _gradientsResult = y })
-- TODO(fmayle): Use something like Data.List.Safe.
-- | Safe version of (!!).
safeIndex :: [a] -> Int -> Maybe a
_ `safeIndex` n | n < 0 = Nothing
[] `safeIndex` _ = Nothing
(x:_) `safeIndex` 0 = Just x
(_:xs) `safeIndex` n = xs `safeIndex` (n-1)
-- Copy of http://hackage.haskell.org/package/lens-3.9.0.2/docs/Control-Lens-Iso.html#v%3anon
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
anon a p = iso (fromMaybe a) go where
go b | p b = Nothing
| otherwise = Just b
non :: Eq a => a -> Lens' (Maybe a) a
non a = anon a (a==)
-- | Lens that defaults Nothing to mempty.
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
nonEmpty = anon mempty null
-- | Calculate the gradients for every node in a graph.
graphGrads :: forall a. GradientCompatible a
=> Graph
-> Map FGL.Node (PendingGradients a)
-- ^ Initial gradients (usually just 1 for the node of interest).
-> Build (Map FGL.Node (Gradients a))
graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult)
where
initState = GradientsState initPending Map.empty
-- Reverse topological sort.
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
-- avoid calculating gradients that won't be used.
nodeOrder = FGL.topsort $ FGL.grev gr
go state node =
-- Aggregate the accumulated gradients for this node.
let outputGrads =
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
in if null outputGrads
then state
else
-- Calculate the gradients for each of the node's inputs.
let nextState = state & gradientsResult %~ Map.insert node outputGrads
ctx = FGL.context gr node
in updatePendingGradients
ctx
(calculateInputGrads ctx outputGrads gr)
nextState
-- | Reduce accumulated gradients for each output to one Tensor.
sumPendingGradient :: GradientCompatible a
=> PendingGradients a -> Gradients a
sumPendingGradient = IntMap.mapMaybe f
where
f [] = Nothing
f [x] = Just x
f xs = Just (addN xs)
-- | Calculate the gradients of a node's input tensors.
--
-- This is mostly just a wrapper around opGrad.
calculateInputGrads :: forall a. GradientCompatible a
=> FGL.Context NodeDef EdgeLabel
-> Gradients a -- ^ Output gradients of the node.
-> Graph
-> [Maybe (Tensor Value a)]
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
where
fullOutGrads =
fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads
-- Create a tensor from an edge (technically an Output, but it seems less
-- confusing to refer to it as a tensor here).
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
edgeToTensor ((i, _), n) =
case FGL.lab gr n of
Just edgeNodeDef -> Output i (Rendered edgeNodeDef)
Nothing -> error $ "calculateInputGrads: missing input node for "
++ Text.unpack (nodeDef ^. name)
-- Input tensors, sorted by input index.
inputTensors = map edgeToTensor $ sortBy (comparing (snd . fst)) inputEdges
-- | Convert a Map of gradients to a list, with zeros for missing outputs.
fullOutputGrads :: (TensorType a, Num a)
=> OutputIx -- ^ Number of outputs.
-> Op
-> Gradients a
-> [Tensor Value a]
fullOutputGrads n o gs =
map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1]
where
-- A tensor of zeros with the same shape as the i'th output.
zero i = zerosLike $ toT (Output i o)
-- | Update the pending gradients of a node's inputs.
updatePendingGradients :: forall a. (TensorType a, Num a)
=> FGL.Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)]
-- ^ Gradient of each input tensor.
-> GradientsState a
-> GradientsState a
updatePendingGradients (inputEdges, _, nodeDef, _) inputGrads initState =
foldl' go initState inputEdges
where
go :: GradientsState a -> (EdgeLabel, FGL.Node) -> GradientsState a
go state ((outIndex, OutputIx inIndex), node) =
case maybeGradient of
Nothing -> state
Just g ->
-- Add to the list of pending gradients for this tensor.
state & gradientsPending
. at node
. nonEmpty
. outputIxAt outIndex
. nonEmpty
%~ (g:)
where
badSizeErr = error $ printf "updatePendingGradients: bad input index \
\%d for inputGrads of length %d in %s"
inIndex (length inputGrads)
(show (nodeDef ^. name))
maybeGradient = fromMaybe badSizeErr (safeIndex inputGrads inIndex)
-- | Create a graph that includes a node and its transitive dependencies.
createGraph :: NodeName -> (NodeName -> NodeDef)
-> (Graph, Map NodeName FGL.Node)
createGraph nodeName nodeDefLookup = (FGL.nmap nodeDefLookup graph, nodeMap)
where
-- Parse a tensor name.
parseTensorName :: Text -> Maybe (NodeName, OutputIx)
parseTensorName n
| Text.null n = error "parseTensorName: empty name"
| Text.head n == '^' = Nothing -- Control edge
| otherwise =
let (nm, indexStr) = Text.breakOn ":" n
index | Text.null indexStr = 0
| otherwise = read $ Text.unpack $ Text.tail indexStr
in Just (NodeName nm, OutputIx index)
-- Build a map from node name to outward edges.
--
-- The state is the set of visited nodes.
collect :: Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State (Set NodeName)
(Map NodeName [(NodeName, OutputIx, OutputIx)])
collect outgoingEdge nm = do
let nextLookup = Map.singleton nm (maybeToList outgoingEdge)
seen <- gets (Set.member nm)
modify (Set.insert nm)
if seen
then pure nextLookup
else do
let inputs = nodeDefLookup nm ^. input
recurse inIndex (parentName, outIndex) =
collect (Just (nm, outIndex, inIndex)) parentName
subEdgeLookups <-
zipWithM recurse [0..] $ mapMaybe parseTensorName inputs
pure $ Map.unionsWith (++) (nextLookup:subEdgeLookups)
edgeLookup = evalState (collect Nothing nodeName) Set.empty
-- Associate an ID with each node name.
nodeMap = Map.fromList $ zip (Map.keys edgeLookup) [0..]
-- Create the graph.
graph = FGL.mkGraph (swap <$> Map.toList nodeMap)
[ (nodeMap Map.! n, nodeMap Map.! m, (i, j))
| (n, edges) <- Map.toList edgeLookup
, (m, i, j) <- edges
]
-- | Function to compute the gradient of y w.r.t. each input.
--
-- Let y be an arbitrary tensor
-- and [w_0, ..., w_n] be the output tensors of a node
-- and [v_0, ..., v_n] be the input tensors of the same node.
--
-- Given [∂y/∂w_0, ..., ∂y/∂w_n] and [v_0, ..., v_n], a GradientFunc computes
-- [∂y/∂v_0, ..., ∂y/∂v_n] for a particular op type.
--
-- A Nothing gradient is equivalent to zero (but allows for short circuiting
-- computation when all the gradients for something are Nothing).
type GradientFunc a = NodeDef
-> [Output]
-- ^ Input tensors.
-> [Tensor Value a]
-- ^ Gradient of y w.r.t. each output tensor.
-> [Maybe (Tensor Value a)]
-- ^ Gradient of y w.r.t. each input tensor.
-- TODO(fmayle): Assert the type is correct.
-- | Create a Tensor from an Output.
toT :: Output -> Tensor Value a
toT = Tensor ValueKind
-- | The gradient function for an op type.
--
-- These implementations should match their python counterparts in:
-- third_party/tensorflow/python/ops/*_grad.py
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x]
opGrad "Neg" _ [_] [dz] = [Just $ -dz]
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
opGrad "Square" _ [toT -> x] [dz] =
-- TODO(fmayle): Handle complex numbers.
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
-- (for performance reasons?). Will need to put these functions in the Build
-- monad to replicate that.
[Just $ dz * (2 * x)]
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
-- TODO(fmayle): The python version uses a better performance implementation
-- when the shape is known without having to run the graph.
-- TODO(fmayle): We shouldn't convert the result to a dense tensor. Sparse
-- tensor support will require some thinking.
[ Just $ CoreOps.unsortedSegmentSum values indices' numRows
, Nothing
]
where
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
denseShape = shape (x :: Tensor Value a)
numRows = CoreOps.slice denseShape 0 (1 :: Tensor Value Int32)
valuesShape = CoreOps.concat 0 [
allDimensions
, CoreOps.slice denseShape 1 (-1 :: Tensor Value Int32)
]
values = reshape dz valuesShape
-- TODO(fmayle): This could be either Int32 or Int64.
indices' = reshape indices allDimensions :: Tensor Value Int32
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
where
sx = shape (x :: Tensor Value a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
x' = reshape x outputShapeKeptDims
dz' = reshape dz outputShapeKeptDims
indicators = CoreOps.cast $ CoreOps.equal x' x
numSelected = reshape (sum indicators indices) outputShapeKeptDims
-- Min and Max have identical gradient implementations.
opGrad "Min" u v w = opGrad "Max" u v w
opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
where
-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
sx = shape (x :: Tensor Value a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
tileScaling = safeShapeDiv sx outputShapeKeptDims
grad = reshape dz outputShapeKeptDims
opGrad "Mean" u v@[toT -> x, _] w =
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
where
[Just dz, Nothing] = opGrad "Sum" u v w
inputShape = shape (x :: Tensor Value a)
outputShape = shape (dz :: Tensor Value a)
-- TODO(fmayle): Add fast path when shape is known.
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
factor = safeShapeDiv inputSize outputSize
opGrad "Add" _ [toT -> x, toT -> y] [dz] =
[ Just $ reshape (sum dz rx) sx
, Just $ reshape (sum dz ry) sy ]
where
sx = shape (x :: Tensor Value a)
sy = shape (y :: Tensor Value a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "Sub" u v w =
[Just x, Just (-y)]
where
[Just x, Just y] = opGrad "Add" u v w
opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
[ Just $ expandDims dz (-1) * snd (softmaxCrossEntropyWithLogits x y)
, Nothing ]
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
-- TODO(fmayle): Handle complex numbers.
[ Just $ reshape (sum (dz * y) rx) sx
, Just $ reshape (sum (x * dz) ry) sy ]
where
sx = shape (x :: Tensor Value a)
sy = shape (y :: Tensor Value a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
-- TODO(fmayle): Handle complex numbers.
-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
, Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy
]
where
sx = shape (x :: Tensor Value a)
sy = shape (y :: Tensor Value a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
let transposeA = lookupAttr nodeDef "transpose_a"
transposeB = lookupAttr nodeDef "transpose_b"
transAttrs a b =
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
in case (transposeA, transposeB) of
(False, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True
, Just $ (x `matMul` dz) & transAttrs True False ]
(False, True) ->
[ Just $ dz `matMul` y
, Just $ (x `matMul` dz) & transAttrs True False ]
(True, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True
, Just $ x `matMul` dz ]
(True, True) ->
[ Just $ (dz `matMul` y) & transAttrs True True
, Just $ (x `matMul` dz) & transAttrs True True ]
opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz
(CoreOps.invertPermutation p :: Tensor Value Int32)
, Nothing
]
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
& tensorAttr "data_format" .~ dataFormat
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
& tensorAttr "data_format" .~ dataFormat
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
[ Just $ CoreOps.maxPoolGrad x output dz
& tensorAttr "ksize" .~ ksize
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "data_format" .~ dataFormat
]
where
output :: Tensor Value a
output = toT $ Output 0 (Rendered nodeDef)
ksize = lookupAttr nodeDef "ksize" :: [Int64]
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "Reshape" _ [toT -> x, _] [dz] =
[Just $ reshape dz $ shape (x :: Tensor Value a), Nothing]
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
opGrad "TruncatedNormal" _ _ _ = [Nothing]
opGrad "RefIdentity" _ _ [dz] = [Just dz]
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
where
-- TODO(gnezdo): too permissive, python only allows float types as src_type.
reverseCast =
buildOp (opDef "Cast"
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString))
dz
opGrad "DynamicStitch" nodeDef inputs [dz] =
replicate halfLen Nothing ++ valuesGrads
where
halfLen =
let len = length inputs
half = len `div` 2
in if 2 * half == len
then half
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32)
| idx <- take halfLen inputs
]
opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
[ Just reconstructed, Nothing ]
where
reconstructed = CoreOps.reshape stitched
(CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32)
stitched = CoreOps.dynamicStitch partitionedIndices dz
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
np = lookupAttr nodeDef "num_partitions" :: Int64
originalIndices =
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
prefixShape = shapeInt32 indices
shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
[ Nothing
, Just $ CoreOps.select c dz zeros
, Just $ CoreOps.select c zeros dz
]
where zeros = CoreOps.zerosLike x
-- TODO(gnezdo): Unlike Python, no control dependency on dz.
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ]
-- TODO(gnezdo): Reuse the output instead of doing another exp,
-- though, it is probably CSE'd away anyway.
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ]
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
[ Just $ CoreOps.unsortedSegmentSum
(CoreOps.gather dz (t :: Tensor Value Int32))
(y :: Tensor Value Int32) inputRows
, Nothing
, Nothing
]
where inputRows = CoreOps.slice (shape (x :: Tensor Value a)) (scalar (0 :: Int32)) (scalar 1)
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing]
opGrad "Size" _ _ _ = [Nothing]
opGrad "ZerosLike" _ _ _ = [Nothing]
-- TODO(fmayle): These can go away if we properly prune the graph.
opGrad "Const" _ _ _ = [Nothing, Nothing]
opGrad "Placeholder" _ _ _ = []
opGrad "Variable" _ _ _ = []
opGrad n nodeDef ins grads =
error $ "no gradient implemented for " ++
show (n, length ins, length grads, showMessage nodeDef, ins)
-- | The number of outputs for an op type.
numOutputs :: NodeDef -> OutputIx
numOutputs o =
case o ^. op of
"Abs" -> 1
"Add" -> 1
"Cast" -> 1
"Const" -> 1
"Conv2D" -> 1
"Div" -> 1
"DynamicStitch" -> 1
"DynamicPartition" ->
fromIntegral (lookupAttr o "num_partitions" :: Int64)
"Exp" -> 1
"Gather" -> 1
"LabelClasses" -> 1
"LabelWeights" -> 1
"Log" -> 1
"MatMul" -> 1
"Max" -> 1
"MaxPool" -> 1
"Mean" -> 1
"Min" -> 1
"Mul" -> 1
"Neg" -> 1
"Placeholder" -> 1
"OneHot" -> 1
"RefIdentity" -> 1
"Relu" -> 1
"Reshape" -> 1
"Select" -> 1
"Size" -> 1
"SoftmaxCrossEntropyWithLogits" -> 2
"Square" -> 1
"SparseSegmentSum" -> 1
"Sub" -> 1
"Sum" -> 1
"Transpose" -> 1
"TruncatedNormal" -> 1
"Variable" -> 1
"ZerosLike" -> 1
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
allDimensions = vector [-1 :: Int32]
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens

View file

@ -0,0 +1,296 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | This module contains definitions for some built-in TensorFlow operations.
--
-- Note that certain, "stateful" ops like 'variable' and 'assign' return a
-- 'Build' action (e.g., @Build (Tensor Ref a)@ instead of a pure value; the
-- returned 'Tensor's are always rendered in the current 'Build' context. This
-- approach helps us avoid problems with inlining or common subexpression
-- elimination, by writing
--
-- > do
-- > v <- variable []
-- > w <- assign v 3
-- > render $ w * w
--
-- instead of
--
-- > let
-- > v = variable []
-- > w = assign v 3
-- > in w * w
--
-- since the latter could be reasonably transformed by the compiler into (or
-- vice versa)
--
-- > let
-- > v = variable []
-- > w = assign v 3
-- > w' = assign v 3
-- > in w * w'
--
-- Ops should return a 'Build' action if their original 'OpDef' marks them as
-- stateful, or if they take any Refs as input. (This mirrors the rules that
-- TensorFlow uses to avoid common subexpression elimination.)
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module TensorFlow.Ops
( CoreOps.add
, CoreOps.abs
, CoreOps.addN
, CoreOps.argMax
, assign
, CoreOps.broadcastGradientArgs
, CoreOps.cast
, CoreOps.concat
, constant
, expandDims
, initializedVariable
, zeroInitializedVariable
, CoreOps.fill
, CoreOps.matMul
, matTranspose
, CoreOps.mul
, CoreOps.neg
, CoreOps.pack
, placeholder
, CoreOps.range
, reducedShape
, CoreOps.relu
, CoreOps.reluGrad
, CoreOps.reshape
, restore
, restoreFromName
, save
, scalar
, shape
, CoreOps.sign
, CoreOps.size
, CoreOps.softmax
, CoreOps.softmaxCrossEntropyWithLogits
, CoreOps.sparseToDense
, CoreOps.sub
, CoreOps.sum
, CoreOps.topK
, CoreOps.transpose
, truncatedNormal
, variable
, vector
, zeros
, CoreOps.zerosLike
) where
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Int (Int32, Int64)
import Prelude hiding (abs, sum, concat)
import Data.ProtoLens (def)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import Text.Printf (printf)
import Proto.Tensorflow.Core.Framework.Tensor
( TensorProto
, dtype
, tensorShape
)
import qualified Proto.Tensorflow.Core.Framework.TensorShape
as TensorShape
import TensorFlow.Build
import TensorFlow.BuildOp
import TensorFlow.ControlFlow (group)
import TensorFlow.Output (unNodeName)
import TensorFlow.Tensor
import TensorFlow.Types
import qualified TensorFlow.GenOps.Core as CoreOps
import qualified Prelude (abs)
-- TODO: Look into hs-boot refactoring to allow mutually recursive imports.
-- | Must be defined as an orphan because of the dependency order between Ops
-- and Tensor.
--
-- The indirect constraint "v ~ Value" helps disambiguate types, for example in
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
-- "1".
instance ( TensorType a
, Num a
, v ~ Value
, OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) => Num (Tensor v a) where
(+) = CoreOps.add
(*) = CoreOps.mul
(-) = CoreOps.sub
abs = CoreOps.abs
fromInteger = scalar . fromInteger
signum = CoreOps.sign
negate = CoreOps.neg
matTranspose :: forall a v . TensorType a
=> Tensor v a -> Tensor Value a
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
-- | Create a new, uninitialized stateful Tensor of the given shape.
variable :: forall a . TensorType a => Shape -> Build (Tensor Ref a)
variable shape' = buildOp $ opDef "Variable"
& opAttr "shape" .~ shape'
& opAttr "dtype" .~ tensorType (undefined :: a)
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
placeholder shape' =
buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape'
-- Assign returns the input ref.
assign :: forall a v . TensorType a
=> Tensor Ref a -> Tensor v a -> Build (Tensor Ref a)
assign = buildOp $ opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a
=> Tensor Value a -> Build (Tensor Ref a)
initializedVariable initializer = do
v <- variable [] -- The shape is not known initially.
(i :: Tensor Ref a) <-
buildOp (opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
& opAttr "validate_shape" .~ False
)
v initializer
addInitializer =<< group i
return v
-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
:: (TensorType a, Num a) =>
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros
-- TODO: Support heterogeneous list of tensors.
save :: forall a v . TensorType a
=> ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save.
-> Build ControlNode
save path xs = do
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
let types = replicate (length xs) (tensorType (undefined :: a))
let saveOp = buildOp $ opDef "Save"
& opAttr "T" .~ types
saveOp (scalar path) (CoreOps.pack names) xs
-- | Restore a tensor's value from a checkpoint file.
--
-- This version allows restoring from a checkpoint file that uses a different
-- tensor name than the variable.
restoreFromName :: forall a . TensorType a
=> ByteString -- ^ File path.
-> ByteString -- ^ Tensor name override.
-> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode
restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a)
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a)
-- | Restore a tensor's value from a checkpoint file.
restore :: forall a . TensorType a
=> ByteString -- ^ File path.
-> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode
restore path x = do
name <- encodeUtf8 . unNodeName <$> renderNodeName x
restoreFromName path name x
-- | Create a constant tensor.
--
-- The values should be in row major order, e.g.,
--
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
constant (Shape shape') values
| invalidLength = error invalidLengthMsg
| otherwise = buildOp $ opDef "Const"
& opAttr "value" .~ typedNode
& opAttr "dtype" .~ nodeType
where
invalidLength = product shape' /= fromIntegral (length values)
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
(product shape')
(length values)
nodeType = tensorType (undefined :: a)
typedNode :: TensorProto
typedNode = def
& dtype .~ nodeType
& tensorShape.TensorShape.dim .~
[def & TensorShape.size .~ x | x <- shape']
& tensorVal .~ values
-- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Value a
vector xs = constant [fromIntegral $ length xs] xs
-- | Create a constant scalar.
scalar :: forall a . TensorType a => a -> Tensor Value a
scalar x = constant [] [x]
-- Random tensor from the unit normal distribution with bounded values.
truncatedNormal :: forall a v . TensorType a
=> Tensor v Int64 -- ^ Shape.
-> Build (Tensor Value a)
truncatedNormal = buildOp $ opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64)
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
shape = CoreOps.shape
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims = CoreOps.expandDims
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
reducedShape inputShape axes =
let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7]
axes32 = toInt32 axes -- [1, 2]
toInt32 x = CoreOps.cast x :: Tensor Value Int32
inputRank = CoreOps.size inputShape32 -- 4
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
axesShape = shape axesMod -- [2]
in CoreOps.dynamicStitch -- [2, 1, 1, 7]
[CoreOps.range 0 inputRank 1, -- [0, 1, 2, 3]
axesMod] -- [1, 2]
[inputShape32, -- [2, 3, 5, 7]
CoreOps.fill axesShape 1] -- [1, 1]

View file

@ -0,0 +1,191 @@
name: tensorflow-ops
version: 0.1.0.0
synopsis: Friendly layer around TensorFlow bindings.
description: Please see README.md
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Simple
cabal-version: >=1.22
library
hs-source-dirs: src
exposed-modules: TensorFlow.Gradient
, TensorFlow.Ops
, TensorFlow.EmbeddingOps
build-depends: proto-lens == 0.1.*
, base >= 4.7 && < 5
, bytestring
, fgl
, mtl
, data-default
, lens-family
, containers
, tensorflow == 0.1.*
, tensorflow-proto == 0.1.*
, tensorflow-core-ops == 0.1.*
, text
default-language: Haskell2010
Test-Suite BuildTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: BuildTest.hs
hs-source-dirs: tests
build-depends: HUnit
, base
, proto-lens
, lens-family
, google-shim
, tensorflow
, tensorflow-ops
, tensorflow-proto
, test-framework
, test-framework-hunit
, transformers
, vector
Test-Suite EmbeddingOpsTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: EmbeddingOpsTest.hs
hs-source-dirs: tests
build-depends: HUnit
, QuickCheck
, base
, proto-lens
, lens-family
, google-shim
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-proto
, test-framework
, test-framework-hunit
, test-framework-quickcheck2
, vector
Test-Suite ArrayOpsTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: ArrayOpsTest.hs
hs-source-dirs: tests
build-depends: HUnit
, QuickCheck
, base
, proto-lens
, lens-family
, google-shim
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-proto
, test-framework
, test-framework-hunit
, test-framework-quickcheck2
, transformers
, vector
Test-Suite OpsTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: OpsTest.hs
hs-source-dirs: tests
build-depends: HUnit
, QuickCheck
, base
, bytestring
, proto-lens
, lens-family
, google-shim
, temporary
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-proto
, test-framework
, test-framework-hunit
, test-framework-quickcheck2
, transformers
, vector
Test-Suite DataFlowOpsTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: DataFlowOpsTest.hs
hs-source-dirs: tests
build-depends: HUnit
, QuickCheck
, base
, proto-lens
, lens-family
, google-shim
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-proto
, test-framework
, test-framework-hunit
, test-framework-quickcheck2
, vector
Test-Suite GradientTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: GradientTest.hs
hs-source-dirs: tests
build-depends: HUnit
, base
, proto-lens
, lens-family
, google-shim
, tensorflow
, tensorflow-ops
, tensorflow-proto
, test-framework
, test-framework-hunit
Test-Suite MiscTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: MiscTest.hs
hs-source-dirs: tests
build-depends: HUnit
, base
, bytestring
, vector
, google-shim
, transformers
, tensorflow
, tensorflow-ops
, tensorflow-proto
, test-framework
, test-framework-hunit
Test-Suite TypesTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: TypesTest.hs
hs-source-dirs: tests
build-depends: HUnit
, QuickCheck
, base
, bytestring
, proto-lens
, lens-family
, google-shim
, tensorflow
, tensorflow-ops
, tensorflow-proto
, transformers
, test-framework
, test-framework-hunit
, test-framework-quickcheck2
, vector
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View file

@ -0,0 +1,42 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
module Main where
import Control.Monad.IO.Class (liftIO)
import Google.Test (googleTest)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import qualified Data.Vector as V
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.GenOps.Core as CoreOps
-- | Test split and concat are inverses.
testSplit = testCase "testSplit" $ TF.runSession $ do
let original = TF.constant [2, 3] [0..5 :: Float]
splitList = CoreOps.split 3 dim original
restored = CoreOps.concat dim splitList
dim = 1 -- dimension to split along (with size of 3 in original)
liftIO $ 3 @=? length splitList
(x, y, z) <-
TF.buildAnd TF.run $ return (original, restored, splitList !! 1)
liftIO $ x @=? (y :: V.Vector Float)
liftIO $ V.fromList [1, 4] @=? z
main :: IO ()
main = googleTest [ testSplit
]

View file

@ -0,0 +1,181 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Functor.Identity (runIdentity)
import Lens.Family2 ((^.))
import Data.List (sort)
import Proto.Tensorflow.Core.Framework.Graph
( node )
import Proto.Tensorflow.Core.Framework.NodeDef
( NodeDef
, device
, name
, op )
import TensorFlow.Build
( Build
, BuildT
, asGraphDef
, evalBuildT
, flushNodeBuffer
, hoistBuildT
, render
, withDevice
, colocateWith
, withNameScope
)
import TensorFlow.ControlFlow (named)
import TensorFlow.Nodes (unScalar)
import TensorFlow.Ops
( add
, assign
, constant
, initializedVariable
, variable
)
import TensorFlow.Output (Device(..))
import TensorFlow.Tensor (Tensor, Value, Ref)
import TensorFlow.Session
( build
, buildAnd
, run
, runSession
, run_
)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import Google.Test (googleTest)
import qualified Data.Vector as V
-- | Test named behavior.
testNamed = testCase "testNamed" $ do
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node
"RefIdentity" @=? (nodeDef ^. op)
"foo" @=? (nodeDef ^. name)
-- | Test named deRef behavior.
testNamedDeRef = testCase "testNamedDeRef" $ do
let graph = named "foo" <$> do
v :: Tensor Ref Float <- variable []
assign v 5
-- TODO: Implement TensorFlow get_variable and test it.
runSession $ do
out <- buildAnd run graph
liftIO $ 5 @=? (unScalar out :: Float)
-- | Test that "run" will render and extend any pure ops that haven't already
-- been rendered.
testPureRender = testCase "testPureRender" $ runSession $ do
result <- run $ 2 `add` 2
liftIO $ 4 @=? (unScalar result :: Float)
-- | Test that "run" assigns any previously accumulated initializers.
testInitializedVariable =
testCase "testInitializedVariable" $ runSession $ do
(formula, reset) <- build $ do
v <- initializedVariable 42
r <- assign v 24
return (1 `add` v, r)
result <- run formula
liftIO $ 43 @=? (unScalar result :: Float)
run_ reset -- Updates v to a different value
rerunResult <- run formula
liftIO $ 25 @=? (unScalar rerunResult :: Float)
testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do
vector <- build $ initializedVariable (constant [1] [42 :: Float])
result <- run vector
liftIO $ [42] @=? (result :: V.Vector Float)
-- | Test nameScoped behavior.
testNameScoped = testCase "testNameScoped" $ do
let graph = withNameScope "foo" $ variable [] :: Build (Tensor Ref Float)
nodeDef :: NodeDef
[nodeDef] = asGraphDef graph ^. node
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
"Variable" @=? (nodeDef ^. op)
-- | Test combined named and nameScoped behavior.
testNamedAndScoped = testCase "testNamedAndScoped" $ do
let graph :: Build (Tensor Ref Float)
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node
"RefIdentity" @=? (nodeDef ^. op)
"foo1/bar1" @=? (nodeDef ^. name)
-- | Lift a Build action into a context for HUnit to run.
liftBuild :: Build a -> BuildT IO a
liftBuild = hoistBuildT (return . runIdentity)
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
flushed field = sort . map field <$> liftBuild flushNodeBuffer
-- | Test the interaction of rendering, CSE and scoping.
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
liftBuild renderNodes
names <- flushed (^. name)
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
-- Render the nodes in a different scope, which should cause them
-- to be distinct from the previous ones.
liftBuild $ withNameScope "foo" renderNodes
scopedNames <- flushed (^. name)
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
where
renderNodes = do
-- A stateful op and a pure op.
_ :: Tensor Ref Float <- variable []
_ :: Tensor Value Float <- render 3
-- Another stateful op, and a pure op which should be
-- deduped with the previous one.
_ :: Tensor Ref Float <- variable []
_ :: Tensor Value Float <- render 3
return ()
-- | Test the interaction of rendering, CSE and scoping.
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
liftBuild renderNodes
devices <- flushed (\x -> (x ^. name, x ^. device))
liftIO $ [ ("Add_2","dev0")
, ("Const_1","dev0")
, ("Variable_0","dev0")] @=? devices
where
renderNodes = do
-- A stateful op and a pure op.
var :: Tensor Ref Float <- withDevice (Just $ Device "dev0") $ variable []
-- Uses render to cause the expression be added to the graph.
_ <- colocateWith var $ render $ 3 `add` var
return ()
main :: IO ()
main = googleTest [ testInitializedVariable
, testInitializedVariableShape
, testDeviceColocation
, testNamed
, testNamedDeRef
, testNameScoped
, testNamedAndScoped
, testPureRender
, testRenderDedup
]

View file

@ -0,0 +1,66 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ScopedTypeVariables #-}
import Data.Int (Int32, Int64)
import Data.List (genericLength)
import Google.Test (googleTest)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run)
import qualified Data.Vector as V
import qualified TensorFlow.GenOps.Core as CoreOps
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
-- DynamicSplit is undone with DynamicStitch to get the original input
-- back.
testDynamicPartitionStitchInverse :: forall a.
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
let splitParts :: [TF.Tensor TF.Value a] =
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
partTensor = TF.vector partitions
restitchIndices = CoreOps.dynamicPartition numParts
(TF.vector [0..genericLength values-1])
partTensor
-- drop (numParts - 2) from both args to expose b/27343984
restitch = CoreOps.dynamicStitch restitchIndices splitParts
in monadicIO $ run $ do
fromIntegral numParts @=? length splitParts
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch
V.fromList values @=? valuesOut
data StitchExample a = StitchExample Int64 [a] [Int32]
deriving Show
instance Arbitrary a => Arbitrary (StitchExample a) where
arbitrary = do
-- Limits the size of the vector.
size <- choose (1, 100)
values <- vectorOf size arbitrary
numParts <- choose (2, 15)
partitions <- vectorOf size (choose (0, fromIntegral numParts - 1))
return $ StitchExample numParts values partitions
main :: IO ()
main = googleTest
[ testProperty "DynamicPartitionStitchInverse"
(testDynamicPartitionStitchInverse :: StitchExample Int64 -> Property)
]

View file

@ -0,0 +1,88 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | Tests for EmbeddingOps.
module Main where
import Data.Int (Int32, Int64)
import Data.List (genericLength)
import Google.Test (googleTest)
import TensorFlow.EmbeddingOps (embeddingLookup)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run)
import qualified Data.Vector as V
import qualified TensorFlow.GenOps.Core as CoreOps
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
-- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
=> LookupExample a -> Property
testEmbeddingLookupUndoesSplit
(LookupExample numParts
shape@(TF.Shape (firstDim : restDims))
values
indices) =
let modShardedValues :: [TF.Tensor TF.Value a] =
CoreOps.dynamicPartition numParts shapedValues cyclicCounter
cyclicCounter :: TF.Tensor TF.Value Int32 =
TF.vector [0..fromIntegral firstDim-1]
`CoreOps.mod` fromIntegral numParts
indicesVector = TF.vector indices
directs = CoreOps.gather shapedValues indicesVector
shapedValues = TF.constant shape values
in monadicIO $ run $ do
(shapeOut, got, want :: V.Vector a) <-
TF.runSession $ TF.buildAnd TF.run $ do
embeddings <- embeddingLookup modShardedValues indicesVector
return (TF.cast (TF.shape embeddings), embeddings, directs)
-- Checks the explicitly documented invariant of embeddingLookup.
shapeOut @=? V.fromList (genericLength indices : restDims)
got @=? want
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit.
data LookupExample a = LookupExample
Int64 -- ^ number of ways to split.
TF.Shape -- ^ shape of the generated tensor
[a] -- ^ data for the tensor
[Int32] -- ^ indices to split the tensor by
deriving Show
instance Arbitrary a => Arbitrary (LookupExample a) where
arbitrary = do
rank <- choose (1, 4)
-- Takes rank-th root of 100 to cap the tensor size.
let maxDim = fromIntegral $ ceiling $ 100 ** (1 / fromIntegral rank)
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
values <- vectorOf (fromIntegral $ product shape) arbitrary
numParts <- choose (2, 15)
indSize <- choose (0, fromIntegral $ firstDim - 1)
indices <- vectorOf indSize (choose (0, fromIntegral firstDim - 1))
return $ LookupExample numParts (TF.Shape shape) values indices
main :: IO ()
main = googleTest
[ testProperty "EmbeddingLookupUndoesSplit"
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
]

View file

@ -0,0 +1,158 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
import Data.List (sort)
import Data.ProtoLens.TextFormat (showMessage)
import Google.Test (googleTest)
import Lens.Family2 ((^..))
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Framework.NodeDef (op)
testGradientSimple :: Test
testGradientSimple = testCase "testGradientSimple" $ do
let x = TF.scalar (3 :: Float)
b = TF.scalar (4 :: Float)
y = x*x + b
grads = TF.gradients y [x, b]
-- Assert that the gradients are right.
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
6 @=? TF.unScalar dx
1 @=? TF.unScalar db
-- Assert that the graph has the expected ops.
let graphDef = TF.asGraphDef grads
putStrLn $ showMessage graphDef
let ops = graphDef ^.. node . traverse . op
expected = [ "Const"
, "Mul"
, "Const"
, "Add"
-- Default output gradient of y.
, "Shape"
, "Const"
, "Fill"
-- Add gradient.
, "Shape"
, "Shape"
, "BroadcastGradientArgs"
, "Sum"
, "Sum"
, "Reshape"
, "Reshape"
-- Mul gradient.
, "Shape"
-- This Op gets dedup'd because the inputs are the same.
-- TODO(fmayle): The same would happen to the Mul and Sum ops
-- below if the gradient function didn't multiply one as
-- 'dz * y' and the other as 'x * dz'. We could change the
-- order, but I'm going to keep it the same as the python
-- version for now.
--
-- , "Shape"
, "BroadcastGradientArgs"
, "Mul"
, "Mul"
, "Sum"
, "Sum"
, "Reshape"
, "Reshape"
-- AddN to combine x's output gradients.
, "AddN"
]
sort expected @=? sort ops
testGradientDisconnected :: Test
testGradientDisconnected = testCase "testGradientDisconnected" $ do
let x = TF.scalar (3 :: Float)
b = TF.scalar (4 :: Float)
grads = TF.gradients x [x, b]
-- Assert that the gradients are right.
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
1 @=? TF.unScalar dx
0 @=? TF.unScalar db
-- Assert that the graph has the expected ops.
let graphDef = TF.asGraphDef grads
putStrLn $ showMessage graphDef
let ops = graphDef ^.. node . traverse . op
expected = [ "Const"
, "Const"
-- Default output gradient of x.
, "Shape"
, "Const"
, "Fill"
-- Default output gradient of b.
, "ZerosLike"
]
sort expected @=? sort ops
-- Test that identical "stateful" ops work with createGraph.
testCreateGraphStateful :: Test
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do
let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
TF.gradients (x + y*3) [x, y]
-- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. These asserts are extra.
1 @=? TF.unScalar dx
3 @=? TF.unScalar dy
-- Test that name scopes work with createGraph.
testCreateGraphNameScopes :: Test
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <-
TF.withNameScope "foo" (TF.truncatedNormal shape)
TF.gradients x [x]
-- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. This assert is extra.
1 @=? TF.unScalar dx
-- Test that createGraph can handle graphs with diamond shapes.
testDiamond :: Test
testDiamond = testCase "testDiamond" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
let x = TF.vector [1]
y = x*x
z = y*y
TF.gradients z [x]
(4 :: Float) @=? TF.unScalar dx
main :: IO ()
main = googleTest [ testGradientSimple
, testGradientDisconnected
, testCreateGraphStateful
, testCreateGraphNameScopes
, testDiamond
]

View file

@ -0,0 +1,46 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE RankNTypes #-}
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import Google.Test
import qualified Data.Vector as V
import TensorFlow.Ops
import TensorFlow.Session
-- | Test fetching multiple outputs from an op.
testMultipleOutputs = testCase "testMultipleOutputs" $
runSession $ do
(values, indices) <- run $ topK 2 $ constant [1, 4] [10, 40, 20, 30]
liftIO $ [40, 30] @=? V.toList (values :: V.Vector Float)
liftIO $ [1, 3] @=? V.toList (indices :: V.Vector Int32)
-- | Test op with variable number of inputs.
testVarargs = testCase "testVarargs" $
runSession $ do
xs <- run $ pack $ map scalar [1..8]
liftIO $ [1..8] @=? V.toList (xs :: V.Vector Float)
main :: IO ()
main = googleTest [ testMultipleOutputs
, testVarargs
]

View file

@ -0,0 +1,70 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64)
import Google.Test (googleTest)
import System.IO.Temp (withSystemTempDirectory)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
-- | Test that one can easily determine number of elements in the tensor.
testSize = testCase "testSize" $ do
x <- eval $ TF.size (TF.constant [2, 3] [0..5 :: Float])
TF.Scalar (2 * 3 :: Int32) @=? x
eval = TF.runSession . TF.buildAnd TF.run . return
-- | Confirms that the original example from Python code works.
testReducedShape = testCase "testReducedShape" $ do
x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64])
(TF.vector [1, 2 :: Int32])
V.fromList [2, 1, 1, 7 :: Int32] @=? x
testSaveRestore = testCase "testSaveRestore" $
withSystemTempDirectory "" $ \dirPath -> do
let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.Build (TF.Tensor TF.Ref Float)
var = TF.render =<< TF.named "a" <$> TF.zeroInitializedVariable []
TF.runSession $ do
v <- TF.build var
TF.buildAnd TF.run_ $ TF.assign v 134
TF.buildAnd TF.run_ $ TF.save path [v]
result <- TF.runSession $ do
v <- TF.build var
TF.buildAnd TF.run_ $ TF.restore path v
TF.run v
liftIO $ TF.Scalar 134 @=? result
main :: IO ()
main = googleTest [ testSaveRestore
, testSize
, testReducedShape
]

View file

@ -0,0 +1,119 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
import Control.Monad (replicateM)
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Google.Test (googleTest)
import Test.Framework.Providers.HUnit (testCase)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.QuickCheck (Arbitrary(..), listOf, suchThat)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
instance Arbitrary B.ByteString where
arbitrary = B.pack <$> arbitrary
-- Test encoding tensors, feeding them through tensorflow, and decoding the
-- results.
testFFIRoundTrip = testCase "testFFIRoundTrip" $
TF.runSession $ do
let floatData = V.fromList [1..6 :: Float]
stringData = V.fromList [B8.pack (show x) | x <- [1..6]]
f <- TF.build $ TF.placeholder [2,3]
s <- TF.build $ TF.placeholder [2,3]
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
, TF.feed s (TF.encodeTensorData [2,3] stringData)
]
-- It is an error to fetch a tensor that is being fed, so the tensors
-- are passed through identity.
(f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s)
liftIO $ do
floatData @=? f'
stringData @=? s'
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
deriving Show
instance Arbitrary a => Arbitrary (TensorDataInputs a) where
arbitrary = do
-- Limit the size of the final vector, and also guard against overflow
-- (i.e., p<0) when there are too many dimensions
let validProduct p = p > 0 && p < 100
sizes <- listOf (arbitrary `suchThat` (>0))
`suchThat` (validProduct . product)
elems <- replicateM (fromIntegral $ product sizes) arbitrary
return $ TensorDataInputs sizes (V.fromList elems)
-- Test that a vector is unchanged after being encoded and decoded.
encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
encodeDecodeProp (TensorDataInputs shape vec) =
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
doubleOrInt64Func :: TF.OneOf '[Double, Int64] a => a -> a
doubleOrInt64Func = id
doubleOrFloatFunc :: TF.OneOf '[Double, Float] a => a -> a
doubleOrFloatFunc = id
doubleFunc :: TF.OneOf '[Double] a => a -> a
doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
-- No explicit type signature; make sure it can be inferred automatically.
-- (Note: this would fail if we didn't have NoMonomorphismRestriction, since it
-- can't simplify the type all the way to `Double -> Double`.
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
typeConstraintTests = testCase "type constraints" $ do
42 @=? doubleOrInt64Func (42 :: Double)
42 @=? doubleOrInt64Func (42 :: Int64)
42 @=? doubleOrFloatFunc (42 :: Double)
42 @=? doubleOrFloatFunc (42 :: Float)
42 @=? doubleFunc (42 :: Double)
42 @=? doubleFuncNoSig (42 :: Double)
main :: IO ()
main = googleTest [ testFFIRoundTrip
, testEncodeDecodeQcFloat
, testEncodeDecodeQcInt64
, testEncodeDecodeQcString
, typeConstraintTests
]

17
tensorflow-proto/Setup.hs Normal file
View file

@ -0,0 +1,17 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
import Data.ProtoLens.Setup
main = defaultMainGeneratingProtos "../third_party/tensorflow"

View file

@ -0,0 +1,40 @@
name: tensorflow-proto
version: 0.1.0.0
synopsis: TensorFlow protocol buffers.
description: Please see README.md
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Custom
cabal-version: >=1.22
extra-source-files: ../third_party/tensorflow/tensorflow/core/framework/*.proto
, ../third_party/tensorflow/tensorflow/core/protobuf/config.proto
library
exposed-modules: Proto.Tensorflow.Core.Framework.AttrValue
, Proto.Tensorflow.Core.Framework.Graph
, Proto.Tensorflow.Core.Framework.NodeDef
, Proto.Tensorflow.Core.Framework.OpDef
, Proto.Tensorflow.Core.Framework.ResourceHandle
, Proto.Tensorflow.Core.Framework.Tensor
, Proto.Tensorflow.Core.Framework.TensorShape
, Proto.Tensorflow.Core.Framework.Types
, Proto.Tensorflow.Core.Protobuf.Config
other-modules: Proto.Tensorflow.Core.Framework.AllocationDescription
, Proto.Tensorflow.Core.Framework.CostGraph
, Proto.Tensorflow.Core.Framework.Function
, Proto.Tensorflow.Core.Framework.StepStats
, Proto.Tensorflow.Core.Framework.TensorDescription
, Proto.Tensorflow.Core.Framework.Versions
build-depends: proto-lens == 0.1.*
, proto-lens-protoc == 0.1.*
, base >= 4.7 && < 5
default-language: Haskell2010
include-dirs: .
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View file

@ -0,0 +1,3 @@
import Distribution.Simple
main = defaultMain

View file

@ -0,0 +1,78 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | Queues in TensorFlow graph. Very limited support for now.
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Lens.Family2 ((.~), (&))
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
import TensorFlow.BuildOp (buildOp)
import TensorFlow.ControlFlow (group)
import TensorFlow.Tensor (Ref, Tensor)
import TensorFlow.Types (TensorType, tensorType)
-- | A queue carrying tuples. The underlying structure is more
-- versatile and can be made to support arbitrary tuples.
data Queue2 a b = Queue2 { handle :: Handle }
type Handle = Tensor Ref ByteString
-- | Adds the given values to the queue.
enqueue :: forall a b v1 v2. (TensorType a, TensorType b)
=> Queue2 a b
-> Tensor v1 a
-> Tensor v2 b
-> Build ControlNode
enqueue q =
buildOp (opDef "QueueEnqueue"
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a)
, tensorType (undefined :: b)])
(handle q)
-- | Retrieves the values from the queue.
dequeue :: forall a b . (TensorType a, TensorType b)
=> Queue2 a b
-> Build (Tensor Ref a, Tensor Ref b)
-- ^ Dequeued tensors. They are paired in a sense
-- that values appear together, even if they are
-- not consumed together.
dequeue q =
buildOp (opDef "QueueDequeue"
& opAttr "component_types" .~ [ tensorType (undefined :: a)
, tensorType (undefined :: b)])
(handle q)
-- | Creates a new queue with the given capacity and shared name.
makeQueue2 :: forall a b . (TensorType a, TensorType b)
=> Int64 -- ^ The upper bound on the number of elements in
-- this queue. Negative numbers mean no limit.
-> ByteString -- ^ If non-empty, this queue will be shared
-- under the given name across multiple sessions.
-> Build (Queue2 a b)
makeQueue2 capacity sharedName = do
q <- buildOp (opDef "FIFOQueue"
& opAttr "component_types" .~ [ tensorType (undefined :: a)
, tensorType (undefined :: b)]
& opAttr "shared_name" .~ sharedName
& opAttr "capacity" .~ capacity
)
group q >>= addInitializer
return (Queue2 q)
-- TODO(gnezdo): Figure out the closing story for queues.

View file

@ -0,0 +1,51 @@
name: tensorflow-queue
version: 0.1.0.0
synopsis: Basic access to TensorFlow queues.
description: Please see README.md
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Simple
cabal-version: >=1.22
library
hs-source-dirs: src
exposed-modules: TensorFlow.Queue
build-depends: proto-lens == 0.1.*
, base >= 4.7 && < 5
, bytestring
, lens-family
, containers
, tensorflow-proto == 0.1.*
, tensorflow-core-ops == 0.1.*
, tensorflow
, text
default-language: Haskell2010
Test-Suite QueueTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: QueueTest.hs
hs-source-dirs: tests
-- Uses multiple threads and blocks without this option.
ghc-options: -threaded
build-depends: HUnit
, base
, bytestring
, proto-lens
, lens-family
, google-shim
, tensorflow
, tensorflow-ops
, tensorflow-queue
, test-framework
, test-framework-hunit
, transformers
, vector
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View file

@ -0,0 +1,79 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Google.Test (googleTest)
import TensorFlow.Nodes (Scalar(..))
import TensorFlow.Ops (scalar)
import TensorFlow.Queue
import TensorFlow.Session
( asyncProdNodes
, build
, buildAnd
, run
, runSession
, run_
)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import qualified Data.ByteString as BS
-- | Test basic queue behaviors.
testBasic = testCase "testBasic" $ runSession $ do
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
buildAnd run_ (enqueue q 42 (scalar "Hi"))
x <- buildAnd run (dequeue q)
liftIO $ (Scalar 42, Scalar "Hi") @=? x
buildAnd run_ (enqueue q 56 (scalar "Bar"))
y <- buildAnd run (dequeue q)
liftIO $ (Scalar 56, Scalar "Bar") @=? y
-- | Test queue pumping.
testPump = testCase "testPump" $ runSession $ do
(deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
(,) <$> dequeue q
<*> enqueue q 31 (scalar "Baz")
-- This is a realistic use. The pump inputs are pre-bound to some
-- nodes that produce values when pumped (e.g. read from a
-- file).
run_ (pump, pump)
(x, y) <- run (deq, deq)
liftIO $ (Scalar 31, Scalar "Baz") @=? x
liftIO $ (Scalar 31, Scalar "Baz") @=? y
testAsync = testCase "testAsync" $ runSession $ do
(deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
(,) <$> dequeue q
<*> enqueue q 10 (scalar "Async")
-- Pumps the queue until canceled by runSession exiting.
asyncProdNodes pump
-- Picks up a couple values and verifies they are as expected.
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
main :: IO ()
main = googleTest [ testBasic
, testPump
, testAsync
]

3
tensorflow/Setup.hs Normal file
View file

@ -0,0 +1,3 @@
import Distribution.Simple
main = defaultMain

View file

@ -0,0 +1,376 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Build
( -- * Graph node types
ControlNode(..)
, Unique
-- * Ops
, explicitName
, implicitName
, opDef
, opDefWithName
, opName
, opType
, opAttr
, opInputs
, opControlInputs
-- * The Build monad
, GraphState
, render
, renderNodeName
, renderedNodeDefs
, BuildT
, Build
, addInitializer
, hoistBuildT
, evalBuildT
, runBuildT
, asGraphDef
, addGraphDef
, flushInitializers
, flushNodeBuffer
-- * Creating and looking up Ops
, getOrAddOp
, addNewOp
, renderOutput
-- * Modifying all nodes in a Build action
, colocateWith
, withStateLens
, withDevice
, withNameScope
, withNodeDependencies
-- * Internal Summary related bits.
, addSummary
, SummaryTensor
, collectAllSummaries
) where
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
import Data.ByteString (ByteString)
import Data.Default (def)
import Data.Functor.Identity (Identity(..))
import qualified Data.Map.Strict as Map
import Data.Monoid ((<>))
import qualified Data.Set as Set
import Data.Set (Set)
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens', (.~), (^.), (&))
import Lens.Family2.State.Strict (MonadState, use, uses, (.=), (<>=), (%=))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph
( GraphDef
, node
)
import Proto.Tensorflow.Core.Framework.NodeDef
( NodeDef
, attr
, input
, device
, name
, op
)
import TensorFlow.Orphans ()
import TensorFlow.Output
import TensorFlow.Tensor
newtype Unique = Unique Int
deriving (Eq, Ord, Enum)
--------------
implicitName :: PendingNodeName
implicitName = ImplicitName
explicitName :: Text -> PendingNodeName
explicitName = ExplicitName
newtype Scope = Scope {unScope :: Text}
deriving (Eq, Ord, IsString)
instance Show Scope where
show = show . unScope
opDef :: OpType -> OpDef
opDef = opDefWithName ImplicitName
opDefWithName :: PendingNodeName -> OpType -> OpDef
opDefWithName n t = OpDef
{ _opName = n
, _opType = t
, _opAttrs = Map.empty
, _opInputs = []
, _opControlInputs = []
}
-- | Synonym for the tensors that return serialized Summary proto.
type SummaryTensor = Tensor Value ByteString
data GraphState = GraphState
{ _renderedNodes :: !(Map.Map PendingNode NodeDef)
-- ^ Nodes which have been rendered. Keeps track of the unique ID we
-- assign each implicitly-named node. Also prevents us from adding the
-- same node (implicit or explicit) more than once to the nodeBuffer.
, _renderedNodeDefs :: !(Map.Map NodeName NodeDef)
-- ^ The NodeDefs of nodes which have been rendered. Used by the
-- Gradient module to inspect the node graph.
, _nodeBuffer :: [NodeDef]
-- ^ A list of nodes that should be passed to TensorFlow during
-- the next call to Session.extend (TF_ExtendGraph).
, _nextUnique :: !Unique
-- ^ Unique ID for the next node
-- TODO(judahjacobson): watch for clashes between auto and user names.
, _defaultDevice :: !(Maybe Device)
, _currentScope :: [Scope]
, _defaultControlInputs :: !(Set NodeName)
, _initializationNodes :: [NodeName]
-- ^ The nodes to run next time a TF.run is issued, typically
-- variable initializers.
, _summaries :: [SummaryTensor]
-- ^ The tensors for summary
}
-- | A node definition without its final name. Used as a key in the
-- "renderedNodes" map.
-- The NodeDef contained inside has an empty "name" field.
data PendingNode = PendingNode [Scope] !PendingNodeName !NodeDef
deriving (Eq, Ord)
-- Returns an _incomplete_ NodeDef. The name is fixed by addNewOpFromPending.
pendingNodeDef :: PendingNode -> NodeDef
pendingNodeDef (PendingNode _ _ n) = n
initGraphState :: GraphState
initGraphState =
GraphState Map.empty Map.empty [] (Unique 0) Nothing [] Set.empty [] []
renderedNodes :: Lens' GraphState (Map.Map PendingNode NodeDef)
renderedNodes = lens _renderedNodes (\g x -> g { _renderedNodes = x })
renderedNodeDefs :: Lens' GraphState (Map.Map NodeName NodeDef)
renderedNodeDefs = lens _renderedNodeDefs (\g x -> g { _renderedNodeDefs = x })
nodeBuffer :: Lens' GraphState [NodeDef]
nodeBuffer = lens _nodeBuffer (\g x -> g { _nodeBuffer = x })
nextUnique :: Lens' GraphState Unique
nextUnique = lens _nextUnique (\g x -> g { _nextUnique = x })
defaultDevice :: Lens' GraphState (Maybe Device)
defaultDevice = lens _defaultDevice (\g x -> g { _defaultDevice = x })
currentScope :: Lens' GraphState [Scope]
currentScope = lens _currentScope (\g x -> g { _currentScope = x })
defaultControlInputs :: Lens' GraphState (Set NodeName)
defaultControlInputs = lens _defaultControlInputs
(\g x -> g { _defaultControlInputs = x })
initializationNodes :: Lens' GraphState [NodeName]
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
summaries :: Lens' GraphState [SummaryTensor]
summaries = lens _summaries (\g x -> g { _summaries = x })
-- | An action for building nodes in a TensorFlow graph.
-- Used to manage build state internally as part of the @Session@ monad.
newtype BuildT m a = BuildT (StateT GraphState m a)
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
MonadState GraphState)
-- | An action for building nodes in a TensorFlow graph.
type Build = BuildT Identity
-- | This is Control.Monad.Morph.hoist sans the dependency.
hoistBuildT :: (forall a . m a -> n a) -> BuildT m b -> BuildT n b
hoistBuildT f (BuildT m) = BuildT $ mapStateT f m
runBuildT :: BuildT m a -> m (a, GraphState)
runBuildT (BuildT f) = runStateT f initGraphState
evalBuildT :: Monad m => BuildT m a -> m a
evalBuildT (BuildT f) = evalStateT f initGraphState
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
flushNodeBuffer :: Monad m => BuildT m [NodeDef]
flushNodeBuffer = do
ns <- use nodeBuffer
nodeBuffer .= []
return ns
-- | Get all the initializers that have accumulated so far, and clear
-- that buffer.
flushInitializers :: Monad m => BuildT m [NodeName]
flushInitializers = do
ns <- use initializationNodes
initializationNodes .= []
return ns
-- | Registers the given node to be executed before the next
-- 'TensorFlow.Session.run'.
addInitializer :: ControlNode -> Build ()
addInitializer (ControlNode o) = do
i <- getOrAddOp o
initializationNodes %= (i:)
-- | Produce a GraphDef proto representation of the nodes that are rendered in
-- the given 'Build' action.
asGraphDef :: Build a -> GraphDef
asGraphDef b = def & node .~ gs ^. nodeBuffer
where
gs = snd $ runIdentity $ runBuildT b
-- TODO: check against existing nodes for conflicts?
addGraphDef :: GraphDef -> Build ()
addGraphDef g = nodeBuffer <>= g ^. node
-- | Render the given op if it hasn't been rendered already, and return its
-- name.
getOrAddOp :: Op -> Build NodeName
getOrAddOp o = NodeName . (^. name) <$> resolveOp o
resolveOp :: Op -> Build NodeDef
resolveOp (Rendered n) = return n
resolveOp (Unrendered o) = do
pending <- getPendingNode o
uses renderedNodes (Map.lookup pending) >>= \case
Just n -> return n
Nothing -> addNewOpFromPending pending
-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops
-- which are not safe to dedup (e.g, "variable" and "assign").
addNewOp :: OpDef -> Build NodeDef
addNewOp o = getPendingNode o >>= addNewOpFromPending
addNewOpFromPending :: PendingNode -> Build NodeDef
addNewOpFromPending pending = do
nodeName <- renderPendingNode pending
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
nodeBuffer %= (nodeDef :)
renderedNodes %= Map.insert pending nodeDef
renderedNodeDefs %= Map.insert nodeName nodeDef
return nodeDef
-- | Get the pending node corresponding to an OpDef, which may or may not have
-- been rendered before. Implicitly renders all of this node's inputs.
getPendingNode :: OpDef -> Build PendingNode
getPendingNode o = do
-- An empty string in the proto field means that no specific
-- device is specified.
dev <- maybe "" deviceName <$> use defaultDevice
inputs <- mapM getInput (o ^. opInputs)
scope <- use currentScope
controls <- use defaultControlInputs
let controlInputs
= map getDep (o ^. opControlInputs ++ Set.toList controls)
return $ PendingNode scope (o ^. opName)
$ def & op .~ (unOpType (o ^. opType) :: Text)
& attr .~ _opAttrs o
& input .~ (inputs ++ controlInputs)
& device .~ dev
where
getInput (Output (OutputIx k) subOp)
= (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp
getDep = ("^" <>) . unNodeName
-- | Pick a name for a pending node. If it has an explicit name, just use that;
-- if the name is implicit, assign a new unique name based on the op type.
renderPendingNode :: PendingNode -> Build NodeName
renderPendingNode (PendingNode scope pendingName nodeDef)
= NodeName . (scopePrefix <>) <$> getName
where
scopePrefix = Text.concat $ fmap ((<> "/") . unScope) scope
getName = case pendingName of
ExplicitName n -> return n
ImplicitName -> do
u@(Unique k) <- use nextUnique
nextUnique .= succ u
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
-- | Render an 'Output' and return a string representation for the TensorFlow
-- foreign APIs.
renderOutput :: Output -> Build Text
renderOutput (Output (OutputIx i) o) = do
n <- getOrAddOp o
return $ unNodeName n <> Text.pack (":" ++ show i)
-- | Modify some part of the state, run an action, and restore the state
-- after that action is done.
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b
withStateLens accessor f act = do
old <- use accessor
accessor %= f
result <- act
accessor .= old
return result
-- | Set a device for all nodes rendered in the given 'Build' action
-- (unless further overridden by another use of withDevice).
withDevice :: Maybe Device -> Build a -> Build a
withDevice d = withStateLens defaultDevice (const d)
-- | Places all nodes rendered in the given 'Build' action on the same
-- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect.
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a
colocateWith t x = do
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
withDevice (Just d) x
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
withNameScope :: Text -> Build a -> Build a
withNameScope s = withStateLens currentScope (Scope s :)
-- | Add control inputs to all nodes rendered in the given 'Build' action.
withNodeDependencies :: Set NodeName -> Build a -> Build a
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
-- the 'Build' context. Also renders any dependencies of the 'Tensor' that
-- weren't already rendered.
--
-- This operation is idempotent; @render >=> render === render@. However,
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
-- may result in two different 'Tensor's.
render :: Tensor v a -> Build (Tensor v a)
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
-- | Render a 'Tensor' and get its node's name.
renderNodeName :: Tensor v a -> Build NodeName
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
-- | Records the given summary action in Build for retrieval with
-- 'collectAllSummaries'. The summary op is required to produce a
-- Summary protocol buffer in string form. For safety, use the
-- pre-composed functions: Logging.scalarSummary and
-- Logging.histogramSummary.
addSummary :: SummaryTensor -> Build ()
addSummary t = summaries %= (t :)
-- | Retrieves the summary ops collected thus far. Typically this only
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
-- repeatedly, the values accumulate.
collectAllSummaries :: Monad m => BuildT m [SummaryTensor]
collectAllSummaries = use summaries

View file

@ -0,0 +1,199 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.BuildOp
( OpResult
, BuildOp
, buildOp
, buildListOp
, eqLengthGuard
)
where
import Control.Monad (replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, runState, get, put)
import Data.Int (Int64)
import Lens.Family2 ((&), (<>~), (^.))
import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
data ResultState = ResultState !OutputIx [Int64] deriving Show
type Result = ReaderT Op (State ResultState)
-- | Class of types that can be used as op outputs.
class OpResult a where
toResult :: Result a
instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where
toResult = (,) <$> toResult <*> toResult
instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where
toResult = (,,) <$> toResult <*> toResult <*> toResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4)
=> OpResult (a1, a2, a3, a4) where
toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5)
=> OpResult (a1, a2, a3, a4, a5) where
toResult = (,,,,) <$> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
instance ( OpResult a1
, OpResult a2
, OpResult a3
, OpResult a4
, OpResult a5
, OpResult a6
)
=> OpResult (a1, a2, a3, a4, a5, a6) where
toResult = (,,,,,)
<$> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
tensorResult :: TensorKind v -> Result (Tensor v a)
tensorResult v = do
o <- ask
ResultState i ns <- get
put $! ResultState (i+1) ns
return $! Tensor v $ output i o
instance OpResult (Tensor Value a) where
toResult = tensorResult ValueKind
instance OpResult (Tensor Ref a) where
toResult = tensorResult RefKind
instance OpResult ControlNode where
toResult = ControlNode <$> ask
instance OpResult a => OpResult [a] where
toResult = do
ResultState i ns <- get
case ns of
[] -> error $ "Ran out of counts in toResult. " ++
"Likely misuse of buildListOp."
(n : rest) -> do
put $! ResultState i rest
replicateM (fromIntegral n) toResult
runResult :: OpResult a => [Int64] -> Op -> a
runResult ns o =
case runState (runReaderT toResult o) (ResultState 0 ns) of
(x, ResultState _ []) -> x
(_, ns') -> error $ "Ununsed length in runResult attributes: " ++
show (ns, ns')
-- | Make a new "pure" op, which may be deduped with identical ops within
-- the same scope.
pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a
pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts
-- | Make a new "stateful" op, which will not be deduped with otherwise
-- identical ops.
buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a
buildResult ns o ts
= runResult ns . Rendered <$> addNewOp (addReversedInputs o ts)
addReversedInputs :: OpDef -> [Output] -> OpDef
addReversedInputs o ts = o & opInputs <>~ reverse ts
-- | Class of types that can be used as op functions.
class BuildOp f where
buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr)
-> OpDef
-> [Output] -- ^ Accumulator for inputs to the op.
-> f
-- | Starts an operation that returns a structured set of tensors
-- (singletons or tuples).
buildOp :: BuildOp f => OpDef -> f
buildOp o = buildOp' [] o []
-- | Starts an operation that returns a list of tensors.
buildListOp :: BuildOp f => [Int64]
-- ^ Cardinality of the corresponding list of tensors output.
-> OpDef -> f
buildListOp counts o = buildOp' counts o []
instance BuildOp ControlNode where
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
instance BuildOp (Tensor Value a) where
buildOp' = pureResult
instance BuildOp (Tensor Ref a) where
buildOp' = pureResult
instance BuildOp [Tensor Value a] where
buildOp' = pureResult
instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4)
=> BuildOp (t1, t2, t3, t4) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5)
=> BuildOp (t1, t2, t3, t4, t5) where
buildOp' = pureResult
instance ( OpResult t1
, OpResult t2
, OpResult t3
, OpResult t4
, OpResult t5
, OpResult t6
)
=> BuildOp (t1, t2, t3, t4, t5, t6) where
buildOp' = pureResult
instance OpResult a => BuildOp (Build a) where
buildOp' = buildResult
instance BuildOp f => BuildOp (Tensor v a -> f) where
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
instance BuildOp f => BuildOp ([Tensor v a] -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
-- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not.
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
eqLengthGuard = all eachOk
where
eachOk (_, []) = True
-- The next line has (== 1) . length . nub in disguise
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
error ("number_attr " ++ numberAttrName ++
" contains tensors with different length " ++ show pairs)

View file

@ -0,0 +1,87 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.ControlFlow
( -- * Dependencies
withControlDependencies
, group
-- * Operations
, identity
, noOp
, named
) where
import qualified Data.Set as Set
import Data.Text (Text)
import Lens.Family2 ((&), (^.), (.~))
import TensorFlow.BuildOp
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
-- on the nodes in the first argument.
withControlDependencies :: Nodes t => t -> Build a -> Build a
withControlDependencies deps act = do
nodes <- getNodes deps
withNodeDependencies nodes act
-- TODO(judahjacobson): Reimplement withDependencies.
-- | Create an op that groups multiple operations.
--
-- When this op finishes, all ops in the input @n@ have finished. This op has
-- no output.
group :: Nodes t => t -> Build ControlNode
group deps = do
nodes <- Set.toList <$> getNodes deps
-- TODO: slicker way
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
-- | Returns a 'Tensor' with the same shape and contents as the input.
identity :: TensorType a => Tensor v a -> Tensor v a
identity = namedIdentity implicitName
-- | Returns a 'Tensor' with a given name and the same shape and contents as
-- the input.
--
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
-- whether we can change that op.
named :: TensorType a => Text -> Tensor v a -> Tensor v a
named = namedIdentity . explicitName
-- | An internal version of "identity" that allows setting the name
-- of the output Tensor.
namedIdentity :: forall a v . TensorType a
=> PendingNodeName -> Tensor v a -> Tensor v a
namedIdentity n t = case t ^. tensorKind of
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
where
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
-- | Does nothing. Only useful as a placeholder for control edges.
noOp :: ControlNode
noOp = buildOp $ opDef "NoOp"

View file

@ -0,0 +1,243 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, extendGraph
, run
, TensorData(..)
, setSessionConfig
, setSessionTarget
, getAllOpList
-- * Internal helper.
, useProtoAsVoidPtrLen
)
where
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when)
import Data.Int (Int64)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw
data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Show, Eq, Typeable)
instance Exception TensorFlowException
-- | All of the data needed to represent a tensor.
data TensorData = TensorData
{ tensorDataDimensions :: [Int64]
, tensorDataType :: !DataType
, tensorDataBytes :: !(S.Vector Word8)
}
deriving (Show, Eq)
-- | Runs the given action after creating a session with options
-- populated by the given optionSetter.
withSession :: (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> IO a)
-- ^ The action can spawn concurrent tasks which will
-- be canceled before withSession returns.
-> IO a
withSession optionSetter action = do
drain <- newMVar []
let cleanup s =
-- Closes the session to nudge the pending run calls to fail and exit.
finally (checkStatus (Raw.closeSession s)) $ do
runners <- takeMVar drain
-- Collects all runners before deleting the session.
mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s)
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
optionSetter options
bracket
(checkStatus (Raw.newSession options))
cleanup
(action (asyncCollector drain))
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
where
launchAndRecord restRunners = (: restRunners) <$> async runner
shutDownRunner :: Async () -> IO ()
shutDownRunner r = do
cancel r
-- TODO(gnezdo): manage exceptions better than print.
either print (const (return ())) =<< waitCatch r
extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph session pb =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus $ Raw.extendGraph session ptr len
run :: Raw.Session
-> [(B.ByteString, TensorData)] -- ^ Feeds.
-> [B.ByteString] -- ^ Fetches.
-> [B.ByteString] -- ^ Targets.
-> IO [TensorData]
run session feeds fetches targets = do
let nullTensor = Raw.Tensor nullPtr
-- Use mask to avoid leaking input tensors before they are passed to 'run'
-- and output tensors before they are passed to 'createTensorData'.
mask_ $
-- Feeds
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
withArrayLen feedTensors $ \_ cFeedTensors ->
-- Fetches.
withStringArrayLen fetches $ \fetchesLen fetchNames ->
-- tensorOuts is an array of null Tensor pointers that will be filled
-- by the call to Raw.run.
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
-- Targets.
withStringArrayLen targets $ \targetsLen ctargets -> do
checkStatus $ Raw.run
session
nullPtr
feedNames cFeedTensors (fromIntegral feedsLen)
fetchNames tensorOuts (fromIntegral fetchesLen)
ctargets (fromIntegral targetsLen)
nullPtr
outTensors <- peekArray fetchesLen tensorOuts
mapM createTensorData outTensors
-- Internal.
-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
-- | Use a list of ByteString as an array of CString.
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
-- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do
let len = S.length byteVec
dest <- mallocArray len
S.unsafeWith byteVec $ \x -> copyArray dest x len
Raw.newTensor (toEnum $ fromEnum dt)
cdims (fromIntegral cdimsLen)
(castPtr dest) (fromIntegral len)
tensorDeallocFunPtr nullPtr
{-# NOINLINE tensorDeallocFunPtr #-}
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
-- | Create a TensorData from a Raw.Tensor.
--
-- Takes ownership of the Raw.Tensor.
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do
-- Read dimensions.
numDims <- Raw.numDims t
dims <- mapM (Raw.dim t) [0..numDims-1]
-- Read type.
dtype <- toEnum . fromEnum <$> Raw.tensorType t
-- Read data.
len <- fromIntegral <$> Raw.tensorByteSize t
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
-- TODO(fmayle): Don't copy the data.
v <- S.fromList <$> peekArray len bytes
-- Free tensor.
Raw.deleteTensor t
return $ TensorData (map fromIntegral dims) dtype v
-- | Runs the given action which does FFI calls updating a provided
-- status object. If the status is not OK it is thrown as
-- TensorFlowException.
checkStatus :: (Raw.Status -> IO a) -> IO a
checkStatus fn =
bracket Raw.newStatus Raw.deleteStatus $ \status -> do
result <- fn status
code <- Raw.getCode status
when (code /= Raw.TF_OK) $ do
msg <- T.decodeUtf8With T.lenientDecode <$>
(Raw.message status >>= B.packCString)
throwIO $ TensorFlowException code msg
return result
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
setSessionConfig pb opt =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus (Raw.setConfig opt ptr len)
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
setSessionTarget target = B.useAsCString target . Raw.setTarget
-- | Serializes the given msg and provides it as (ptr,len) argument
-- to the given action.
useProtoAsVoidPtrLen :: (Message msg, Num c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
\(bytes, len) -> f (castPtr bytes) (fromIntegral len)
-- | Returns the serialized OpList of all OpDefs defined in this
-- address space.
getAllOpList :: IO B.ByteString
getAllOpList = do
foreignPtr <-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
-- Makes a copy because it is more reliable than eviscerating
-- Buffer to steal its memory (including custom deallocator).
withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr)
<*> (fromIntegral <$> Raw.getBufferLength ptr)
where
checkCall = do
p <- Raw.getAllOpList
when (p == nullPtr) (throwIO exception)
return p
exception = TensorFlowException
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"

View file

@ -0,0 +1,152 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ForeignFunctionInterface #-}
module TensorFlow.Internal.Raw where
#include "third_party/tensorflow/c/c_api.h"
import Foreign
import Foreign.C
{#enum TF_DataType as DataType {} deriving (Show, Eq) #}
{#enum TF_Code as Code {} deriving (Show, Eq) #}
-- Status.
{#pointer *TF_Status as Status newtype #}
newStatus :: IO Status
newStatus = {# call TF_NewStatus as ^ #}
deleteStatus :: Status -> IO ()
deleteStatus = {# call TF_DeleteStatus as ^ #}
setStatus :: Status -> Code -> CString -> IO ()
setStatus s c = {# call TF_SetStatus as ^ #} s (fromIntegral $ fromEnum c)
getCode :: Status -> IO Code
getCode s = toEnum . fromIntegral <$> {# call TF_GetCode as ^ #} s
message :: Status -> IO CString
message = {# call TF_Message as ^ #}
-- Buffer.
data Buffer
{#pointer *TF_Buffer as BufferPtr -> Buffer #}
getBufferData :: BufferPtr -> IO (Ptr ())
getBufferData = {#get TF_Buffer->data #}
getBufferLength :: BufferPtr -> IO CULong
getBufferLength ={#get TF_Buffer->length #}
-- Tensor.
{#pointer *TF_Tensor as Tensor newtype #}
instance Storable Tensor where
sizeOf (Tensor t) = sizeOf t
alignment (Tensor t) = alignment t
peek p = fmap Tensor (peek (castPtr p))
poke p (Tensor t) = poke (castPtr p) t
newTensor :: DataType
-> Ptr CLong -- dimensions array
-> CInt -- num dimensions
-> Ptr () -- data
-> CULong -- data len
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()) -- deallocator
-> Ptr () -- deallocator arg
-> IO Tensor
newTensor dt = {# call TF_NewTensor as ^ #} (fromIntegral $ fromEnum dt)
deleteTensor :: Tensor -> IO ()
deleteTensor = {# call TF_DeleteTensor as ^ #}
tensorType :: Tensor -> IO DataType
tensorType t = toEnum . fromIntegral <$> {# call TF_TensorType as ^ #} t
numDims :: Tensor -> IO CInt
numDims = {# call TF_NumDims as ^ #}
dim :: Tensor -> CInt -> IO CLong
dim = {# call TF_Dim as ^ #}
tensorByteSize :: Tensor -> IO CULong
tensorByteSize = {# call TF_TensorByteSize as ^ #}
tensorData :: Tensor -> IO (Ptr ())
tensorData = {# call TF_TensorData as ^ #}
-- Session Options.
{# pointer *TF_SessionOptions as SessionOptions newtype #}
newSessionOptions :: IO SessionOptions
newSessionOptions = {# call TF_NewSessionOptions as ^ #}
setTarget :: SessionOptions -> CString -> IO ()
setTarget = {# call TF_SetTarget as ^ #}
setConfig :: SessionOptions -> Ptr () -> CULong -> Status -> IO ()
setConfig = {# call TF_SetConfig as ^ #}
deleteSessionOptions :: SessionOptions -> IO ()
deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
-- Session.
{# pointer *TF_Session as Session newtype #}
newSession :: SessionOptions -> Status -> IO Session
newSession = {# call TF_NewSession as ^ #}
closeSession :: Session -> Status -> IO ()
closeSession = {# call TF_CloseSession as ^ #}
deleteSession :: Session -> Status -> IO ()
deleteSession = {# call TF_DeleteSession as ^ #}
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
extendGraph = {# call TF_ExtendGraph as ^ #}
run :: Session
-> BufferPtr -- RunOptions proto.
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count).
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count).
-> Ptr CString -> CInt -- Target nodes (names, count).
-> BufferPtr -- RunMetadata proto.
-> Status
-> IO ()
run = {# call TF_Run as ^ #}
-- FFI helpers.
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
foreign import ccall "wrapper"
wrapTensorDealloc :: TensorDeallocFn -> IO (FunPtr TensorDeallocFn)
-- | Get the OpList of all OpDefs defined in this address space.
-- Returns a BufferPtr, ownership of which is transferred to the caller
-- (and can be freed using deleteBuffer).
--
-- The data in the buffer will be the serialized OpList proto for ops registered
-- in this address space.
getAllOpList :: IO BufferPtr
getAllOpList = {# call TF_GetAllOpList as ^ #}
foreign import ccall "&TF_DeleteBuffer"
deleteBuffer :: FunPtr (BufferPtr -> IO ())

View file

@ -0,0 +1,50 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE BangPatterns #-}
{-|
Module : TensorFlow.Internal.VarInt
Description : Encoders and decoders for varint types.
Originally taken from internal proto-lens code.
-}
module TensorFlow.Internal.VarInt
( getVarInt
, putVarInt
) where
import Data.Attoparsec.ByteString as Parse
import Data.Bits
import Data.ByteString.Lazy.Builder as Builder
import Data.Monoid ((<>))
import Data.Word (Word64)
-- | Decode an unsigned varint.
getVarInt :: Parser Word64
getVarInt = loop 1 0
where
loop !s !n = do
b <- anyWord8
let n' = n + s * fromIntegral (b .&. 127)
if (b .&. 128) == 0
then return n'
else loop (128*s) n'
-- | Encode a Word64.
putVarInt :: Word64 -> Builder
putVarInt n
| n < 128 = Builder.word8 (fromIntegral n)
| otherwise = Builder.word8 (fromIntegral $ n .&. 127 .|. 128)
<> putVarInt (n `shiftR` 7)

View file

@ -0,0 +1,141 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module TensorFlow.Nodes where
import Control.Applicative (liftA2, liftA3)
import Data.Map.Strict (Map)
import Data.Monoid ((<>))
import Data.Set (Set)
import Data.String (IsString)
import Data.Text (Text)
import Lens.Family2 ((^.))
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Vector as V
import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
import qualified TensorFlow.Internal.FFI as FFI
-- | Types that contain ops which can be run.
class Nodes t where
getNodes :: t -> Build (Set NodeName)
-- | Types that tensor representations (e.g. 'Tensor', 'ControlNode') can be
-- fetched into.
--
-- Includes collections of tensors (e.g. tuples).
class Nodes t => Fetchable t a where
getFetch :: t -> Build (Fetch a)
-- | Fetch action. Keeps track of what needs to be fetched and how to decode
-- the fetched data.
data Fetch a = Fetch
{ -- | Nodes to fetch
fetches :: Set Text
-- | Function to create an 'a' from the fetched data.
, fetchRestore :: Map Text FFI.TensorData -> a
}
instance Functor Fetch where
fmap f (Fetch fetch restore) = Fetch fetch (f . restore)
instance Applicative Fetch where
pure x = Fetch Set.empty (const x)
Fetch fetch restore <*> Fetch fetch' restore' =
Fetch (fetch <> fetch') (restore <*> restore')
nodesUnion :: (Monoid b, Traversable t, Applicative f) => t (f b) -> f b
nodesUnion = fmap (foldMap id) . sequenceA
instance (Nodes t1, Nodes t2) => Nodes (t1, t2) where
getNodes (x, y) = nodesUnion [getNodes x, getNodes y]
instance (Nodes t1, Nodes t2, Nodes t3) => Nodes (t1, t2, t3) where
getNodes (x, y, z) = nodesUnion [getNodes x, getNodes y, getNodes z]
instance (Fetchable t1 a1, Fetchable t2 a2) => Fetchable (t1, t2) (a1, a2) where
getFetch (x, y) = liftA2 (,) <$> getFetch x <*> getFetch y
instance (Fetchable t1 a1, Fetchable t2 a2, Fetchable t3 a3)
=> Fetchable (t1, t2, t3) (a1, a2, a3) where
getFetch (x, y, z) =
liftA3 (,,) <$> getFetch x <*> getFetch y <*> getFetch z
instance Nodes t => Nodes [t] where
getNodes = nodesUnion . map getNodes
instance Fetchable t a => Fetchable [t] [a] where
getFetch ts = sequenceA <$> mapM getFetch ts
instance Nodes ControlNode where
getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o
-- We use the constraint @(a ~ ())@ to help with type inference. For example,
-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session
-- ()@. If we used @instance Fetchable ControlNode ()@ instead, then that
-- expression would be ambiguous without explicitly specifying the return type.
instance a ~ () => Fetchable ControlNode a where
getFetch _ = return $ pure ()
instance Nodes (Tensor v a) where
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
fetchTensorList :: TensorType a => Tensor v a -> Build (Fetch (Shape, [a]))
fetchTensorList t = fmap (fmap V.toList) <$> fetchTensorVector t
fetchTensorVector :: forall a v . TensorType a
=> Tensor v a -> Build (Fetch (Shape, V.Vector a))
fetchTensorVector (Tensor _ o) = do
outputName <- renderOutput o
return $ Fetch (Set.singleton outputName) $ \tensors ->
let tensorData = tensors Map.! outputName
shape = Shape $ FFI.tensorDataDimensions tensorData
vec = decodeTensorData $ TensorData tensorData
expectedType = tensorType (undefined :: a)
actualType = FFI.tensorDataType tensorData
badTypeError = error $ "Bad tensor type: expected "
++ show expectedType
++ ", got "
++ show actualType
in if expectedType /= actualType
then badTypeError
else (shape, vec)
-- The constraint "a ~ a'" means that the input/output of fetch can constrain
-- the TensorType of each other.
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (V.Vector a') where
getFetch t = fmap snd <$> fetchTensorVector t
newtype Scalar a = Scalar {unScalar :: a}
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
RealFrac, IsString)
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (Scalar a') where
getFetch t = fmap (Scalar . headFromSingleton . snd) <$> fetchTensorList t
where
headFromSingleton [x] = x
headFromSingleton xs
= error $ "Unable to extract singleton from tensor of length "
++ show (length xs)

View file

@ -0,0 +1,46 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- Orphan instances for certain proto messages/enums, used internally.
-- TODO(judahjacobson): consider making proto-lens generate some or all of
-- these automatically; or, alternately, make new Haskell datatypes.
module TensorFlow.Orphans() where
import Proto.Tensorflow.Core.Framework.AttrValue
( AttrValue(..)
, AttrValue'ListValue(..)
, NameAttrList(..)
)
import Proto.Tensorflow.Core.Framework.NodeDef
( NodeDef(..))
import Proto.Tensorflow.Core.Framework.ResourceHandle
( ResourceHandle(..))
import Proto.Tensorflow.Core.Framework.Tensor
(TensorProto(..))
import Proto.Tensorflow.Core.Framework.TensorShape
(TensorShapeProto(..), TensorShapeProto'Dim(..))
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
deriving instance Ord AttrValue
deriving instance Ord AttrValue'ListValue
deriving instance Ord DataType
deriving instance Ord NameAttrList
deriving instance Ord NodeDef
deriving instance Ord ResourceHandle
deriving instance Ord TensorProto
deriving instance Ord TensorShapeProto
deriving instance Ord TensorShapeProto'Dim

View file

@ -0,0 +1,156 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Output
( ControlNode(..)
, Device(..)
-- * Ops
, NodeName(..)
, Op(..)
, opUnrendered
, OpDef(..)
, opName
, opType
, opAttr
, opInputs
, opControlInputs
, OpType(..)
, OutputIx(..)
, Output(..)
, output
, outputIndex
, outputOp
, PendingNodeName(..)
) where
import qualified Data.Map.Strict as Map
import Data.ProtoLens.TextFormat (showMessage)
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
import Data.Default (def)
import TensorFlow.Types (Attribute, attrLens)
import TensorFlow.Orphans ()
-- | A type of graph node which has no outputs. These nodes are
-- valuable for causing side effects when they are run.
newtype ControlNode = ControlNode { unControlNode :: Op }
-- | The type of op of a node in the graph. This corresponds to the proto field
-- NodeDef.op.
newtype OpType = OpType { unOpType :: Text }
deriving (Eq, Ord, Show)
instance IsString OpType where
fromString = OpType . Text.pack
-- | An output of a TensorFlow node.
data Output = Output !OutputIx !Op
deriving (Eq, Ord, Show)
output :: OutputIx -> Op -> Output
output = Output
outputOp :: Lens' Output Op
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)
outputIndex :: Lens' Output OutputIx
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)
newtype OutputIx = OutputIx { unOutputIx :: Int }
deriving (Eq, Ord, Num, Enum, Show)
-- | A device that a node can be assigned to.
-- There's a naming convention where the device names
-- are constructed from job and replica names.
newtype Device = Device {deviceName :: Text}
deriving (Eq, Ord, IsString)
instance Show Device where
show (Device d) = show d
-- | The representation of a node in a TensorFlow graph.
data Op
= Rendered !NodeDef -- ^ Properties are fixed, including the
-- device, name, and scope.
| Unrendered !OpDef -- ^ Properties are not fixed, and may change depending
-- on which context this op is rendered in.
deriving (Eq, Ord)
instance Show Op where
show (Rendered n) = "Rendered " ++ showMessage n
show (Unrendered o) = "Unrendered " ++ show (o ^. opName)
-- | Traverse on the 'Unrendered' of an 'Op'.
--
-- Same implementation as _Left.
opUnrendered :: Traversal' Op OpDef
opUnrendered f (Unrendered a) = Unrendered <$> f a
opUnrendered _ (Rendered b) = pure (Rendered b)
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
data OpDef = OpDef
{ _opName :: !PendingNodeName
, _opType :: !OpType
, _opAttrs :: !(Map.Map Text AttrValue)
, _opInputs :: [Output]
, _opControlInputs :: [NodeName]
} deriving (Eq, Ord)
-- | The name specified for an unrendered Op. If an Op has an
-- ImplicitName, it will be assigned based on the opType plus a
-- unique identifier. Does not contain the "scope" prefix.
data PendingNodeName = ExplicitName !Text | ImplicitName
deriving (Eq, Ord, Show)
-- | The name of a node in the graph. This corresponds to the proto field
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
-- (if the node was implicitly named).
newtype NodeName = NodeName { unNodeName :: Text }
deriving (Eq, Ord, Show)
opName :: Lens' OpDef PendingNodeName
opName = lens _opName (\o x -> o {_opName = x})
opType :: Lens' OpDef OpType
opType = lens _opType (\o x -> o { _opType = x})
opAttr :: Attribute a => Text -> Lens' OpDef a
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
. lens (Map.findWithDefault def n) (flip (Map.insert n))
. attrLens
opInputs :: Lens' OpDef [Output]
opInputs = lens _opInputs (\o x -> o {_opInputs = x})
opControlInputs :: Lens' OpDef [NodeName]
opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
-- TODO(gnezdo): IsString instance is weird and we should move that
-- code into a Build function
instance IsString Output where
fromString s = case break (==':') s of
(n, ':':ixStr)
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s
where assigned n = Rendered $ def & name .~ Text.pack n

View file

@ -0,0 +1,202 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.Session (
Session,
-- * Opaque value created via 'sessionConfig' and 'sessionTarget'.
SessionOption,
sessionConfig,
sessionTarget,
runSession,
runSessionWithOptions,
build,
buildAnd,
buildWithSummary,
extend,
addGraphDef,
run,
runWithFeeds,
run_,
runWithFeeds_,
asyncProdNodes,
) where
import Control.Monad (forever, unless, void)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Functor.Identity (runIdentity)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8)
import Data.ProtoLens (def)
import Lens.Family2 ((&), (.~))
import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build
import qualified TensorFlow.Internal.FFI as FFI
import qualified TensorFlow.Internal.Raw as Raw
import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Tensor
-- Common state threaded through the session.
data SessionState
= SessionState {
rawSession :: FFI.Session
, asyncCollector :: IO () -> IO ()
-- ^ Starts the given action concurrently.
}
newtype Session a
= Session (ReaderT SessionState (BuildT IO) a)
deriving (Functor, Applicative, Monad, MonadIO)
-- | Run 'Session' actions in a new TensorFlow session.
runSession :: Session a -> IO a
runSession = runSessionWithOptions []
-- | Setting of an option for the session (see 'runSessionWithOptions').
newtype SessionOption =
SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () }
-- | Target can be: "local", ip:port, host:port.
-- The set of supported factories depends on the linked in libraries.
-- REQUIRES "//learning/brain/public:tensorflow_remote" dependency for the binary.
sessionTarget :: ByteString -> SessionOption
sessionTarget = SessionOption . FFI.setSessionTarget
-- | Uses the specified config for the created session.
sessionConfig :: ConfigProto -> SessionOption
sessionConfig = SessionOption . FFI.setSessionConfig
-- | Run 'Session' actions in a new TensorFlow session created with
-- the given option setter actions ('sessionTarget', 'sessionConfig').
runSessionWithOptions :: [SessionOption] -> Session a -> IO a
runSessionWithOptions options (Session m) =
FFI.withSession applyOptions $
\as rs -> evalBuildT (runReaderT m (SessionState rs as))
where applyOptions opt = mapM_ (`unSesssionOption` opt) options
-- | Lift a 'Build' action into a 'Session', including any explicit op
-- renderings.
build :: Build a -> Session a
build = Session . lift . hoistBuildT (return . runIdentity)
-- | Lift a 'Build' action into a 'Session', including any explicit op
-- renderings. Returns the merged summary ops which can be used for
-- logging, see 'TensorFlow.Logging.build' for a convenient wrapper.
buildWithSummary :: forall a . Build a -> Session (a, [SummaryTensor])
buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
where v :: BuildT IO a
v = hoistBuildT (return . runIdentity) b
-- | Add all pending rendered nodes to the TensorFlow graph and runs
-- any pending initializers.
--
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
extend :: Session ()
extend = do
let withSessionWhen vs action =
unless (null vs) $ Session (asks rawSession) >>= action
nodesToExtend <- build flushNodeBuffer
withSessionWhen nodesToExtend $ \session ->
liftIO $ FFI.extendGraph session
$ def & node .~ nodesToExtend
-- Now that all the nodes are created, run the initializers.
initializers <- build flushInitializers
withSessionWhen initializers $ \session ->
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
-- | Helper combinator for doing something with the result of a 'Build' action.
-- Example usage:
--
-- > buildAnd run :: Fetchable t a => Build t -> Session a
buildAnd :: (a -> Session b) -> Build a -> Session b
buildAnd f m = build m >>= f
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'.
run :: Fetchable t a => t -> Session a
run = runWithFeeds []
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'.
runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a
runWithFeeds feeds t = do
ns <- build $ getNodes t
-- Note that this call to "fetch" shouldn't affect the following "extend"
-- call, since all nodes in t and its inputs/deps will be rendered by the
-- above call to getNodes.
fetch <- build $ getFetch t
runFetchWithFeeds feeds ns fetch
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
runFetchWithFeeds feeds target (Fetch fetch restore) = do
extend
feeds' <- build $ fixFeeds feeds
let fetchNames = encodeUtf8 <$> Set.toList fetch
targetNames = toNodeNames $ Set.toList target
session <- Session (asks rawSession)
runResult <- liftIO $ FFI.run session
feeds'
fetchNames
targetNames
let resultTensorsMap = Map.fromList $ zip (Set.toList fetch) runResult
return $ restore resultTensorsMap
toNodeNames :: [NodeName] -> [ByteString]
toNodeNames = map (encodeUtf8 . unNodeName)
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
-- already rendered. This behaves like 'run' except that it doesn't do any
-- fetches.
run_ :: Nodes t => t -> Session ()
run_ = runWithFeeds_ []
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
-- any fetches.
runWithFeeds_ :: Nodes t => [Feed] -> t -> Session ()
runWithFeeds_ feeds t = do
ns <- build $ getNodes t
runFetchWithFeeds feeds ns (pure ())
fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)]
fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o
-- | Starts a concurrent thread which evaluates the given Nodes
-- forever until runSession exits or an exception occurs. Graph
-- extension happens synchronously, but the resultant run proceeds as
-- a separate thread.
asyncProdNodes :: Nodes t
=> t -- ^ Node to evaluate concurrently.
-> Session ()
asyncProdNodes nodes = do
target <- build (getNodes nodes)
extend
let targetNames = toNodeNames $ Set.toList target
state <- Session ask
let loop = forever (void (FFI.run (rawSession state) [] [] targetNames))
liftIO (asyncCollector state loop)

View file

@ -0,0 +1,85 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
module TensorFlow.Tensor where
import Data.String (IsString(..))
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal')
import Lens.Family2.Unchecked (lens)
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
import TensorFlow.Types (TensorData(..), Attribute)
import qualified TensorFlow.Internal.FFI as FFI
-- | A named output of a TensorFlow operation.
--
-- The type parameter @a@ is the type of the elements in the 'Tensor'. The
-- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is
-- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a
-- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via
-- 'value'.
data Tensor v a = Tensor (TensorKind v) Output
data Value
data Ref
-- | This class provides a runtime switch on whether a 'Tensor' should be
-- treated as a 'Value' or as a 'Ref'.
data TensorKind v where
ValueKind :: TensorKind Value
RefKind :: TensorKind Ref
tensorKind :: Lens' (Tensor v a) (TensorKind v)
tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
tensorOutput :: Lens' (Tensor v a) Output
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
-- TODO: Come up with a better API for handling attributes.
-- | Lens for the attributes of a tensor.
--
-- Only valid if the tensor has not yet been rendered. If the tensor has been
-- rendered, the traversal will be over nothing (nothing can be read or
-- written).
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
-- Ref into Value. This behaves like a no-op.
value :: Tensor v a -> Tensor Value a
value (Tensor _ o) = Tensor ValueKind o
-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
-- when running the graph.
data Feed = Feed Output FFI.TensorData
-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
-- the graph.
--
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
feed :: Tensor v a -> TensorData a -> Feed
feed (Tensor _ o) (TensorData td) = Feed o td
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
-- in a 'GraphDef' that was loaded via 'addGraphDef'.
-- TODO(judahjacobson): add more safety checks here.
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
tensorFromName v = Tensor v . fromString . Text.unpack

View file

@ -0,0 +1,382 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- We use UndecidableInstances for type families with recursive definitions
-- like "\\". Those instances will terminate since each equation unwraps one
-- cons cell of a type-level list.
{-# LANGUAGE UndecidableInstances #-}
module TensorFlow.Types
( TensorType(..)
, TensorData(..)
, Shape(..)
, protoShape
, Attribute(..)
-- * Type constraints
, OneOf
, type (/=)
-- ** Implementation of constraints
, TypeError
, ExcludedCase
, TensorTypes
, NoneOf
, type (\\)
, Delete
, AllTensorTypes
) where
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Monoid ((<>))
import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~))
import Lens.Family2.Unchecked (iso)
import qualified Data.Attoparsec.ByteString as Atto
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import Proto.Tensorflow.Core.Framework.AttrValue
( AttrValue(..)
, AttrValue'ListValue(..)
, b
, f
, i
, s
, list
, type'
, shape
, tensor
)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
( TensorProto(..)
, floatVal
, doubleVal
, intVal
, stringVal
, int64Val
, stringVal
, boolVal
)
import Proto.Tensorflow.Core.Framework.TensorShape
( TensorShapeProto(..)
, dim
, size
)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.FFI as FFI
-- | Data about a tensor that is encoded for the TensorFlow APIs.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
-- | The class of scalar types supported by tensorflow.
class TensorType a where
tensorType :: a -> DataType
tensorRefType :: a -> DataType
tensorVal :: Lens' TensorProto [a]
-- | Decode the bytes of a TensorData into a Vector.
decodeTensorData :: TensorData a -> V.Vector a
-- | Encode a Vector into a TensorData.
--
-- The values should be in row major order, e.g.,
--
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
encodeTensorData :: Shape -> V.Vector a -> TensorData a
-- All types, besides ByteString, are encoded as simple arrays and we can use
-- Vector.Storable to encode/decode by type casting pointers.
-- TODO(fmayle): Assert that the data type matches the return type.
simpleDecode :: Storable a => TensorData a -> V.Vector a
simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData
simpleEncode :: forall a . (TensorType a, Storable a)
=> Shape -> V.Vector a -> TensorData a
simpleEncode (Shape xs)
= TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert
where
dt = tensorType (undefined :: a)
instance TensorType Float where
tensorType _ = DT_FLOAT
tensorRefType _ = DT_FLOAT_REF
tensorVal = floatVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Double where
tensorType _ = DT_DOUBLE
tensorRefType _ = DT_DOUBLE_REF
tensorVal = doubleVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int32 where
tensorType _ = DT_INT32
tensorRefType _ = DT_INT32_REF
tensorVal = intVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int64 where
tensorType _ = DT_INT64
tensorRefType _ = DT_INT64_REF
tensorVal = int64Val
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
integral :: Integral a => Lens' [Int32] [a]
integral = iso (fmap fromIntegral) (fmap fromIntegral)
instance TensorType Word8 where
tensorType _ = DT_UINT8
tensorRefType _ = DT_UINT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Word16 where
tensorType _ = DT_UINT16
tensorRefType _ = DT_UINT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int16 where
tensorType _ = DT_INT16
tensorRefType _ = DT_INT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int8 where
tensorType _ = DT_INT8
tensorRefType _ = DT_INT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType ByteString where
tensorType _ = DT_STRING
tensorRefType _ = DT_STRING_REF
tensorVal = stringVal
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
-- table offsets for each element :: [Word64]
-- at each element offset:
-- string length :: VarInt64
-- string data :: [Word8]
-- TODO(fmayle): Benchmark these functions.
decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
if expected /= count
then Left $ "decodeTensorData for ByteString count mismatch " ++
show (expected, count)
else V.mapM decodeString (S.convert offsets)
where
expected = S.length offsets
count = fromIntegral $ product $ FFI.tensorDataDimensions
$ unTensorData tensorData
bytes = FFI.tensorDataBytes $ unTensorData tensorData
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
decodeString :: Word64 -> Either String ByteString
decodeString offset =
let stringDataStart = B.drop (fromIntegral offset) dataBytes
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
stringParser :: Atto.Parser ByteString
stringParser = getVarInt >>= Atto.take . fromIntegral
encodeTensorData (Shape xs) vec =
TensorData $ FFI.TensorData xs dt byteVector
where
dt = tensorType (undefined :: ByteString)
-- Add a string to an offset table and data blob.
addString :: (Builder, Builder, Word64)
-> ByteString
-> (Builder, Builder, Word64)
addString (table, strings, offset) str =
( table <> Builder.word64LE offset
, strings <> lengthBytes <> Builder.byteString str
, offset + lengthBytesLen + strLen
)
where
strLen = fromIntegral $ B.length str
lengthBytes = putVarInt $ fromIntegral $ B.length str
lengthBytesLen =
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
-- Encode all strings.
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
-- Concat offset table with data.
bytes = table' <> strings'
-- Convert to Vector Word8.
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
instance TensorType Bool where
tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64
tensorRefType _ = DT_COMPLEX64
tensorVal = error "TODO (Complex Float)"
decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)"
instance TensorType (Complex Double) where
tensorType _ = DT_COMPLEX128
tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"
-- | Shape (dimensions) of a tensor.
newtype Shape = Shape [Int64] deriving Show
instance IsList Shape where
type Item Shape = Int64
fromList = Shape . fromList
toList (Shape ss) = toList ss
protoShape :: Lens' TensorShapeProto Shape
protoShape = iso protoToShape shapeToProto
where
protoToShape = Shape . fmap (view size) . view dim
shapeToProto (Shape ds) = def & dim .~ fmap (\d -> def & size .~ d) ds
class Attribute a where
attrLens :: Lens' AttrValue a
instance Attribute Float where
attrLens = f
instance Attribute ByteString where
attrLens = s
instance Attribute Int64 where
attrLens = i
instance Attribute DataType where
attrLens = type'
instance Attribute TensorProto where
attrLens = tensor
instance Attribute Bool where
attrLens = b
instance Attribute Shape where
attrLens = shape . protoShape
-- TODO(gnezdo): support generating list(Foo) from [Foo].
instance Attribute AttrValue'ListValue where
attrLens = list
instance Attribute [DataType] where
attrLens = list . type'
instance Attribute [Int64] where
attrLens = list . i
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
--
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
-- natural representation as a conjunction, i.e.,
--
-- @
-- a == Double || a == Float
-- @
--
-- into a disjunction like
--
-- @
-- a \/= Int32 && a \/= Int64 && a \/= ByteString && ...
-- @
--
-- using an enumeration of all the possible 'TensorType's.
type OneOf ts a
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
-- | A 'Constraint' checking that the input is a list of 'TensorType's.
-- Helps improve error messages when using 'OneOf'.
type family TensorTypes ts :: Constraint where
TensorTypes '[] = ()
TensorTypes (t ': ts) = (TensorType t, TensorTypes ts)
-- | A constraint checking that two types are different.
type family a /= b :: Constraint where
a /= a = TypeError a ~ ExcludedCase
a /= b = ()
-- | Helper types to produce a reasonable type error message when the Constraint
-- "a /= a" fails.
-- TODO(judahjacobson): Use ghc-8's CustomTypeErrors for this.
data TypeError a
data ExcludedCase
-- | An enumeration of all valid 'TensorType's.
type AllTensorTypes =
-- NOTE: This list should be kept in sync with
-- TensorFlow.OpGen.dtTypeToHaskell.
-- TODO: Add support for Complex Float/Double.
'[ Float
, Double
, Int8
, Int16
, Int32
, Int64
, Word8
, Word16
, ByteString
, Bool
]
-- | Removes a type from the given list of types.
type family Delete a as where
Delete a '[] = '[]
Delete a (a ': as) = Delete a as
Delete a (b ': as) = b ': Delete a as
-- | Takes the difference of two lists of types.
type family as \\ bs where
as \\ '[] = as
as \\ b ': bs = Delete b as \\ bs
-- | A constraint that the type @a@ doesn't appear in the type list @ts@.
-- Assumes that @a@ and each of the elements of @ts@ are 'TensorType's.
type family NoneOf ts a :: Constraint where
NoneOf '[] a = ()
NoneOf (t ': ts) a = (a /= t, NoneOf ts a)

View file

@ -0,0 +1,84 @@
name: tensorflow
version: 0.1.0.0
synopsis: TensorFlow bindings.
description: Please see README.md
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Simple
cabal-version: >=1.22
library
hs-source-dirs: src
exposed-modules: TensorFlow.Build
, TensorFlow.BuildOp
, TensorFlow.ControlFlow
, TensorFlow.Internal.FFI
, TensorFlow.Nodes
, TensorFlow.Output
, TensorFlow.Session
, TensorFlow.Tensor
, TensorFlow.Types
, TensorFlow.Internal.VarInt
other-modules: TensorFlow.Internal.Raw
, TensorFlow.Orphans
build-tools: c2hs
build-depends: proto-lens == 0.1.*
-- Used by the custom Setup script (for the test-suite).
, proto-lens-protoc == 0.1.*
, tensorflow-proto == 0.1.*
, base >= 4.7 && < 5
, async
, attoparsec
, bytestring
, containers
, data-default
, fgl
, lens-family
, mainland-pretty
, mtl
, semigroups
, split
, text
, temporary
, transformers
, vector
extra-libraries: tensorflow_c
default-language: Haskell2010
include-dirs: .
Test-Suite FFITest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: FFITest.hs
hs-source-dirs: tests
build-depends: HUnit
, base
, bytestring
, lens-family
, proto-lens
, tensorflow
, tensorflow-proto
, test-framework
, test-framework-hunit
Test-Suite VarIntTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: VarIntTest.hs
hs-source-dirs: tests
build-depends: base
, attoparsec
, bytestring
, google-shim
, tensorflow
, test-framework
, test-framework-quickcheck2
source-repository head
type: git
location: https://github.com/tensorflow/haskell

View file

@ -0,0 +1,38 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | Tests for FFI.
module Main where
import Data.ProtoLens (decodeMessage)
import Lens.Family2 (view)
import TensorFlow.Internal.FFI (getAllOpList)
import Test.HUnit (assertBool, assertFailure)
import Test.Framework (defaultMain)
import Test.Framework.Providers.HUnit (testCase)
import Proto.Tensorflow.Core.Framework.OpDef (OpList, op)
testParseAll :: IO ()
testParseAll = do
opList <- getAllOpList
either
assertFailure
(assertBool "Expected non-empty list of default Ops"
. not . null . view op)
(decodeMessage opList :: Either String OpList)
main = defaultMain
[ testCase "ParseAllOps" testParseAll
]

View file

@ -0,0 +1,32 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
module Main where
import Data.ByteString.Builder (toLazyByteString)
import Google.Test (googleTest)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import qualified Data.Attoparsec.ByteString.Lazy as Atto
import TensorFlow.Internal.VarInt
testEncodeDecode = testProperty "testEncodeDecode" $ \x ->
let bytes = toLazyByteString (putVarInt x)
in case Atto.eitherResult $ Atto.parse getVarInt bytes of
Left _ -> False
Right y -> x == y
main :: IO ()
main = googleTest [ testEncodeDecode
]

1
tensorflow/third_party Symbolic link
View file

@ -0,0 +1 @@
../third_party/tensorflow

1
third_party/tensorflow vendored Submodule

@ -0,0 +1 @@
Subproject commit bac7faa9a3eb5b60687a83336202cd3493de5385