mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Make 'mean' doubly differentiable
Use stopGradient on shape computations Add opGrad for StopGradient
This commit is contained in:
parent
96f1c88327
commit
943d23151c
|
@ -556,7 +556,7 @@ opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
|
|||
grad = reshape dz outputShapeKeptDims
|
||||
|
||||
opGrad "Mean" u v@[toT -> x, _] w =
|
||||
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
|
||||
[Just $ dz `CoreOps.div` (CoreOps.stopGradient $ CoreOps.cast $ factor), Nothing]
|
||||
where
|
||||
[Just dz, Nothing] = opGrad "Sum" u v w
|
||||
inputShape = shape (x :: Tensor Build a)
|
||||
|
@ -842,6 +842,7 @@ opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
|||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "Placeholder" _ _ _ = []
|
||||
opGrad "StopGradient" _ _ _ = [Nothing]
|
||||
opGrad "VarHandleOp" _ _ _ = []
|
||||
opGrad "Variable" _ _ _ = []
|
||||
|
||||
|
@ -886,6 +887,7 @@ numOutputs o =
|
|||
"Neg" -> 1
|
||||
"Pad" -> 1
|
||||
"Placeholder" -> 1
|
||||
"StopGradient" -> 1
|
||||
"OneHot" -> 1
|
||||
"ReadVariableOp" -> 1
|
||||
"RefIdentity" -> 1
|
||||
|
|
|
@ -170,6 +170,23 @@ testAddNGradient = testCase "testAddNGradient" $ do
|
|||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [2, 2, 2 :: Float] @=? dx
|
||||
|
||||
testMeanGradient :: Test
|
||||
testMeanGradient = testCase "testMeanGradient" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [1, 2, 0 :: Float]
|
||||
let y = TF.mean x (TF.vector [0 :: Int32])
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [1, 1, 1 :: Float] @=? dx
|
||||
|
||||
testMeanGradGrad :: Test
|
||||
testMeanGradGrad = testCase "testMeanGradGrad" $ do
|
||||
[ddx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [1, 2, 0 :: Float]
|
||||
let y = TF.mean x (TF.vector [0 :: Int32])
|
||||
[dx] <- TF.gradients y [x]
|
||||
TF.gradients dx [x] >>= TF.run
|
||||
|
||||
V.fromList [0, 0, 0 :: Float] @=? ddx
|
||||
|
||||
testMaxGradient :: Test
|
||||
testMaxGradient = testCase "testMaxGradient" $ do
|
||||
|
@ -549,6 +566,8 @@ main = defaultMain
|
|||
, testCreateGraphNameScopes
|
||||
, testDiamond
|
||||
, testAddNGradient
|
||||
, testMeanGradient
|
||||
, testMeanGradGrad
|
||||
, testMaxGradient
|
||||
, testConcatGradient
|
||||
, testConcatGradientSimple
|
||||
|
|
Loading…
Reference in New Issue
Block a user