mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Make 'mean' doubly differentiable (#241)
Use stopGradient on shape computations Add opGrad for StopGradient
This commit is contained in:
parent
7316062c10
commit
4a2e46ba57
2 changed files with 22 additions and 1 deletions
|
@ -574,7 +574,7 @@ opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
|
||||||
grad = reshape dz outputShapeKeptDims
|
grad = reshape dz outputShapeKeptDims
|
||||||
|
|
||||||
opGrad "Mean" u v@[toT -> x, _] w =
|
opGrad "Mean" u v@[toT -> x, _] w =
|
||||||
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
|
[Just $ dz `CoreOps.div` (CoreOps.stopGradient $ CoreOps.cast $ factor), Nothing]
|
||||||
where
|
where
|
||||||
[Just dz, Nothing] = opGrad "Sum" u v w
|
[Just dz, Nothing] = opGrad "Sum" u v w
|
||||||
inputShape = shape (x :: Tensor Build a)
|
inputShape = shape (x :: Tensor Build a)
|
||||||
|
@ -858,6 +858,7 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
||||||
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
||||||
|
|
||||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||||
|
opGrad "StopGradient" _ _ _ = [Nothing]
|
||||||
opGrad "VarHandleOp" _ _ _ = []
|
opGrad "VarHandleOp" _ _ _ = []
|
||||||
|
|
||||||
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
|
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
|
||||||
|
@ -901,6 +902,7 @@ numOutputs o =
|
||||||
"Neg" -> 1
|
"Neg" -> 1
|
||||||
"Pad" -> 1
|
"Pad" -> 1
|
||||||
"Placeholder" -> 1
|
"Placeholder" -> 1
|
||||||
|
"StopGradient" -> 1
|
||||||
"OneHot" -> 1
|
"OneHot" -> 1
|
||||||
"ReadVariableOp" -> 1
|
"ReadVariableOp" -> 1
|
||||||
"RefIdentity" -> 1
|
"RefIdentity" -> 1
|
||||||
|
|
|
@ -230,6 +230,23 @@ testAddNGradient = testCase "testAddNGradient" $ do
|
||||||
TF.gradients y [x] >>= TF.run
|
TF.gradients y [x] >>= TF.run
|
||||||
V.fromList [2, 2, 2 :: Float] @=? dx
|
V.fromList [2, 2, 2 :: Float] @=? dx
|
||||||
|
|
||||||
|
testMeanGradient :: Test
|
||||||
|
testMeanGradient = testCase "testMeanGradient" $ do
|
||||||
|
[dx] <- TF.runSession $ do
|
||||||
|
x <- TF.render $ TF.vector [1, 2, 0 :: Float]
|
||||||
|
let y = TF.mean x (TF.vector [0 :: Int32])
|
||||||
|
TF.gradients y [x] >>= TF.run
|
||||||
|
V.fromList [1, 1, 1 :: Float] @=? dx
|
||||||
|
|
||||||
|
testMeanGradGrad :: Test
|
||||||
|
testMeanGradGrad = testCase "testMeanGradGrad" $ do
|
||||||
|
[ddx] <- TF.runSession $ do
|
||||||
|
x <- TF.render $ TF.vector [1, 2, 0 :: Float]
|
||||||
|
let y = TF.mean x (TF.vector [0 :: Int32])
|
||||||
|
[dx] <- TF.gradients y [x]
|
||||||
|
TF.gradients dx [x] >>= TF.run
|
||||||
|
|
||||||
|
V.fromList [0, 0, 0 :: Float] @=? ddx
|
||||||
|
|
||||||
testMaxGradient :: Test
|
testMaxGradient :: Test
|
||||||
testMaxGradient = testCase "testMaxGradient" $ do
|
testMaxGradient = testCase "testMaxGradient" $ do
|
||||||
|
@ -611,6 +628,8 @@ main = defaultMain
|
||||||
, testCreateGraphNameScopes
|
, testCreateGraphNameScopes
|
||||||
, testDiamond
|
, testDiamond
|
||||||
, testAddNGradient
|
, testAddNGradient
|
||||||
|
, testMeanGradient
|
||||||
|
, testMeanGradGrad
|
||||||
, testMaxGradient
|
, testMaxGradient
|
||||||
, testConcatGradient
|
, testConcatGradient
|
||||||
, testConcatGradientSimple
|
, testConcatGradientSimple
|
||||||
|
|
Loading…
Reference in a new issue