1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00

Add gradient for sqrt function

This commit is contained in:
Erika Bor 2019-03-11 16:31:15 +01:00
parent 896a0d31f7
commit c0e88742cb
2 changed files with 16 additions and 2 deletions

View File

@ -813,6 +813,10 @@ opGrad "Placeholder" _ _ _ = []
opGrad "VarHandleOp" _ _ _ = []
opGrad "Variable" _ _ _ = []
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
where
sq' = vector [1] `CoreOps.div` (vector [2] `CoreOps.mul` CoreOps.sqrt x)
opGrad n nodeDef ins grads =
error $ "no gradient implemented for " ++
show (n, length ins, length grads, showMessage nodeDef, ins)
@ -863,6 +867,7 @@ numOutputs o =
"SparseSegmentSum" -> 1
"Square" -> 1
"Squeeze" -> 1
"Sqrt" -> 1
"Sub" -> 1
"Sum" -> 1
"Tanh" -> 1

View File

@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze)
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt)
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
import qualified TensorFlow.Output as TF
@ -324,6 +324,14 @@ testPad =
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
V.fromList [2, 2, 3] @=? s
testSqrt :: Test
testSqrt = testCase "testSqrt" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.vector [0.25 :: Float]
let y = TF.sqrt x
TF.gradients y [x] >>= TF.run
V.fromList [1] @=? dx
testBatchToSpaceND :: Test
testBatchToSpaceND =
testCase "testBatchToSpaceND" $ do
@ -517,6 +525,7 @@ main = defaultMain
, testExpandDims
, testReshape
, testPad
, testSqrt
, testBatchToSpaceND
, testSpaceToBatchND
, testSqueeze
@ -530,4 +539,4 @@ main = defaultMain
, matMulTransposeGradient (True, False)
, matMulTransposeGradient (True, True)
, testConv2DBackpropInputGrad
]
]