mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Tensorflow 2.3.0 Support (#267)
* Tensorflow 2.3.0 building and passing tests. * Added einsum and test. * Added ByteString as a possible argument to a function. * Support more data types for Adam. * Move to later version of LTS on stackage. * Added a wrapper module for convolution functions. * Update ci build to use a later version of stack. * Removed a deprecated import in GradientTest.
This commit is contained in:
parent
568c9b6f03
commit
c66c912c32
21 changed files with 409 additions and 75 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,3 +2,4 @@
|
||||||
.stack/
|
.stack/
|
||||||
tensorflow-mnist-input-data/data/*.gz
|
tensorflow-mnist-input-data/data/*.gz
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
stack.yaml.lock
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
# stack to be installed on the host. This comes at the expense of
|
# stack to be installed on the host. This comes at the expense of
|
||||||
# flexibility.
|
# flexibility.
|
||||||
|
|
||||||
FROM tensorflow/tensorflow:1.14.0
|
FROM tensorflow/tensorflow:2.3.0
|
||||||
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
||||||
|
|
||||||
# The build context directory is the top of the tensorflow-haskell
|
# The build context directory is the top of the tensorflow-haskell
|
||||||
|
@ -30,14 +30,14 @@ RUN \
|
||||||
netbase \
|
netbase \
|
||||||
&& \
|
&& \
|
||||||
# Installs stack.
|
# Installs stack.
|
||||||
curl -O -L https://github.com/commercialhaskell/stack/releases/download/v2.1.3/stack-2.1.3-linux-x86_64.tar.gz && \
|
curl -O -L https://github.com/commercialhaskell/stack/releases/download/v2.5.1/stack-2.5.1-linux-x86_64.tar.gz && \
|
||||||
tar zxf stack-2.1.3-linux-x86_64.tar.gz -C /usr/local/bin stack-2.1.3-linux-x86_64/stack --strip 1 && \
|
tar zxf stack-2.5.1-linux-x86_64.tar.gz -C /usr/local/bin stack-2.5.1-linux-x86_64/stack --strip 1 && \
|
||||||
# Installs protoc and the libraries.
|
# Installs protoc and the libraries.
|
||||||
curl -O -L https://github.com/google/protobuf/releases/download/v3.9.1/protoc-3.9.1-linux-x86_64.zip && \
|
curl -O -L https://github.com/google/protobuf/releases/download/v3.13.0/protoc-3.13.0-linux-x86_64.zip && \
|
||||||
unzip -d /usr/local protoc-3.9.1-linux-x86_64.zip bin/protoc && \
|
unzip -d /usr/local protoc-3.13.0-linux-x86_64.zip bin/protoc && \
|
||||||
chmod 755 /usr/local/bin/protoc && \
|
chmod 755 /usr/local/bin/protoc && \
|
||||||
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz && \
|
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.3.0.tar.gz && \
|
||||||
tar zxf libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz -C /usr/local && \
|
tar zxf libtensorflow-cpu-linux-x86_64-2.3.0.tar.gz -C /usr/local && \
|
||||||
ldconfig && \
|
ldconfig && \
|
||||||
stack setup && \
|
stack setup && \
|
||||||
stack test --only-dependencies
|
stack test --only-dependencies
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Prepare the image with:
|
# Prepare the image with:
|
||||||
# docker build -t tensorflow/haskell:v0 docker
|
# docker build -t tensorflow/haskell:v0 docker
|
||||||
FROM tensorflow/tensorflow:1.14.0
|
FROM tensorflow/tensorflow:2.3.0
|
||||||
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
||||||
|
|
||||||
RUN apt-get update
|
RUN apt-get update
|
||||||
|
@ -28,11 +28,11 @@ RUN dpkg-reconfigure locales && \
|
||||||
|
|
||||||
# Installs protoc and the libraries.
|
# Installs protoc and the libraries.
|
||||||
RUN \
|
RUN \
|
||||||
curl -O -L https://github.com/google/protobuf/releases/download/v3.9.1/protoc-3.9.1-linux-x86_64.zip && \
|
curl -O -L https://github.com/google/protobuf/releases/download/v3.13.0/protoc-3.13.0-linux-x86_64.zip && \
|
||||||
unzip -d /usr/local protoc-3.9.1-linux-x86_64.zip bin/protoc && \
|
unzip -d /usr/local protoc-3.13.0-linux-x86_64.zip bin/protoc && \
|
||||||
chmod 755 /usr/local/bin/protoc && \
|
chmod 755 /usr/local/bin/protoc && \
|
||||||
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz && \
|
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.3.0.tar.gz && \
|
||||||
tar zxf libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz -C /usr/local && \
|
tar zxf libtensorflow-cpu-linux-x86_64-2.3.0.tar.gz -C /usr/local && \
|
||||||
ldconfig
|
ldconfig
|
||||||
|
|
||||||
ENV LANG en_US.UTF-8
|
ENV LANG en_US.UTF-8
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Prepare the image with:
|
# Prepare the image with:
|
||||||
# docker build -t tensorflow/haskell:1.14.0-gpu docker/gpu
|
# docker build -t tensorflow/haskell:1.14.0-gpu docker/gpu
|
||||||
FROM tensorflow/tensorflow:1.14.0-gpu
|
FROM tensorflow/tensorflow:2.3.0-gpu
|
||||||
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
||||||
|
|
||||||
RUN apt-get update
|
RUN apt-get update
|
||||||
|
@ -28,11 +28,11 @@ RUN dpkg-reconfigure locales && \
|
||||||
|
|
||||||
# Installs protoc and the libraries.
|
# Installs protoc and the libraries.
|
||||||
RUN \
|
RUN \
|
||||||
curl -O -L https://github.com/google/protobuf/releases/download/v3.9.1/protoc-3.9.1-linux-x86_64.zip && \
|
curl -O -L https://github.com/google/protobuf/releases/download/v3.13.0/protoc-3.13.0-linux-x86_64.zip && \
|
||||||
unzip -d /usr/local protoc-3.9.1-linux-x86_64.zip bin/protoc && \
|
unzip -d /usr/local protoc-3.13.0-linux-x86_64.zip bin/protoc && \
|
||||||
chmod 755 /usr/local/bin/protoc && \
|
chmod 755 /usr/local/bin/protoc && \
|
||||||
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.14.0.tar.gz && \
|
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.3.0.tar.gz && \
|
||||||
tar zxf libtensorflow-gpu-linux-x86_64-1.14.0.tar.gz -C /usr/local && \
|
tar zxf libtensorflow-gpu-linux-x86_64-2.3.0.tar.gz -C /usr/local && \
|
||||||
ldconfig
|
ldconfig
|
||||||
|
|
||||||
ENV LANG en_US.UTF-8
|
ENV LANG en_US.UTF-8
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
resolver: lts-15.11
|
resolver: lts-16.18
|
||||||
|
|
||||||
packages:
|
packages:
|
||||||
- tensorflow
|
- tensorflow
|
||||||
|
|
|
@ -105,6 +105,8 @@ blackList =
|
||||||
, "GroupByWindowDataset"
|
, "GroupByWindowDataset"
|
||||||
, "If"
|
, "If"
|
||||||
, "InterleaveDataset"
|
, "InterleaveDataset"
|
||||||
|
, "LegacyParallelInterleaveDatasetV2"
|
||||||
|
, "LoadDataset"
|
||||||
, "MapAndBatchDataset"
|
, "MapAndBatchDataset"
|
||||||
, "MapAndBatchDatasetV2"
|
, "MapAndBatchDatasetV2"
|
||||||
, "MapDataset"
|
, "MapDataset"
|
||||||
|
@ -112,16 +114,24 @@ blackList =
|
||||||
, "OneShotIterator"
|
, "OneShotIterator"
|
||||||
, "ParallelInterleaveDataset"
|
, "ParallelInterleaveDataset"
|
||||||
, "ParallelInterleaveDatasetV2"
|
, "ParallelInterleaveDatasetV2"
|
||||||
|
, "ParallelInterleaveDatasetV3"
|
||||||
|
, "ParallelInterleaveDatasetV4"
|
||||||
, "ParallelMapDataset"
|
, "ParallelMapDataset"
|
||||||
|
, "ParallelMapDatasetV2"
|
||||||
, "ParseSequenceExample"
|
, "ParseSequenceExample"
|
||||||
|
, "ParseSequenceExampleV2"
|
||||||
|
, "ParseSingleSequenceExample"
|
||||||
, "PartitionedCall"
|
, "PartitionedCall"
|
||||||
, "ReduceDataset"
|
, "ReduceDataset"
|
||||||
, "RemoteCall"
|
, "RemoteCall"
|
||||||
|
, "SaveDataset"
|
||||||
, "ScanDataset"
|
, "ScanDataset"
|
||||||
|
, "SnapshotDatasetV2"
|
||||||
, "StatefulPartitionedCall"
|
, "StatefulPartitionedCall"
|
||||||
, "StatelessIf"
|
, "StatelessIf"
|
||||||
, "StatelessWhile"
|
, "StatelessWhile"
|
||||||
, "SymbolicGradient"
|
, "SymbolicGradient"
|
||||||
|
, "TakeWhileDataset"
|
||||||
, "TPUPartitionedCall"
|
, "TPUPartitionedCall"
|
||||||
, "TPUReplicate"
|
, "TPUReplicate"
|
||||||
, "While"
|
, "While"
|
||||||
|
@ -130,6 +140,7 @@ blackList =
|
||||||
, "XlaReduce"
|
, "XlaReduce"
|
||||||
, "XlaReduceWindow"
|
, "XlaReduceWindow"
|
||||||
, "XlaSelectAndScatter"
|
, "XlaSelectAndScatter"
|
||||||
|
, "XlaScatter"
|
||||||
, "XlaWhile"
|
, "XlaWhile"
|
||||||
, "_If"
|
, "_If"
|
||||||
, "_TPUReplicate"
|
, "_TPUReplicate"
|
||||||
|
|
|
@ -16,7 +16,7 @@ library
|
||||||
exposed-modules: TensorFlow.GenOps.Core
|
exposed-modules: TensorFlow.GenOps.Core
|
||||||
autogen-modules: TensorFlow.GenOps.Core
|
autogen-modules: TensorFlow.GenOps.Core
|
||||||
build-depends: bytestring
|
build-depends: bytestring
|
||||||
, proto-lens == 0.6.*
|
, proto-lens == 0.7.*
|
||||||
, tensorflow == 0.2.*
|
, tensorflow == 0.2.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, lens-family == 2.*
|
, lens-family == 2.*
|
||||||
|
@ -27,7 +27,7 @@ custom-setup
|
||||||
setup-depends: Cabal
|
setup-depends: Cabal
|
||||||
, bytestring
|
, bytestring
|
||||||
, directory
|
, directory
|
||||||
, proto-lens == 0.6.*
|
, proto-lens == 0.7.*
|
||||||
, tensorflow-opgen == 0.2.*
|
, tensorflow-opgen == 0.2.*
|
||||||
, tensorflow == 0.2.*
|
, tensorflow == 0.2.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
|
|
|
@ -24,7 +24,7 @@ library
|
||||||
, filepath
|
, filepath
|
||||||
, hostname
|
, hostname
|
||||||
, lens-family == 2.*
|
, lens-family == 2.*
|
||||||
, proto-lens == 0.6.*
|
, proto-lens == 0.7.*
|
||||||
, resourcet
|
, resourcet
|
||||||
, stm
|
, stm
|
||||||
, stm-chans
|
, stm-chans
|
||||||
|
|
|
@ -20,7 +20,7 @@ library
|
||||||
exposed-modules: TensorFlow.Examples.MNIST.Parse
|
exposed-modules: TensorFlow.Examples.MNIST.Parse
|
||||||
, TensorFlow.Examples.MNIST.TrainedGraph
|
, TensorFlow.Examples.MNIST.TrainedGraph
|
||||||
other-modules: Paths_tensorflow_mnist
|
other-modules: Paths_tensorflow_mnist
|
||||||
build-depends: proto-lens == 0.6.*
|
build-depends: proto-lens == 0.7.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, binary
|
, binary
|
||||||
, bytestring
|
, bytestring
|
||||||
|
|
|
@ -266,7 +266,7 @@ getExplicitInputAttr o implicitAttrs a
|
||||||
, a ^. maybe'defaultValue == Nothing
|
, a ^. maybe'defaultValue == Nothing
|
||||||
, t <- parseAttrType o (a ^. type')
|
, t <- parseAttrType o (a ^. type')
|
||||||
, t `elem` map AttrSingle
|
, t `elem` map AttrSingle
|
||||||
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
|
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape, AttrBytes]
|
||||||
++ [AttrList AttrType] = Just t
|
++ [AttrList AttrType] = Just t
|
||||||
| otherwise = Nothing
|
| otherwise = Nothing
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ library
|
||||||
hs-source-dirs: src
|
hs-source-dirs: src
|
||||||
exposed-modules: TensorFlow.OpGen.ParsedOp
|
exposed-modules: TensorFlow.OpGen.ParsedOp
|
||||||
, TensorFlow.OpGen
|
, TensorFlow.OpGen
|
||||||
build-depends: proto-lens == 0.6.*
|
build-depends: proto-lens == 0.7.*
|
||||||
, tensorflow-proto == 0.2.*
|
, tensorflow-proto == 0.2.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, bytestring
|
, bytestring
|
||||||
|
|
300
tensorflow-ops/src/TensorFlow/Convolution.hs
Normal file
300
tensorflow-ops/src/TensorFlow/Convolution.hs
Normal file
|
@ -0,0 +1,300 @@
|
||||||
|
-- Copyright 2020 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
module TensorFlow.Convolution
|
||||||
|
( Padding(..)
|
||||||
|
, DataFormat(..)
|
||||||
|
, conv2D
|
||||||
|
, conv2D'
|
||||||
|
, conv2DBackpropFilter
|
||||||
|
, conv2DBackpropFilter'
|
||||||
|
, conv2DBackpropInput
|
||||||
|
, conv2DBackpropInput'
|
||||||
|
, conv3D
|
||||||
|
, conv3D'
|
||||||
|
, conv3DBackpropFilter
|
||||||
|
, conv3DBackpropFilter'
|
||||||
|
, conv3DBackpropFilterV2
|
||||||
|
, conv3DBackpropFilterV2'
|
||||||
|
, conv3DBackpropInput
|
||||||
|
, conv3DBackpropInput'
|
||||||
|
, conv3DBackpropInputV2
|
||||||
|
, conv3DBackpropInputV2'
|
||||||
|
, depthwiseConv2dNative
|
||||||
|
, depthwiseConv2dNative'
|
||||||
|
, depthwiseConv2dNativeBackpropFilter
|
||||||
|
, depthwiseConv2dNativeBackpropFilter'
|
||||||
|
, depthwiseConv2dNativeBackpropInput
|
||||||
|
, depthwiseConv2dNativeBackpropInput'
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Word (Word16)
|
||||||
|
import Data.Int (Int32,Int64)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Lens.Family2 ((.~))
|
||||||
|
|
||||||
|
import qualified TensorFlow.BuildOp as TF
|
||||||
|
import qualified TensorFlow.Core as TF
|
||||||
|
import qualified TensorFlow.GenOps.Core as TF
|
||||||
|
|
||||||
|
-- TODO: Support other convolution parameters such as stride.
|
||||||
|
|
||||||
|
-- | Convolution padding.
|
||||||
|
data Padding =
|
||||||
|
-- | output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
|
||||||
|
PaddingValid
|
||||||
|
-- | output_spatial_shape[i] = ceil(
|
||||||
|
-- (input_spatial_shape[i] -
|
||||||
|
-- (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
|
||||||
|
| PaddingSame
|
||||||
|
|
||||||
|
paddingToByteString :: Padding -> ByteString
|
||||||
|
paddingToByteString x = case x of
|
||||||
|
PaddingValid -> "VALID"
|
||||||
|
PaddingSame -> "SAME"
|
||||||
|
|
||||||
|
-- | Matrix format.
|
||||||
|
data DataFormat = ChannelLast -- ^ Channel is the last dimension (e.g. NWC, NHWC, NDHWC)
|
||||||
|
| ChannelFirst -- ^ Channel is the first dimension after N (e.g. NCW, NCHW, NCDHW)
|
||||||
|
|
||||||
|
-- TODO: Address 1D convolution.
|
||||||
|
|
||||||
|
dataFormat2D :: DataFormat -> ByteString
|
||||||
|
dataFormat2D x = case x of
|
||||||
|
ChannelLast -> "NHWC"
|
||||||
|
ChannelFirst -> "NCHW"
|
||||||
|
|
||||||
|
dataFormat3D :: DataFormat -> ByteString
|
||||||
|
dataFormat3D x = case x of
|
||||||
|
ChannelLast -> "NDHWC"
|
||||||
|
ChannelFirst -> "NCDHW"
|
||||||
|
|
||||||
|
-- | 2D Convolution with default parameters.
|
||||||
|
conv2D :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv2D = conv2D' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
conv2D' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv2D' params padding dataformat = TF.conv2D'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | 2D convolution backpropagation filter with default parameters.
|
||||||
|
conv2DBackpropFilter :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 Int32 -- ^ filter_sizes
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv2DBackpropFilter = conv2DBackpropFilter' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
conv2DBackpropFilter' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 Int32 -- ^ filter_sizes
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv2DBackpropFilter' params padding dataformat = TF.conv2DBackpropFilter'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | 2D convolution backpropagation input with default parameters.
|
||||||
|
conv2DBackpropInput :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 Int32 -- ^ input_sizes
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv2DBackpropInput = conv2DBackpropInput' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
conv2DBackpropInput' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 Int32 -- ^ input_sizes
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv2DBackpropInput' params padding dataformat = TF.conv2DBackpropInput'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | 3D Convolution with default parameters.
|
||||||
|
conv3D :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3D = conv3D' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
conv3D' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3D' params padding dataformat = TF.conv3D'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | 3D convolution backpropagation filter with default parameters.
|
||||||
|
conv3DBackpropFilter :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3DBackpropFilter = conv3DBackpropFilter' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
conv3DBackpropFilter' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3DBackpropFilter' params padding dataformat = TF.conv3DBackpropFilter'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | 3D convolution backpropagation filter with default parameters.
|
||||||
|
conv3DBackpropFilterV2 :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 Int32 -- ^ filter_sizes
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3DBackpropFilterV2 = conv3DBackpropFilterV2' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
conv3DBackpropFilterV2' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 Int32 -- ^ filter_sizes
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3DBackpropFilterV2' params padding dataformat = TF.conv3DBackpropFilterV2'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | 3D convolution backpropagation input with default parameters.
|
||||||
|
conv3DBackpropInput :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3DBackpropInput = conv3DBackpropInput' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
conv3DBackpropInput' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3DBackpropInput' params padding dataformat = TF.conv3DBackpropInput'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | 3D convolution backpropagation input with default parameters.
|
||||||
|
conv3DBackpropInputV2 :: (TF.OneOf '[Word16, Double, Float] t, TF.OneOf '[Int32, Int64] tshape)
|
||||||
|
=> TF.Tensor v1 tshape -- ^ input_sizes
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3DBackpropInputV2 = conv3DBackpropInputV2' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
conv3DBackpropInputV2' :: (TF.OneOf '[Word16, Double, Float] t, TF.OneOf '[Int32, Int64] tshape)
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 tshape -- ^ input_sizes
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
conv3DBackpropInputV2' params padding dataformat = TF.conv3DBackpropInputV2'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | Depth-wise 2D convolution native with default parameters.
|
||||||
|
depthwiseConv2dNative :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
depthwiseConv2dNative = depthwiseConv2dNative' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
depthwiseConv2dNative' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 t -- ^ filter
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
depthwiseConv2dNative' params padding dataformat = TF.depthwiseConv2dNative'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | Depth-wise 2D convolution native backpropagation filter with default parameters.
|
||||||
|
depthwiseConv2dNativeBackpropFilter :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 Int32 -- ^ filter_sizes
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
depthwiseConv2dNativeBackpropFilter = depthwiseConv2dNativeBackpropFilter' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
depthwiseConv2dNativeBackpropFilter' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 t -- ^ input
|
||||||
|
-> TF.Tensor v2 Int32 -- ^ filter_sizes
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
depthwiseConv2dNativeBackpropFilter' params padding dataformat = TF.depthwiseConv2dNativeBackpropFilter'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
|
||||||
|
(paddingToByteString padding)
|
||||||
|
|
||||||
|
-- | Depth-wise 2D convolution native backpropagation input with default parameters.
|
||||||
|
depthwiseConv2dNativeBackpropInput :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.Tensor v1 Int32 -- ^ input_sizes
|
||||||
|
-> TF.Tensor v2 t -- ^ input
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
depthwiseConv2dNativeBackpropInput = depthwiseConv2dNativeBackpropInput' id PaddingValid ChannelLast
|
||||||
|
|
||||||
|
|
||||||
|
depthwiseConv2dNativeBackpropInput' :: TF.OneOf '[Word16, Double, Float] t
|
||||||
|
=> TF.OpParams
|
||||||
|
-> Padding
|
||||||
|
-> DataFormat
|
||||||
|
-> TF.Tensor v1 Int32 -- ^ input_sizes
|
||||||
|
-> TF.Tensor v2 t -- ^ input
|
||||||
|
-> TF.Tensor v3 t -- ^ out_backprop
|
||||||
|
-> TF.Tensor TF.Build t -- ^ output
|
||||||
|
depthwiseConv2dNativeBackpropInput' params padding dataformat = TF.depthwiseConv2dNativeBackpropInput'
|
||||||
|
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
|
||||||
|
(paddingToByteString padding)
|
|
@ -678,16 +678,14 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||||
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
[ Just $ CoreOps.conv2DBackpropInput'
|
[ Just $ CoreOps.conv2DBackpropInput'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
(shape x) y dz
|
padding (shape x) y dz
|
||||||
, Just $ CoreOps.conv2DBackpropFilter'
|
, Just $ CoreOps.conv2DBackpropFilter'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
x (shape y) dz
|
padding x (shape y) dz
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
@ -699,16 +697,14 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
|
||||||
[ Nothing
|
[ Nothing
|
||||||
, Just $ CoreOps.conv2DBackpropFilter'
|
, Just $ CoreOps.conv2DBackpropFilter'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
dz (shape x) y
|
padding dz (shape x) y
|
||||||
, Just $ CoreOps.conv2D'
|
, Just $ CoreOps.conv2D'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
dz x
|
padding dz x
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
@ -719,14 +715,12 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
|
||||||
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
|
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
|
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
(shape x) y dz
|
padding (shape x) y dz
|
||||||
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
x (shape y) dz
|
padding x (shape y) dz
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
@ -737,14 +731,12 @@ opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz]
|
||||||
[ Nothing
|
[ Nothing
|
||||||
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
dz (shape x) y
|
padding dz (shape x) y
|
||||||
, Just $ CoreOps.depthwiseConv2dNative'
|
, Just $ CoreOps.depthwiseConv2dNative'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
dz x
|
padding dz x
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
@ -755,9 +747,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||||
[ Just $ CoreOps.maxPoolGrad'
|
[ Just $ CoreOps.maxPoolGrad'
|
||||||
((opAttr "ksize" .~ ksize)
|
((opAttr "ksize" .~ ksize)
|
||||||
. (opAttr "strides" .~ strides)
|
. (opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
x output dz
|
padding x output dz
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
output :: Tensor Build a
|
output :: Tensor Build a
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
@ -22,11 +24,15 @@ module TensorFlow.Minimize
|
||||||
( Minimizer
|
( Minimizer
|
||||||
, minimizeWith
|
, minimizeWith
|
||||||
, gradientDescent
|
, gradientDescent
|
||||||
|
, OneOfAdamDataTypes
|
||||||
, AdamConfig(..)
|
, AdamConfig(..)
|
||||||
, adam
|
, adam
|
||||||
, adam'
|
, adam'
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Data.Complex (Complex)
|
||||||
|
import Data.Int (Int8,Int16,Int32,Int64)
|
||||||
|
import Data.Word (Word8,Word16,Word32,Word64)
|
||||||
import Control.Monad (zipWithM)
|
import Control.Monad (zipWithM)
|
||||||
import Data.Default (Default(..))
|
import Data.Default (Default(..))
|
||||||
import Data.List (zipWith4)
|
import Data.List (zipWith4)
|
||||||
|
@ -65,16 +71,19 @@ gradientDescent learningRate params grads = TF.withNameScope "gradientDescent" $
|
||||||
TF.assignAdd param (TF.scalar (-learningRate) `TF.mul` grad)
|
TF.assignAdd param (TF.scalar (-learningRate) `TF.mul` grad)
|
||||||
TF.group =<< zipWithM applyGrad params grads
|
TF.group =<< zipWithM applyGrad params grads
|
||||||
|
|
||||||
-- TODO: Support more than Float in adam.
|
data AdamConfig t = AdamConfig
|
||||||
|
{ adamLearningRate :: t
|
||||||
data AdamConfig = AdamConfig
|
, adamBeta1 :: t
|
||||||
{ adamLearningRate :: Float
|
, adamBeta2 :: t
|
||||||
, adamBeta1 :: Float
|
, adamEpsilon :: t
|
||||||
, adamBeta2 :: Float
|
|
||||||
, adamEpsilon :: Float
|
|
||||||
}
|
}
|
||||||
|
|
||||||
instance Default AdamConfig where
|
type OneOfAdamDataTypes t =
|
||||||
|
TF.OneOf '[ Complex Double, Complex Float
|
||||||
|
, Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8
|
||||||
|
, Double, Float] t
|
||||||
|
|
||||||
|
instance Fractional t => Default (AdamConfig t) where
|
||||||
-- Recommended defaults from the adam paper.
|
-- Recommended defaults from the adam paper.
|
||||||
def = AdamConfig 0.001 0.9 0.999 1e-8
|
def = AdamConfig 0.001 0.9 0.999 1e-8
|
||||||
|
|
||||||
|
@ -83,10 +92,10 @@ instance Default AdamConfig where
|
||||||
-- See https://arxiv.org/abs/1412.6980.
|
-- See https://arxiv.org/abs/1412.6980.
|
||||||
--
|
--
|
||||||
-- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'.
|
-- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'.
|
||||||
adam :: Minimizer Float
|
adam :: (OneOfAdamDataTypes t, Fractional t) => Minimizer t
|
||||||
adam = adam' def
|
adam = adam' def
|
||||||
|
|
||||||
adam' :: AdamConfig -> Minimizer Float
|
adam' :: OneOfAdamDataTypes t => AdamConfig t -> Minimizer t
|
||||||
adam' config params grads = TF.withNameScope "adam" $ do
|
adam' config params grads = TF.withNameScope "adam" $ do
|
||||||
let lr = TF.scalar (adamLearningRate config)
|
let lr = TF.scalar (adamLearningRate config)
|
||||||
beta1 = TF.scalar (adamBeta1 config)
|
beta1 = TF.scalar (adamBeta1 config)
|
||||||
|
|
|
@ -89,6 +89,8 @@ module TensorFlow.Ops
|
||||||
, CoreOps.identity'
|
, CoreOps.identity'
|
||||||
, CoreOps.matMul
|
, CoreOps.matMul
|
||||||
, CoreOps.matMul'
|
, CoreOps.matMul'
|
||||||
|
, CoreOps.einsum
|
||||||
|
, CoreOps.einsum'
|
||||||
, matTranspose
|
, matTranspose
|
||||||
, matTranspose'
|
, matTranspose'
|
||||||
, CoreOps.mean
|
, CoreOps.mean
|
||||||
|
|
|
@ -16,12 +16,13 @@ library
|
||||||
hs-source-dirs: src
|
hs-source-dirs: src
|
||||||
exposed-modules: TensorFlow.Gradient
|
exposed-modules: TensorFlow.Gradient
|
||||||
, TensorFlow.Ops
|
, TensorFlow.Ops
|
||||||
|
, TensorFlow.Convolution
|
||||||
, TensorFlow.EmbeddingOps
|
, TensorFlow.EmbeddingOps
|
||||||
, TensorFlow.Minimize
|
, TensorFlow.Minimize
|
||||||
, TensorFlow.NN
|
, TensorFlow.NN
|
||||||
, TensorFlow.Queue
|
, TensorFlow.Queue
|
||||||
, TensorFlow.Variable
|
, TensorFlow.Variable
|
||||||
build-depends: proto-lens == 0.6.*
|
build-depends: proto-lens == 0.7.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, bytestring
|
, bytestring
|
||||||
, fgl
|
, fgl
|
||||||
|
|
|
@ -32,7 +32,8 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput', batchMatMul, batchMatMul', conjugateTranspose)
|
import qualified TensorFlow.GenOps.Core as TF (max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, batchMatMul, batchMatMul', conjugateTranspose)
|
||||||
|
import qualified TensorFlow.Convolution as TF
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||||
import qualified TensorFlow.Output as TF
|
import qualified TensorFlow.Output as TF
|
||||||
|
@ -42,7 +43,6 @@ import qualified TensorFlow.Variable as TF
|
||||||
import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
|
import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
|
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
|
||||||
|
|
||||||
import qualified Data.ByteString.Char8 as BS
|
|
||||||
import TensorFlow.Session (SessionT)
|
import TensorFlow.Session (SessionT)
|
||||||
|
|
||||||
testGradientSimple :: Test
|
testGradientSimple :: Test
|
||||||
|
@ -715,11 +715,8 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
||||||
let filterShape = TF.vector [2, 2, 1, 1 :: Int32] -- [fh, fw, inc, out]
|
let filterShape = TF.vector [2, 2, 1, 1 :: Int32] -- [fh, fw, inc, out]
|
||||||
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
|
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
|
||||||
let y = TF.conv2DBackpropInput'
|
let y = TF.conv2DBackpropInput'
|
||||||
( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
|
(TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
|
||||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
TF.PaddingValid TF.ChannelLast conv_input_shape filter' x
|
||||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
|
||||||
)
|
|
||||||
conv_input_shape filter' x
|
|
||||||
|
|
||||||
[dx] <- TF.gradients y [x]
|
[dx] <- TF.gradients y [x]
|
||||||
TF.run (dx, TF.shape dx, TF.shape x)
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
|
@ -735,11 +732,8 @@ testDepthwiseConv2dGrad = testCase "testDepthwiseConv2dGrad" $ do
|
||||||
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
|
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
|
||||||
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
||||||
let y = TF.depthwiseConv2dNative'
|
let y = TF.depthwiseConv2dNative'
|
||||||
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
(TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
||||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
TF.PaddingValid TF.ChannelLast x filter'
|
||||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
|
||||||
)
|
|
||||||
x filter'
|
|
||||||
|
|
||||||
[dx] <- TF.gradients y [x]
|
[dx] <- TF.gradients y [x]
|
||||||
TF.run (dx, TF.shape dx, TF.shape x)
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
|
@ -757,11 +751,8 @@ testDepthwiseConv2dBackpropInputGrad = testCase "testDepthwiseConv2dBackpropInpu
|
||||||
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
|
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
|
||||||
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
||||||
let y = TF.depthwiseConv2dNativeBackpropInput'
|
let y = TF.depthwiseConv2dNativeBackpropInput'
|
||||||
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
(TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
||||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
TF.PaddingValid TF.ChannelLast conv_input_shape filter' x
|
||||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
|
||||||
)
|
|
||||||
conv_input_shape filter' x
|
|
||||||
|
|
||||||
[dx] <- TF.gradients y [x]
|
[dx] <- TF.gradients y [x]
|
||||||
TF.run (dx, TF.shape dx, TF.shape x)
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
|
|
|
@ -104,6 +104,20 @@ testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do
|
||||||
f1 <- TF.run w
|
f1 <- TF.run w
|
||||||
liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1)
|
liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1)
|
||||||
|
|
||||||
|
-- | Test Einstein summation.
|
||||||
|
testEinsum :: Test
|
||||||
|
testEinsum = testCase "testEinsum" $ TF.runSession $ do
|
||||||
|
-- Matrix multiply
|
||||||
|
let matA = TF.constant (TF.Shape [3,3]) [1..9 :: Float]
|
||||||
|
let matB = TF.constant (TF.Shape [3,1]) [1..3 :: Float]
|
||||||
|
matMulOut <- TF.run $ TF.matMul matA matB
|
||||||
|
einsumOut <- TF.run $ TF.einsum "ij,jk->ik" [matA,matB]
|
||||||
|
liftIO $ (matMulOut :: V.Vector Float) @=? einsumOut
|
||||||
|
-- Hadamard multiply
|
||||||
|
hadMulOut <- TF.run $ TF.mul matA matA
|
||||||
|
einsumHad <- TF.run $ TF.einsum "ij,ij->ij" [matA,matA]
|
||||||
|
liftIO $ (hadMulOut :: V.Vector Float) @=? einsumHad
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = defaultMain
|
main = defaultMain
|
||||||
[ testSaveRestore
|
[ testSaveRestore
|
||||||
|
@ -112,4 +126,5 @@ main = defaultMain
|
||||||
, testPlaceholderCse
|
, testPlaceholderCse
|
||||||
, testScalarFeedCse
|
, testScalarFeedCse
|
||||||
, testRereadRef
|
, testRereadRef
|
||||||
|
, testEinsum
|
||||||
]
|
]
|
||||||
|
|
|
@ -73,6 +73,8 @@ library
|
||||||
, Proto.Tensorflow.Core.Protobuf.ControlFlow_Fields
|
, Proto.Tensorflow.Core.Protobuf.ControlFlow_Fields
|
||||||
, Proto.Tensorflow.Core.Protobuf.Debug
|
, Proto.Tensorflow.Core.Protobuf.Debug
|
||||||
, Proto.Tensorflow.Core.Protobuf.Debug_Fields
|
, Proto.Tensorflow.Core.Protobuf.Debug_Fields
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.DeviceFilters
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.ErrorCodes
|
||||||
, Proto.Tensorflow.Core.Protobuf.MetaGraph
|
, Proto.Tensorflow.Core.Protobuf.MetaGraph
|
||||||
, Proto.Tensorflow.Core.Protobuf.MetaGraph_Fields
|
, Proto.Tensorflow.Core.Protobuf.MetaGraph_Fields
|
||||||
, Proto.Tensorflow.Core.Protobuf.NamedTensor
|
, Proto.Tensorflow.Core.Protobuf.NamedTensor
|
||||||
|
@ -83,12 +85,16 @@ library
|
||||||
, Proto.Tensorflow.Core.Protobuf.RewriterConfig_Fields
|
, Proto.Tensorflow.Core.Protobuf.RewriterConfig_Fields
|
||||||
, Proto.Tensorflow.Core.Protobuf.SavedModel
|
, Proto.Tensorflow.Core.Protobuf.SavedModel
|
||||||
, Proto.Tensorflow.Core.Protobuf.SavedModel_Fields
|
, Proto.Tensorflow.Core.Protobuf.SavedModel_Fields
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.SavedObjectGraph
|
||||||
, Proto.Tensorflow.Core.Protobuf.Saver
|
, Proto.Tensorflow.Core.Protobuf.Saver
|
||||||
, Proto.Tensorflow.Core.Protobuf.Saver_Fields
|
, Proto.Tensorflow.Core.Protobuf.Saver_Fields
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.Struct
|
||||||
, Proto.Tensorflow.Core.Protobuf.TensorBundle
|
, Proto.Tensorflow.Core.Protobuf.TensorBundle
|
||||||
, Proto.Tensorflow.Core.Protobuf.TensorBundle_Fields
|
, Proto.Tensorflow.Core.Protobuf.TensorBundle_Fields
|
||||||
, Proto.Tensorflow.Core.Protobuf.TensorflowServer
|
, Proto.Tensorflow.Core.Protobuf.TensorflowServer
|
||||||
, Proto.Tensorflow.Core.Protobuf.TensorflowServer_Fields
|
, Proto.Tensorflow.Core.Protobuf.TensorflowServer_Fields
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.TrackableObjectGraph
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.VerifierConfig
|
||||||
, Proto.Tensorflow.Core.Util.Event
|
, Proto.Tensorflow.Core.Util.Event
|
||||||
, Proto.Tensorflow.Core.Util.Event_Fields
|
, Proto.Tensorflow.Core.Util.Event_Fields
|
||||||
, Proto.Tensorflow.Core.Util.MemmappedFileSystem
|
, Proto.Tensorflow.Core.Util.MemmappedFileSystem
|
||||||
|
@ -125,6 +131,7 @@ library
|
||||||
, Proto.Tensorflow.Core.Framework.OpDef_Fields
|
, Proto.Tensorflow.Core.Framework.OpDef_Fields
|
||||||
, Proto.Tensorflow.Core.Framework.ResourceHandle
|
, Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||||
, Proto.Tensorflow.Core.Framework.ResourceHandle_Fields
|
, Proto.Tensorflow.Core.Framework.ResourceHandle_Fields
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.SavedObjectGraph
|
||||||
, Proto.Tensorflow.Core.Framework.StepStats
|
, Proto.Tensorflow.Core.Framework.StepStats
|
||||||
, Proto.Tensorflow.Core.Framework.StepStats_Fields
|
, Proto.Tensorflow.Core.Framework.StepStats_Fields
|
||||||
, Proto.Tensorflow.Core.Framework.Summary
|
, Proto.Tensorflow.Core.Framework.Summary
|
||||||
|
@ -152,6 +159,8 @@ library
|
||||||
, Proto.Tensorflow.Core.Protobuf.ControlFlow_Fields
|
, Proto.Tensorflow.Core.Protobuf.ControlFlow_Fields
|
||||||
, Proto.Tensorflow.Core.Protobuf.Debug
|
, Proto.Tensorflow.Core.Protobuf.Debug
|
||||||
, Proto.Tensorflow.Core.Protobuf.Debug_Fields
|
, Proto.Tensorflow.Core.Protobuf.Debug_Fields
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.DeviceFilters
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.ErrorCodes
|
||||||
, Proto.Tensorflow.Core.Protobuf.MetaGraph
|
, Proto.Tensorflow.Core.Protobuf.MetaGraph
|
||||||
, Proto.Tensorflow.Core.Protobuf.MetaGraph_Fields
|
, Proto.Tensorflow.Core.Protobuf.MetaGraph_Fields
|
||||||
, Proto.Tensorflow.Core.Protobuf.NamedTensor
|
, Proto.Tensorflow.Core.Protobuf.NamedTensor
|
||||||
|
@ -162,12 +171,16 @@ library
|
||||||
, Proto.Tensorflow.Core.Protobuf.RewriterConfig_Fields
|
, Proto.Tensorflow.Core.Protobuf.RewriterConfig_Fields
|
||||||
, Proto.Tensorflow.Core.Protobuf.SavedModel
|
, Proto.Tensorflow.Core.Protobuf.SavedModel
|
||||||
, Proto.Tensorflow.Core.Protobuf.SavedModel_Fields
|
, Proto.Tensorflow.Core.Protobuf.SavedModel_Fields
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.SavedObjectGraph
|
||||||
, Proto.Tensorflow.Core.Protobuf.Saver
|
, Proto.Tensorflow.Core.Protobuf.Saver
|
||||||
, Proto.Tensorflow.Core.Protobuf.Saver_Fields
|
, Proto.Tensorflow.Core.Protobuf.Saver_Fields
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.Struct
|
||||||
, Proto.Tensorflow.Core.Protobuf.TensorBundle
|
, Proto.Tensorflow.Core.Protobuf.TensorBundle
|
||||||
, Proto.Tensorflow.Core.Protobuf.TensorBundle_Fields
|
, Proto.Tensorflow.Core.Protobuf.TensorBundle_Fields
|
||||||
, Proto.Tensorflow.Core.Protobuf.TensorflowServer
|
, Proto.Tensorflow.Core.Protobuf.TensorflowServer
|
||||||
, Proto.Tensorflow.Core.Protobuf.TensorflowServer_Fields
|
, Proto.Tensorflow.Core.Protobuf.TensorflowServer_Fields
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.TrackableObjectGraph
|
||||||
|
, Proto.Tensorflow.Core.Protobuf.VerifierConfig
|
||||||
, Proto.Tensorflow.Core.Util.Event
|
, Proto.Tensorflow.Core.Util.Event
|
||||||
, Proto.Tensorflow.Core.Util.Event_Fields
|
, Proto.Tensorflow.Core.Util.Event_Fields
|
||||||
, Proto.Tensorflow.Core.Util.MemmappedFileSystem
|
, Proto.Tensorflow.Core.Util.MemmappedFileSystem
|
||||||
|
@ -176,9 +189,9 @@ library
|
||||||
, Proto.Tensorflow.Core.Util.SavedTensorSlice_Fields
|
, Proto.Tensorflow.Core.Util.SavedTensorSlice_Fields
|
||||||
, Proto.Tensorflow.Core.Util.TestLog
|
, Proto.Tensorflow.Core.Util.TestLog
|
||||||
, Proto.Tensorflow.Core.Util.TestLog_Fields
|
, Proto.Tensorflow.Core.Util.TestLog_Fields
|
||||||
build-depends: proto-lens == 0.6.*
|
build-depends: proto-lens == 0.7.*
|
||||||
, proto-lens-runtime == 0.6.*
|
, proto-lens-runtime == 0.7.*
|
||||||
, proto-lens-protobuf-types == 0.6.*
|
, proto-lens-protobuf-types == 0.7.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
include-dirs: .
|
include-dirs: .
|
||||||
|
|
|
@ -36,7 +36,7 @@ library
|
||||||
, TensorFlow.Types
|
, TensorFlow.Types
|
||||||
other-modules: TensorFlow.Internal.Raw
|
other-modules: TensorFlow.Internal.Raw
|
||||||
build-tools: c2hs
|
build-tools: c2hs
|
||||||
build-depends: proto-lens == 0.6.*
|
build-depends: proto-lens == 0.7.*
|
||||||
, tensorflow-proto == 0.2.*
|
, tensorflow-proto == 0.2.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, async
|
, async
|
||||||
|
|
2
third_party/tensorflow
vendored
2
third_party/tensorflow
vendored
|
@ -1 +1 @@
|
||||||
Subproject commit 6612da89516247503f03ef76e974b51a434fb52e
|
Subproject commit b36436b087bd8e8701ef51718179037cccdfc26e
|
Loading…
Reference in a new issue