From 3cfd96ef087008da5c15383dc76111c68c709e6e Mon Sep 17 00:00:00 2001 From: erikabor <48091676+erikabor@users.noreply.github.com> Date: Tue, 26 Mar 2019 21:30:50 +0100 Subject: [PATCH] Add gradient for slice function (#234) --- tensorflow-ops/src/TensorFlow/Gradient.hs | 22 ++++++++++++++++++++ tensorflow-ops/tests/GradientTest.hs | 25 +++++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 83914fc..b029998 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -710,6 +710,27 @@ opGrad "Pad" _ [toT -> x, toT -> padPattern] [dz] = gradientSliceBegin = CoreOps.reshape padPatternSliced rankx gradientSliceSize = shape (x :: Tensor Build Float) +-- Gradient for Slice +-- Create an Nx2 padding where N is the rank of (grad of) Slice and the first +-- column represents how many zeros are to be prepended for each dimension, and the second +-- column indicates how many zeros are appended. +-- The number of zeros to prepend is the shape of the beginvec. +-- The number of zeros to append is the shape of the inputvec +-- elementwise-subtracted by both the beginvec and sizevec. +-- Some more reshaping is needed to assemble this tensor with the +-- right dimensions. +opGrad "Slice" _ [toT -> inputvec, toT -> beginvec, _] [dz] = + [Just $ CoreOps.pad dz paddings, Nothing, Nothing] + where + v1 = vector [1 :: Int32] + inputRank' = CoreOps.rank (inputvec :: Tensor Build Float) + -- For some reason inputRank' has an empty shape + inputRank = CoreOps.reshape inputRank' v1 + padShape = CoreOps.concat 0 [inputRank, v1] + beforePad = CoreOps.reshape beginvec padShape + afterPad = CoreOps.reshape (shape inputvec - shape dz - beginvec) padShape + paddings = CoreOps.concat 1 [beforePad, afterPad] + -- TODO: This could be either Int32 or Int64. opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] = [Just $ CoreOps.spaceToBatchND dz blockShape crops, Nothing, Nothing] @@ -862,6 +883,7 @@ numOutputs o = "Reshape" -> 1 "Select" -> 1 "Size" -> 1 + "Slice" -> 1 "SoftmaxCrossEntropyWithLogits" -> 2 "SpaceToBatchND" -> 1 "SparseSegmentSum" -> 1 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 0cbb6eb..9af14e4 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -32,9 +32,9 @@ import Control.Monad(forM_, replicateM, zipWithM) import Control.Monad.IO.Class (liftIO) import qualified TensorFlow.Core as TF -import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt) +import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape) import qualified TensorFlow.Gradient as TF -import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable) +import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape) import qualified TensorFlow.Output as TF import qualified TensorFlow.Types as TF import qualified TensorFlow.Variable as TF @@ -324,6 +324,7 @@ testPad = V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx V.fromList [2, 2, 3] @=? s + testSqrt :: Test testSqrt = testCase "testSqrt" $ do [dx] <- TF.runSession $ do @@ -332,6 +333,25 @@ testSqrt = testCase "testSqrt" $ do TF.gradients y [x] >>= TF.run V.fromList [2] @=? dx +testSlice :: Test +testSlice = + testCase "testSlice" $ do + ([dx], [s]) <- + TF.runSession $ do + (x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 3, 4 :: Int64] + (z :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 2 :: Int64] + let y = TF.slice x (TF.constant (TF.Shape [3]) [1, 1, 1 :: Int32]) (TF.shape z) + calculateGradWithShape y x + let expected = + [0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 1, 1, 0, + 0, 1, 1, 0] + V.fromList expected @=? dx + V.fromList [2, 3, 4] @=? s + testBatchToSpaceND :: Test testBatchToSpaceND = testCase "testBatchToSpaceND" $ do @@ -526,6 +546,7 @@ main = defaultMain , testReshape , testPad , testSqrt + , testSlice , testBatchToSpaceND , testSpaceToBatchND , testSqueeze