mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-27 03:05:01 +01:00
Add gradient of 'maximum' and 'gradForBinaryCwise'
`maximum` gradient uses `gradForBinaryCwise` which may be useful for other binary componentwise op gradients
This commit is contained in:
parent
ea30577264
commit
bebc4aa7d9
2 changed files with 50 additions and 2 deletions
|
@ -431,6 +431,22 @@ flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
|
|||
nodeDefName :: NodeDef -> NodeName
|
||||
nodeDefName = NodeName . view name
|
||||
|
||||
-- | Gradient helper for binary component wise operations
|
||||
-- See https://github.com/tensorflow/tensorflow/blob/e9de087fa7f59c39bbe12ac2c83c5547c83f746c/tensorflow/core/ops/math_grad.cc#L329
|
||||
gradForBinaryCwise :: ( OneOf '[ Int32, Int64, Float, Double, Complex Float, Complex Double ] t
|
||||
)
|
||||
=> (Tensor v1 t, Tensor v1 t)
|
||||
-> (Tensor v1 t, Tensor v1 t)
|
||||
-> [ Maybe (Tensor Build t) ]
|
||||
gradForBinaryCwise (x, gx) (y, gy) =
|
||||
[ Just dx
|
||||
, Just dy ]
|
||||
where
|
||||
dx = reshape (sum gx rx) sx
|
||||
dy = reshape (sum gy ry) sy
|
||||
sx = shape x -- (x :: Tensor Build t)
|
||||
sy = shape y -- (y :: Tensor Build t)
|
||||
(rx, ry) = broadcastGradientArgs sx sy
|
||||
|
||||
-- | The gradient function for an op type.
|
||||
--
|
||||
|
@ -483,6 +499,15 @@ opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
|
|||
-- Min and Max have identical gradient implementations.
|
||||
opGrad "Min" u v w = opGrad "Max" u v w
|
||||
|
||||
-- Element wise maximum gradient
|
||||
-- See https://github.com/tensorflow/tensorflow/blob/e9de087fa7f59c39bbe12ac2c83c5547c83f746c/tensorflow/core/ops/math_grad.cc#L473
|
||||
opGrad "Maximum" _ [toT -> x, toT -> y] [dz] =
|
||||
gradForBinaryCwise (x, gx) (y, gy)
|
||||
where
|
||||
xmask = CoreOps.greaterEqual x y
|
||||
gx = CoreOps.select xmask dz (CoreOps.zerosLike dz)
|
||||
gy = CoreOps.select (CoreOps.logicalNot xmask) dz (CoreOps.zerosLike dz)
|
||||
|
||||
opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
|
||||
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
|
||||
where
|
||||
|
@ -731,6 +756,7 @@ numOutputs o =
|
|||
"Log" -> 1
|
||||
"MatMul" -> 1
|
||||
"Max" -> 1
|
||||
"Maximum" -> 1
|
||||
"MaxPool" -> 1
|
||||
"Mean" -> 1
|
||||
"Min" -> 1
|
||||
|
|
|
@ -29,7 +29,7 @@ import qualified Data.Vector as V
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (max, tile)
|
||||
import qualified TensorFlow.GenOps.Core as TF (max, tile, maximum)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
||||
import qualified TensorFlow.Output as TF
|
||||
|
@ -173,6 +173,27 @@ testMaxGradient = testCase "testMaxGradient" $ do
|
|||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||
|
||||
-- run single test like this:
|
||||
-- stack --docker --docker-image=$IMAGE_NAME test tensorflow-ops:GradientTest --test-arguments -t"*MaximumGrad*"
|
||||
testMaximumGrad :: Test
|
||||
testMaximumGrad = testCase "testMaximumGrad" $ do
|
||||
[gx, gy] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [0 :: Float]
|
||||
y <- TF.render $ TF.vector [0 :: Float]
|
||||
let z = TF.maximum x y
|
||||
TF.gradients z [x, y] >>= TF.run
|
||||
V.fromList [1] @=? gx
|
||||
V.fromList [1] @=? gy
|
||||
|
||||
testMaximumGradGrad :: Test
|
||||
testMaximumGradGrad = testCase "testMaximumGradGrad" $ do
|
||||
[ggx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [2 :: Float]
|
||||
y <- TF.render $ TF.vector [1 :: Float]
|
||||
let z = TF.maximum x y
|
||||
[gx, _gy] <- TF.gradients z [x, y]
|
||||
TF.gradients gx [x] >>= TF.run
|
||||
V.fromList [0] @=? ggx
|
||||
|
||||
testReluGrad :: Test
|
||||
testReluGrad = testCase "testReluGrad" $ do
|
||||
|
@ -191,7 +212,6 @@ testReluGradGrad = testCase "testReluGradGrad" $ do
|
|||
TF.gradients y' [x] >>= TF.run
|
||||
V.fromList [0] @=? dx
|
||||
|
||||
|
||||
testFillGrad :: Test
|
||||
testFillGrad = testCase "testFillGrad" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
|
@ -309,6 +329,8 @@ main = defaultMain
|
|||
, testDiamond
|
||||
, testAddNGradient
|
||||
, testMaxGradient
|
||||
, testMaximumGrad
|
||||
, testMaximumGradGrad
|
||||
, testReluGrad
|
||||
, testReluGradGrad
|
||||
, testFillGrad
|
||||
|
|
Loading…
Add table
Reference in a new issue