mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-24 10:39:45 +01:00
Implemented ReluGradGrad and FillGrad (#102)
Added testReluGrad, testReluGradGrad and testFillGrad
This commit is contained in:
parent
09c792b84c
commit
eca4ff8981
2 changed files with 36 additions and 0 deletions
|
@ -439,6 +439,7 @@ opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
|
|||
opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
|
||||
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
||||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
||||
|
||||
opGrad "Square" _ [toT -> x] [dz] =
|
||||
-- TODO(fmayle): Handle complex numbers.
|
||||
|
@ -667,6 +668,9 @@ opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
|||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||
opGrad "Size" _ _ _ = [Nothing]
|
||||
opGrad "ZerosLike" _ _ _ = [Nothing]
|
||||
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
||||
where
|
||||
rx = rangeOfRank dz
|
||||
|
||||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||
|
@ -706,6 +710,7 @@ numOutputs o =
|
|||
"OneHot" -> 1
|
||||
"RefIdentity" -> 1
|
||||
"Relu" -> 1
|
||||
"ReluGrad" -> 1
|
||||
"Reshape" -> 1
|
||||
"Select" -> 1
|
||||
"Size" -> 1
|
||||
|
@ -718,6 +723,7 @@ numOutputs o =
|
|||
"TruncatedNormal" -> 1
|
||||
"Variable" -> 1
|
||||
"ZerosLike" -> 1
|
||||
"Fill" -> 1
|
||||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||
|
||||
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
||||
|
|
|
@ -160,6 +160,33 @@ testMaxGradient = testCase "testMaxGradient" $ do
|
|||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||
|
||||
|
||||
testReluGrad :: Test
|
||||
testReluGrad = testCase "testReluGrad" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [2 :: Float]
|
||||
let y = TF.relu x
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [1] @=? dx
|
||||
|
||||
testReluGradGrad :: Test
|
||||
testReluGradGrad = testCase "testReluGradGrad" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [2 :: Float]
|
||||
let y = TF.relu x
|
||||
[y'] <- TF.gradients y [x]
|
||||
TF.gradients y' [x] >>= TF.run
|
||||
V.fromList [0] @=? dx
|
||||
|
||||
|
||||
testFillGrad :: Test
|
||||
testFillGrad = testCase "testFillGrad" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.scalar (9 :: Float)
|
||||
let shape = TF.vector [2, 3 :: Int32]
|
||||
let y = TF.fill shape x
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [6] @=? dx
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testGradientSimple
|
||||
, testGradientDisconnected
|
||||
|
@ -167,4 +194,7 @@ main = googleTest [ testGradientSimple
|
|||
, testCreateGraphNameScopes
|
||||
, testDiamond
|
||||
, testMaxGradient
|
||||
, testReluGrad
|
||||
, testReluGradGrad
|
||||
, testFillGrad
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue