1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Implemented ReluGradGrad and FillGrad (#102)

Added testReluGrad, testReluGradGrad and testFillGrad
This commit is contained in:
Christian Berentsen 2017-04-30 20:18:06 +02:00 committed by fkm3
parent 09c792b84c
commit eca4ff8981
2 changed files with 36 additions and 0 deletions

View file

@ -439,6 +439,7 @@ opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x] opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz] opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x] opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
opGrad "Square" _ [toT -> x] [dz] = opGrad "Square" _ [toT -> x] [dz] =
-- TODO(fmayle): Handle complex numbers. -- TODO(fmayle): Handle complex numbers.
@ -667,6 +668,9 @@ opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing] opGrad "LabelWeights" _ _ _ = [Nothing]
opGrad "Size" _ _ _ = [Nothing] opGrad "Size" _ _ _ = [Nothing]
opGrad "ZerosLike" _ _ _ = [Nothing] opGrad "ZerosLike" _ _ _ = [Nothing]
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
where
rx = rangeOfRank dz
-- TODO(fmayle): These can go away if we properly prune the graph. -- TODO(fmayle): These can go away if we properly prune the graph.
opGrad "Const" _ _ _ = [Nothing, Nothing] opGrad "Const" _ _ _ = [Nothing, Nothing]
@ -706,6 +710,7 @@ numOutputs o =
"OneHot" -> 1 "OneHot" -> 1
"RefIdentity" -> 1 "RefIdentity" -> 1
"Relu" -> 1 "Relu" -> 1
"ReluGrad" -> 1
"Reshape" -> 1 "Reshape" -> 1
"Select" -> 1 "Select" -> 1
"Size" -> 1 "Size" -> 1
@ -718,6 +723,7 @@ numOutputs o =
"TruncatedNormal" -> 1 "TruncatedNormal" -> 1
"Variable" -> 1 "Variable" -> 1
"ZerosLike" -> 1 "ZerosLike" -> 1
"Fill" -> 1
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op) _ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0` -- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`

View file

@ -160,6 +160,33 @@ testMaxGradient = testCase "testMaxGradient" $ do
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
testReluGrad :: Test
testReluGrad = testCase "testReluGrad" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.vector [2 :: Float]
let y = TF.relu x
TF.gradients y [x] >>= TF.run
V.fromList [1] @=? dx
testReluGradGrad :: Test
testReluGradGrad = testCase "testReluGradGrad" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.vector [2 :: Float]
let y = TF.relu x
[y'] <- TF.gradients y [x]
TF.gradients y' [x] >>= TF.run
V.fromList [0] @=? dx
testFillGrad :: Test
testFillGrad = testCase "testFillGrad" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.scalar (9 :: Float)
let shape = TF.vector [2, 3 :: Int32]
let y = TF.fill shape x
TF.gradients y [x] >>= TF.run
V.fromList [6] @=? dx
main :: IO () main :: IO ()
main = googleTest [ testGradientSimple main = googleTest [ testGradientSimple
, testGradientDisconnected , testGradientDisconnected
@ -167,4 +194,7 @@ main = googleTest [ testGradientSimple
, testCreateGraphNameScopes , testCreateGraphNameScopes
, testDiamond , testDiamond
, testMaxGradient , testMaxGradient
, testReluGrad
, testReluGradGrad
, testFillGrad
] ]