diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 51d995d..83914fc 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -813,6 +813,10 @@ opGrad "Placeholder" _ _ _ = [] opGrad "VarHandleOp" _ _ _ = [] opGrad "Variable" _ _ _ = [] +opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz] + where + sq' = scalar 1 `CoreOps.div` (scalar 2 `CoreOps.mul` CoreOps.sqrt x) + opGrad n nodeDef ins grads = error $ "no gradient implemented for " ++ show (n, length ins, length grads, showMessage nodeDef, ins) @@ -863,6 +867,7 @@ numOutputs o = "SparseSegmentSum" -> 1 "Square" -> 1 "Squeeze" -> 1 + "Sqrt" -> 1 "Sub" -> 1 "Sum" -> 1 "Tanh" -> 1 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 931a146..0cbb6eb 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM) import Control.Monad.IO.Class (liftIO) import qualified TensorFlow.Core as TF -import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze) +import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt) import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable) import qualified TensorFlow.Output as TF @@ -324,6 +324,14 @@ testPad = V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx V.fromList [2, 2, 3] @=? s +testSqrt :: Test +testSqrt = testCase "testSqrt" $ do + [dx] <- TF.runSession $ do + x <- TF.render $ TF.vector [0.0625 :: Float] + let y = TF.sqrt x + TF.gradients y [x] >>= TF.run + V.fromList [2] @=? dx + testBatchToSpaceND :: Test testBatchToSpaceND = testCase "testBatchToSpaceND" $ do @@ -517,6 +525,7 @@ main = defaultMain , testExpandDims , testReshape , testPad + , testSqrt , testBatchToSpaceND , testSpaceToBatchND , testSqueeze @@ -530,4 +539,4 @@ main = defaultMain , matMulTransposeGradient (True, False) , matMulTransposeGradient (True, True) , testConv2DBackpropInputGrad - ] \ No newline at end of file + ]