mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Implemented ReluGradGrad and FillGrad (#102)
Added testReluGrad, testReluGradGrad and testFillGrad
This commit is contained in:
parent
09c792b84c
commit
eca4ff8981
2 changed files with 36 additions and 0 deletions
|
@ -439,6 +439,7 @@ opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
|
||||||
opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
|
opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
|
||||||
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
||||||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||||
|
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
||||||
|
|
||||||
opGrad "Square" _ [toT -> x] [dz] =
|
opGrad "Square" _ [toT -> x] [dz] =
|
||||||
-- TODO(fmayle): Handle complex numbers.
|
-- TODO(fmayle): Handle complex numbers.
|
||||||
|
@ -667,6 +668,9 @@ opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||||
opGrad "Size" _ _ _ = [Nothing]
|
opGrad "Size" _ _ _ = [Nothing]
|
||||||
opGrad "ZerosLike" _ _ _ = [Nothing]
|
opGrad "ZerosLike" _ _ _ = [Nothing]
|
||||||
|
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
||||||
|
where
|
||||||
|
rx = rangeOfRank dz
|
||||||
|
|
||||||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||||
|
@ -706,6 +710,7 @@ numOutputs o =
|
||||||
"OneHot" -> 1
|
"OneHot" -> 1
|
||||||
"RefIdentity" -> 1
|
"RefIdentity" -> 1
|
||||||
"Relu" -> 1
|
"Relu" -> 1
|
||||||
|
"ReluGrad" -> 1
|
||||||
"Reshape" -> 1
|
"Reshape" -> 1
|
||||||
"Select" -> 1
|
"Select" -> 1
|
||||||
"Size" -> 1
|
"Size" -> 1
|
||||||
|
@ -718,6 +723,7 @@ numOutputs o =
|
||||||
"TruncatedNormal" -> 1
|
"TruncatedNormal" -> 1
|
||||||
"Variable" -> 1
|
"Variable" -> 1
|
||||||
"ZerosLike" -> 1
|
"ZerosLike" -> 1
|
||||||
|
"Fill" -> 1
|
||||||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||||
|
|
||||||
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
||||||
|
|
|
@ -160,6 +160,33 @@ testMaxGradient = testCase "testMaxGradient" $ do
|
||||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||||
|
|
||||||
|
|
||||||
|
testReluGrad :: Test
|
||||||
|
testReluGrad = testCase "testReluGrad" $ do
|
||||||
|
[dx] <- TF.runSession $ do
|
||||||
|
x <- TF.render $ TF.vector [2 :: Float]
|
||||||
|
let y = TF.relu x
|
||||||
|
TF.gradients y [x] >>= TF.run
|
||||||
|
V.fromList [1] @=? dx
|
||||||
|
|
||||||
|
testReluGradGrad :: Test
|
||||||
|
testReluGradGrad = testCase "testReluGradGrad" $ do
|
||||||
|
[dx] <- TF.runSession $ do
|
||||||
|
x <- TF.render $ TF.vector [2 :: Float]
|
||||||
|
let y = TF.relu x
|
||||||
|
[y'] <- TF.gradients y [x]
|
||||||
|
TF.gradients y' [x] >>= TF.run
|
||||||
|
V.fromList [0] @=? dx
|
||||||
|
|
||||||
|
|
||||||
|
testFillGrad :: Test
|
||||||
|
testFillGrad = testCase "testFillGrad" $ do
|
||||||
|
[dx] <- TF.runSession $ do
|
||||||
|
x <- TF.render $ TF.scalar (9 :: Float)
|
||||||
|
let shape = TF.vector [2, 3 :: Int32]
|
||||||
|
let y = TF.fill shape x
|
||||||
|
TF.gradients y [x] >>= TF.run
|
||||||
|
V.fromList [6] @=? dx
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testGradientSimple
|
main = googleTest [ testGradientSimple
|
||||||
, testGradientDisconnected
|
, testGradientDisconnected
|
||||||
|
@ -167,4 +194,7 @@ main = googleTest [ testGradientSimple
|
||||||
, testCreateGraphNameScopes
|
, testCreateGraphNameScopes
|
||||||
, testDiamond
|
, testDiamond
|
||||||
, testMaxGradient
|
, testMaxGradient
|
||||||
|
, testReluGrad
|
||||||
|
, testReluGradGrad
|
||||||
|
, testFillGrad
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue