From 2dcc921f6e2119995dc2558295e3548c1d93f24f Mon Sep 17 00:00:00 2001 From: Christian Berentsen Date: Sun, 15 Oct 2017 20:49:44 +0200 Subject: [PATCH] Gradient of Conv2DBackpropInput (#155) --- tensorflow-ops/src/TensorFlow/Gradient.hs | 22 ++++++++++++++++++ tensorflow-ops/tensorflow-ops.cabal | 1 + tensorflow-ops/tests/GradientTest.hs | 28 +++++++++++++++++++++-- 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 76f7c68..9fcbaaf 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -650,6 +650,27 @@ opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] = useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool dataFormat = lookupAttr nodeDef "data_format" :: ByteString +opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] = + [ Nothing + , Just $ CoreOps.conv2DBackpropFilter' + ((opAttr "strides" .~ strides) + . (opAttr "padding" .~ padding) + . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu) + . (opAttr "data_format" .~ dataFormat)) + dz (shape x) y + , Just $ CoreOps.conv2D' + ((opAttr "strides" .~ strides) + . (opAttr "padding" .~ padding) + . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu) + . (opAttr "data_format" .~ dataFormat)) + dz x + ] + where + strides = lookupAttr nodeDef "strides" :: [Int64] + padding = lookupAttr nodeDef "padding" :: ByteString + useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool + dataFormat = lookupAttr nodeDef "data_format" :: ByteString + opGrad "MaxPool" nodeDef [toT -> x] [dz] = [ Just $ CoreOps.maxPoolGrad' ((opAttr "ksize" .~ ksize) @@ -779,6 +800,7 @@ numOutputs o = "Const" -> 1 "Concat" -> 1 "Conv2D" -> 1 + "Conv2DBackpropInput" -> 1 "Div" -> 1 "DynamicStitch" -> 1 "DynamicPartition" -> diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index b13c283..f64aab5 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -188,6 +188,7 @@ Test-Suite GradientTest hs-source-dirs: tests build-depends: HUnit , base + , bytestring , proto-lens , lens-family , random diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index ca4cedb..04391f7 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM) import Control.Monad.IO.Class (liftIO) import qualified TensorFlow.Core as TF -import qualified TensorFlow.GenOps.Core as TF (max, tile, maximum) +import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile) import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable) import qualified TensorFlow.Output as TF @@ -42,6 +42,8 @@ import qualified TensorFlow.Variable as TF import Proto.Tensorflow.Core.Framework.Graph (node) import Proto.Tensorflow.Core.Framework.NodeDef (op) +import qualified Data.ByteString.Char8 as BS + testGradientSimple :: Test testGradientSimple = testCase "testGradientSimple" $ do let grads = do @@ -313,7 +315,6 @@ testTile2DGrad = testCase "testTileGrad2D" $ do shapeX @=? (shapeDX :: V.Vector Int32) V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float) - matMulGradient :: Test matMulGradient = testCase "matMulGradients" $ do @@ -388,6 +389,28 @@ transAttrs :: (TF.Attribute a, transAttrs a b = (TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b) +testConv2DBackpropInputGrad :: Test +testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do + (dx, shapeDX, shapeX) <- TF.runSession $ do + let conv_input_shape = TF.vector [1, 2, 2, 1 :: Int32] -- [batch, h, w, in_channels] + let conv_out_shape = TF.vector [1, 1, 1, 1 :: Int32] -- [batch, h, w, out_channels] + x <- TF.render $ TF.fill conv_out_shape (TF.scalar (1::Float)) + + let filterShape = TF.vector [2, 2, 1, 1 :: Int32] -- [fh, fw, inc, out] + filter <- TF.render $ TF.fill filterShape (TF.scalar (1::Float)) + let y = TF.conv2DBackpropInput' + ( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1]) + . (TF.opAttr "padding" .~ (BS.pack "VALID")) + . (TF.opAttr "data_format" .~ (BS.pack "NHWC")) + ) + conv_input_shape filter x + + [dx] <- TF.gradients y [x] + TF.run (dx, TF.shape dx, TF.shape x) + shapeX @=? (shapeDX :: V.Vector Int32) + V.fromList [4::Float] @=? (dx :: V.Vector Float) + + main :: IO () main = defaultMain [ testGradientSimple @@ -413,4 +436,5 @@ main = defaultMain , matMulTransposeGradient (False, True) , matMulTransposeGradient (True, False) , matMulTransposeGradient (True, True) + , testConv2DBackpropInputGrad ]