1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-01-27 03:05:01 +01:00

Add gradient for sqrt function (#236)

This commit is contained in:
erikabor 2019-03-19 02:08:08 +01:00 committed by fkm3
parent 896a0d31f7
commit 666dce94bd
2 changed files with 16 additions and 2 deletions

View file

@ -813,6 +813,10 @@ opGrad "Placeholder" _ _ _ = []
opGrad "VarHandleOp" _ _ _ = []
opGrad "Variable" _ _ _ = []
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
where
sq' = scalar 1 `CoreOps.div` (scalar 2 `CoreOps.mul` CoreOps.sqrt x)
opGrad n nodeDef ins grads =
error $ "no gradient implemented for " ++
show (n, length ins, length grads, showMessage nodeDef, ins)
@ -863,6 +867,7 @@ numOutputs o =
"SparseSegmentSum" -> 1
"Square" -> 1
"Squeeze" -> 1
"Sqrt" -> 1
"Sub" -> 1
"Sum" -> 1
"Tanh" -> 1

View file

@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze)
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt)
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
import qualified TensorFlow.Output as TF
@ -324,6 +324,14 @@ testPad =
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
V.fromList [2, 2, 3] @=? s
testSqrt :: Test
testSqrt = testCase "testSqrt" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.vector [0.0625 :: Float]
let y = TF.sqrt x
TF.gradients y [x] >>= TF.run
V.fromList [2] @=? dx
testBatchToSpaceND :: Test
testBatchToSpaceND =
testCase "testBatchToSpaceND" $ do
@ -517,6 +525,7 @@ main = defaultMain
, testExpandDims
, testReshape
, testPad
, testSqrt
, testBatchToSpaceND
, testSpaceToBatchND
, testSqueeze
@ -530,4 +539,4 @@ main = defaultMain
, matMulTransposeGradient (True, False)
, matMulTransposeGradient (True, True)
, testConv2DBackpropInputGrad
]
]