From ea3057726418bd9409c95c1874e43b2083ef8368 Mon Sep 17 00:00:00 2001 From: Christian Berentsen Date: Fri, 16 Jun 2017 13:42:33 +0200 Subject: [PATCH] Gradient for AddN --- tensorflow-ops/src/TensorFlow/Gradient.hs | 6 ++++++ tensorflow-ops/tests/GradientTest.hs | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index d07e9e9..dde196d 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -511,6 +511,11 @@ opGrad "Add" _ [toT -> x, toT -> y] [dz] = sy = shape (y :: Tensor Build a) (rx, ry) = broadcastGradientArgs sx sy +-- Copies the gradients to all inputs +-- Not broadcasting +opGrad "AddN" _ inputs [dz] = + map ((const . Just . expr) dz) inputs + opGrad "Sub" u v w = [Just x, Just (-y)] where @@ -711,6 +716,7 @@ numOutputs o = case o ^. op of "Abs" -> 1 "Add" -> 1 + "AddN" -> 1 "Cast" -> 1 "Const" -> 1 "Conv2D" -> 1 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index ce3d284..2902297 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -156,6 +156,15 @@ testDiamond = testCase "testDiamond" $ do (4 :: Float) @=? TF.unScalar dx +testAddNGradient :: Test +testAddNGradient = testCase "testAddNGradient" $ do + [dx] <- TF.runSession $ do + x <- TF.render $ TF.vector [1, 2, 0 :: Float] + let y = TF.addN [x, x] + TF.gradients y [x] >>= TF.run + V.fromList [2, 2, 2 :: Float] @=? dx + + testMaxGradient :: Test testMaxGradient = testCase "testMaxGradient" $ do [dx] <- TF.runSession $ do @@ -298,6 +307,7 @@ main = defaultMain , testCreateGraphStateful , testCreateGraphNameScopes , testDiamond + , testAddNGradient , testMaxGradient , testReluGrad , testReluGradGrad