From 4a2e46ba5706226cd6b46f8c09468843ea019179 Mon Sep 17 00:00:00 2001 From: Christian Berentsen Date: Mon, 22 Apr 2019 06:46:01 +0200 Subject: [PATCH] Make 'mean' doubly differentiable (#241) Use stopGradient on shape computations Add opGrad for StopGradient --- tensorflow-ops/src/TensorFlow/Gradient.hs | 4 +++- tensorflow-ops/tests/GradientTest.hs | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 6f7494c..c2d274c 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -574,7 +574,7 @@ opGrad "Sum" _ [toT -> x, toT -> indices] [dz] = grad = reshape dz outputShapeKeptDims opGrad "Mean" u v@[toT -> x, _] w = - [Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing] + [Just $ dz `CoreOps.div` (CoreOps.stopGradient $ CoreOps.cast $ factor), Nothing] where [Just dz, Nothing] = opGrad "Sum" u v w inputShape = shape (x :: Tensor Build a) @@ -858,6 +858,7 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx] opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz] opGrad "Const" _ _ _ = [Nothing, Nothing] +opGrad "StopGradient" _ _ _ = [Nothing] opGrad "VarHandleOp" _ _ _ = [] opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz] @@ -901,6 +902,7 @@ numOutputs o = "Neg" -> 1 "Pad" -> 1 "Placeholder" -> 1 + "StopGradient" -> 1 "OneHot" -> 1 "ReadVariableOp" -> 1 "RefIdentity" -> 1 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index a2207ad..56b2f67 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -230,6 +230,23 @@ testAddNGradient = testCase "testAddNGradient" $ do TF.gradients y [x] >>= TF.run V.fromList [2, 2, 2 :: Float] @=? dx +testMeanGradient :: Test +testMeanGradient = testCase "testMeanGradient" $ do + [dx] <- TF.runSession $ do + x <- TF.render $ TF.vector [1, 2, 0 :: Float] + let y = TF.mean x (TF.vector [0 :: Int32]) + TF.gradients y [x] >>= TF.run + V.fromList [1, 1, 1 :: Float] @=? dx + +testMeanGradGrad :: Test +testMeanGradGrad = testCase "testMeanGradGrad" $ do + [ddx] <- TF.runSession $ do + x <- TF.render $ TF.vector [1, 2, 0 :: Float] + let y = TF.mean x (TF.vector [0 :: Int32]) + [dx] <- TF.gradients y [x] + TF.gradients dx [x] >>= TF.run + + V.fromList [0, 0, 0 :: Float] @=? ddx testMaxGradient :: Test testMaxGradient = testCase "testMaxGradient" $ do @@ -611,6 +628,8 @@ main = defaultMain , testCreateGraphNameScopes , testDiamond , testAddNGradient + , testMeanGradient + , testMeanGradGrad , testMaxGradient , testConcatGradient , testConcatGradientSimple