1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Add gradient for sqrt function (#236)

This commit is contained in:
erikabor 2019-03-19 02:08:08 +01:00 committed by fkm3
parent 896a0d31f7
commit 666dce94bd
2 changed files with 16 additions and 2 deletions

View file

@ -813,6 +813,10 @@ opGrad "Placeholder" _ _ _ = []
opGrad "VarHandleOp" _ _ _ = [] opGrad "VarHandleOp" _ _ _ = []
opGrad "Variable" _ _ _ = [] opGrad "Variable" _ _ _ = []
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
where
sq' = scalar 1 `CoreOps.div` (scalar 2 `CoreOps.mul` CoreOps.sqrt x)
opGrad n nodeDef ins grads = opGrad n nodeDef ins grads =
error $ "no gradient implemented for " ++ error $ "no gradient implemented for " ++
show (n, length ins, length grads, showMessage nodeDef, ins) show (n, length ins, length grads, showMessage nodeDef, ins)
@ -863,6 +867,7 @@ numOutputs o =
"SparseSegmentSum" -> 1 "SparseSegmentSum" -> 1
"Square" -> 1 "Square" -> 1
"Squeeze" -> 1 "Squeeze" -> 1
"Sqrt" -> 1
"Sub" -> 1 "Sub" -> 1
"Sum" -> 1 "Sum" -> 1
"Tanh" -> 1 "Tanh" -> 1

View file

@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze) import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt)
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable) import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
import qualified TensorFlow.Output as TF import qualified TensorFlow.Output as TF
@ -324,6 +324,14 @@ testPad =
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
V.fromList [2, 2, 3] @=? s V.fromList [2, 2, 3] @=? s
testSqrt :: Test
testSqrt = testCase "testSqrt" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.vector [0.0625 :: Float]
let y = TF.sqrt x
TF.gradients y [x] >>= TF.run
V.fromList [2] @=? dx
testBatchToSpaceND :: Test testBatchToSpaceND :: Test
testBatchToSpaceND = testBatchToSpaceND =
testCase "testBatchToSpaceND" $ do testCase "testBatchToSpaceND" $ do
@ -517,6 +525,7 @@ main = defaultMain
, testExpandDims , testExpandDims
, testReshape , testReshape
, testPad , testPad
, testSqrt
, testBatchToSpaceND , testBatchToSpaceND
, testSpaceToBatchND , testSpaceToBatchND
, testSqueeze , testSqueeze