From 96f1c883273026530dfa5ca59a2ed28f00fdbe7b Mon Sep 17 00:00:00 2001 From: Christian Berentsen Date: Mon, 8 Apr 2019 19:43:17 +0200 Subject: [PATCH] Add gradient for ResizeBilinear (#239) --- tensorflow-ops/src/TensorFlow/Gradient.hs | 12 ++++++++++++ tensorflow-ops/tests/GradientTest.hs | 19 ++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index b029998..cb12aee 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -817,6 +817,17 @@ opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] = axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32) reshapedDz = CoreOps.reshape dz splitShape +opGrad "ResizeBilinear" nodeDef [toT -> x, _] [dz] = + [ Just $ CoreOps.resizeBilinearGrad' + (opAttr "align_corners" .~ align) + (CoreOps.cast dz) + x + + , Nothing + ] + where + align = lookupAttr nodeDef "align_corners" :: Bool + opGrad "ZerosLike" _ _ _ = [Nothing] opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx] where @@ -894,6 +905,7 @@ numOutputs o = "Sum" -> 1 "Tanh" -> 1 "Tile" -> 1 + "ResizeBilinear" -> 1 "Transpose" -> 1 "TruncatedNormal" -> 1 "VarHandleOp" -> 1 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 9af14e4..dd36480 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM) import Control.Monad.IO.Class (liftIO) import qualified TensorFlow.Core as TF -import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape) +import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape) import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape) import qualified TensorFlow.Output as TF @@ -429,6 +429,22 @@ testTile2DGrad = testCase "testTileGrad2D" $ do shapeX @=? (shapeDX :: V.Vector Int32) V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float) +testResizeBilinearGrad :: Test +testResizeBilinearGrad = testCase "testResizeBilinearGrad" $ do + (dx, shapeDX, shapeX) <- TF.runSession $ do + let shape = TF.vector [1, 2, 2, 1 :: Int32] + x <- TF.render $ TF.fill shape (TF.scalar (1 :: Float)) + let outSize = TF.vector [4, 4 :: Int32] + align = TF.opAttr "align_corners" .~ True + y = TF.resizeBilinear' align x outSize + + [dx] <- TF.gradients y [x] + TF.run (dx, TF.shape dx, TF.shape x) + shapeX @=? (shapeDX :: V.Vector Int32) + let expect = V.fromList [4, 4, 4, 4 :: Float] + near = 0.00001 > (V.sum $ V.zipWith (-) expect (dx :: V.Vector Float)) + near @=? True + matMulGradient :: Test matMulGradient = testCase "matMulGradients" $ do @@ -553,6 +569,7 @@ main = defaultMain , testFillGrad , testTileGrad , testTile2DGrad + , testResizeBilinearGrad , matMulGradient , matMulGradGrad , matMulTransposeGradient (False, False)