mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-27 05:19:45 +01:00
Gradient for AddN
This commit is contained in:
parent
41f4c8a235
commit
ea30577264
2 changed files with 16 additions and 0 deletions
|
@ -511,6 +511,11 @@ opGrad "Add" _ [toT -> x, toT -> y] [dz] =
|
||||||
sy = shape (y :: Tensor Build a)
|
sy = shape (y :: Tensor Build a)
|
||||||
(rx, ry) = broadcastGradientArgs sx sy
|
(rx, ry) = broadcastGradientArgs sx sy
|
||||||
|
|
||||||
|
-- Copies the gradients to all inputs
|
||||||
|
-- Not broadcasting
|
||||||
|
opGrad "AddN" _ inputs [dz] =
|
||||||
|
map ((const . Just . expr) dz) inputs
|
||||||
|
|
||||||
opGrad "Sub" u v w =
|
opGrad "Sub" u v w =
|
||||||
[Just x, Just (-y)]
|
[Just x, Just (-y)]
|
||||||
where
|
where
|
||||||
|
@ -711,6 +716,7 @@ numOutputs o =
|
||||||
case o ^. op of
|
case o ^. op of
|
||||||
"Abs" -> 1
|
"Abs" -> 1
|
||||||
"Add" -> 1
|
"Add" -> 1
|
||||||
|
"AddN" -> 1
|
||||||
"Cast" -> 1
|
"Cast" -> 1
|
||||||
"Const" -> 1
|
"Const" -> 1
|
||||||
"Conv2D" -> 1
|
"Conv2D" -> 1
|
||||||
|
|
|
@ -156,6 +156,15 @@ testDiamond = testCase "testDiamond" $ do
|
||||||
(4 :: Float) @=? TF.unScalar dx
|
(4 :: Float) @=? TF.unScalar dx
|
||||||
|
|
||||||
|
|
||||||
|
testAddNGradient :: Test
|
||||||
|
testAddNGradient = testCase "testAddNGradient" $ do
|
||||||
|
[dx] <- TF.runSession $ do
|
||||||
|
x <- TF.render $ TF.vector [1, 2, 0 :: Float]
|
||||||
|
let y = TF.addN [x, x]
|
||||||
|
TF.gradients y [x] >>= TF.run
|
||||||
|
V.fromList [2, 2, 2 :: Float] @=? dx
|
||||||
|
|
||||||
|
|
||||||
testMaxGradient :: Test
|
testMaxGradient :: Test
|
||||||
testMaxGradient = testCase "testMaxGradient" $ do
|
testMaxGradient = testCase "testMaxGradient" $ do
|
||||||
[dx] <- TF.runSession $ do
|
[dx] <- TF.runSession $ do
|
||||||
|
@ -298,6 +307,7 @@ main = defaultMain
|
||||||
, testCreateGraphStateful
|
, testCreateGraphStateful
|
||||||
, testCreateGraphNameScopes
|
, testCreateGraphNameScopes
|
||||||
, testDiamond
|
, testDiamond
|
||||||
|
, testAddNGradient
|
||||||
, testMaxGradient
|
, testMaxGradient
|
||||||
, testReluGrad
|
, testReluGrad
|
||||||
, testReluGradGrad
|
, testReluGradGrad
|
||||||
|
|
Loading…
Reference in a new issue