Added a wrapper module for convolution functions.

This commit is contained in:
jcmartin 2020-11-05 17:21:06 +00:00
parent 550ed59ad7
commit 4bca19db82
3 changed files with 308 additions and 13 deletions

View File

@ -0,0 +1,299 @@
-- Copyright 2020 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Convolution
( Padding(..)
, DataFormat(..)
, conv2D
, conv2D'
, conv2DBackpropFilter
, conv2DBackpropFilter'
, conv2DBackpropInput
, conv2DBackpropInput'
, conv3D
, conv3D'
, conv3DBackpropFilter
, conv3DBackpropFilter'
, conv3DBackpropFilterV2
, conv3DBackpropFilterV2'
, conv3DBackpropInput
, conv3DBackpropInput'
, conv3DBackpropInputV2
, conv3DBackpropInputV2'
, depthwiseConv2dNative
, depthwiseConv2dNative'
, depthwiseConv2dNativeBackpropFilter
, depthwiseConv2dNativeBackpropFilter'
, depthwiseConv2dNativeBackpropInput
, depthwiseConv2dNativeBackpropInput'
) where
import Data.Word(Word16)
import Data.Int(Int32,Int64)
import Data.ByteString(ByteString)
import Lens.Family2 ((.~), (&))
import qualified TensorFlow.BuildOp as TF
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF
-- TODO: Support other convolution parameters such as stride.
-- | Convolution padding.
data Padding = PaddingValid -- ^ output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
| PaddingSame -- ^ output_spatial_shape[i] = ceil((input_spatial_shape[i] - (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
paddingToByteString :: Padding -> ByteString
paddingToByteString x = case x of
PaddingValid -> "VALID"
PaddingSame -> "SAME"
-- | Matrix format.
data DataFormat = ChannelLast -- ^ Channel is the last dimension (e.g. NWC, NHWC, NDHWC)
| ChannelFirst -- ^ Channel is the first dimension after N (e.g. NCW, NCHW, NCDHW)
-- TODO: Address 1D convolution.
--dataFormat1D :: DataFormat -> ByteString
--dataFormat1D x = case x of
-- ChannelLast -> "NWC"
-- ChannelFirst -> "NCW"
dataFormat2D :: DataFormat -> ByteString
dataFormat2D x = case x of
ChannelLast -> "NHWC"
ChannelFirst -> "NCHW"
dataFormat3D :: DataFormat -> ByteString
dataFormat3D x = case x of
ChannelLast -> "NDHWC"
ChannelFirst -> "NCDHW"
-- | 2D Convolution with default parameters.
conv2D :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor TF.Build t -- ^ output
conv2D = conv2D' id PaddingValid ChannelLast
conv2D' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor TF.Build t -- ^ output
conv2D' params padding dataformat = TF.conv2D'
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
(paddingToByteString padding)
-- | 2D convolution backpropagation filter with default parameters.
conv2DBackpropFilter :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 Int32 -- ^ filter_sizes
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv2DBackpropFilter = conv2DBackpropFilter' id PaddingValid ChannelLast
conv2DBackpropFilter' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 Int32 -- ^ filter_sizes
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv2DBackpropFilter' params padding dataformat = TF.conv2DBackpropFilter'
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
(paddingToByteString padding)
-- | 2D convolution backpropagation input with default parameters.
conv2DBackpropInput :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 Int32 -- ^ input_sizes
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv2DBackpropInput = conv2DBackpropInput' id PaddingValid ChannelLast
conv2DBackpropInput' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 Int32 -- ^ input_sizes
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv2DBackpropInput' params padding dataformat = TF.conv2DBackpropInput'
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
(paddingToByteString padding)
-- | 3D Convolution with default parameters.
conv3D :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor TF.Build t -- ^ output
conv3D = conv3D' id PaddingValid ChannelLast
conv3D' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor TF.Build t -- ^ output
conv3D' params padding dataformat = TF.conv3D'
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
(paddingToByteString padding)
-- | 3D convolution backpropagation filter with default parameters.
conv3DBackpropFilter :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv3DBackpropFilter = conv3DBackpropFilter' id PaddingValid ChannelLast
conv3DBackpropFilter' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv3DBackpropFilter' params padding dataformat = TF.conv3DBackpropFilter'
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
(paddingToByteString padding)
-- | 3D convolution backpropagation filter with default parameters.
conv3DBackpropFilterV2 :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 Int32 -- ^ filter_sizes
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv3DBackpropFilterV2 = conv3DBackpropFilterV2' id PaddingValid ChannelLast
conv3DBackpropFilterV2' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 Int32 -- ^ filter_sizes
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv3DBackpropFilterV2' params padding dataformat = TF.conv3DBackpropFilterV2'
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
(paddingToByteString padding)
-- | 3D convolution backpropagation input with default parameters.
conv3DBackpropInput :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv3DBackpropInput = conv3DBackpropInput' id PaddingValid ChannelLast
conv3DBackpropInput' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv3DBackpropInput' params padding dataformat = TF.conv3DBackpropInput'
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
(paddingToByteString padding)
-- | 3D convolution backpropagation input with default parameters.
conv3DBackpropInputV2 :: (TF.OneOf '[Word16, Double, Float] t, TF.OneOf '[Int32, Int64] tshape)
=> TF.Tensor v1 tshape -- ^ input_sizes
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv3DBackpropInputV2 = conv3DBackpropInputV2' id PaddingValid ChannelLast
conv3DBackpropInputV2' :: (TF.OneOf '[Word16, Double, Float] t, TF.OneOf '[Int32, Int64] tshape)
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 tshape -- ^ input_sizes
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
conv3DBackpropInputV2' params padding dataformat = TF.conv3DBackpropInputV2'
(params . (TF.opAttr "data_format" .~ dataFormat3D dataformat))
(paddingToByteString padding)
-- | Depth-wise 2D convolution native with default parameters.
depthwiseConv2dNative :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor TF.Build t -- ^ output
depthwiseConv2dNative = depthwiseConv2dNative' id PaddingValid ChannelLast
depthwiseConv2dNative' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 t -- ^ filter
-> TF.Tensor TF.Build t -- ^ output
depthwiseConv2dNative' params padding dataformat = TF.depthwiseConv2dNative'
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
(paddingToByteString padding)
-- | Depth-wise 2D convolution native backpropagation filter with default parameters.
depthwiseConv2dNativeBackpropFilter :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 Int32 -- ^ filter_sizes
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
depthwiseConv2dNativeBackpropFilter = depthwiseConv2dNativeBackpropFilter' id PaddingValid ChannelLast
depthwiseConv2dNativeBackpropFilter' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 t -- ^ input
-> TF.Tensor v2 Int32 -- ^ filter_sizes
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
depthwiseConv2dNativeBackpropFilter' params padding dataformat = TF.depthwiseConv2dNativeBackpropFilter'
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
(paddingToByteString padding)
-- | Depth-wise 2D convolution native backpropagation input with default parameters.
depthwiseConv2dNativeBackpropInput :: TF.OneOf '[Word16, Double, Float] t
=> TF.Tensor v1 Int32 -- ^ input_sizes
-> TF.Tensor v2 t -- ^ input
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
depthwiseConv2dNativeBackpropInput = depthwiseConv2dNativeBackpropInput' id PaddingValid ChannelLast
depthwiseConv2dNativeBackpropInput' :: TF.OneOf '[Word16, Double, Float] t
=> TF.OpParams
-> Padding
-> DataFormat
-> TF.Tensor v1 Int32 -- ^ input_sizes
-> TF.Tensor v2 t -- ^ input
-> TF.Tensor v3 t -- ^ out_backprop
-> TF.Tensor TF.Build t -- ^ output
depthwiseConv2dNativeBackpropInput' params padding dataformat = TF.depthwiseConv2dNativeBackpropInput'
(params . (TF.opAttr "data_format" .~ dataFormat2D dataformat))
(paddingToByteString padding)

View File

@ -16,6 +16,7 @@ library
hs-source-dirs: src
exposed-modules: TensorFlow.Gradient
, TensorFlow.Ops
, TensorFlow.Convolution
, TensorFlow.EmbeddingOps
, TensorFlow.Minimize
, TensorFlow.NN

View File

@ -32,7 +32,8 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput', batchMatMul, batchMatMul', conjugateTranspose)
import qualified TensorFlow.GenOps.Core as TF (max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, batchMatMul, batchMatMul', conjugateTranspose)
import qualified TensorFlow.Convolution as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
import qualified TensorFlow.Output as TF
@ -715,10 +716,8 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
let filterShape = TF.vector [2, 2, 1, 1 :: Int32] -- [fh, fw, inc, out]
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
let y = TF.conv2DBackpropInput'
( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
)
"VALID" conv_input_shape filter' x
(TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
TF.PaddingValid TF.ChannelLast conv_input_shape filter' x
[dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x)
@ -734,10 +733,8 @@ testDepthwiseConv2dGrad = testCase "testDepthwiseConv2dGrad" $ do
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
let y = TF.depthwiseConv2dNative'
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
)
"VALID" x filter'
(TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
TF.PaddingValid TF.ChannelLast x filter'
[dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x)
@ -755,10 +752,8 @@ testDepthwiseConv2dBackpropInputGrad = testCase "testDepthwiseConv2dBackpropInpu
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
let y = TF.depthwiseConv2dNativeBackpropInput'
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
)
"VALID" conv_input_shape filter' x
(TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
TF.PaddingValid TF.ChannelLast conv_input_shape filter' x
[dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x)