mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Add gradient for sqrt function (#236)
This commit is contained in:
parent
896a0d31f7
commit
666dce94bd
2 changed files with 16 additions and 2 deletions
|
@ -813,6 +813,10 @@ opGrad "Placeholder" _ _ _ = []
|
||||||
opGrad "VarHandleOp" _ _ _ = []
|
opGrad "VarHandleOp" _ _ _ = []
|
||||||
opGrad "Variable" _ _ _ = []
|
opGrad "Variable" _ _ _ = []
|
||||||
|
|
||||||
|
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
|
||||||
|
where
|
||||||
|
sq' = scalar 1 `CoreOps.div` (scalar 2 `CoreOps.mul` CoreOps.sqrt x)
|
||||||
|
|
||||||
opGrad n nodeDef ins grads =
|
opGrad n nodeDef ins grads =
|
||||||
error $ "no gradient implemented for " ++
|
error $ "no gradient implemented for " ++
|
||||||
show (n, length ins, length grads, showMessage nodeDef, ins)
|
show (n, length ins, length grads, showMessage nodeDef, ins)
|
||||||
|
@ -863,6 +867,7 @@ numOutputs o =
|
||||||
"SparseSegmentSum" -> 1
|
"SparseSegmentSum" -> 1
|
||||||
"Square" -> 1
|
"Square" -> 1
|
||||||
"Squeeze" -> 1
|
"Squeeze" -> 1
|
||||||
|
"Sqrt" -> 1
|
||||||
"Sub" -> 1
|
"Sub" -> 1
|
||||||
"Sum" -> 1
|
"Sum" -> 1
|
||||||
"Tanh" -> 1
|
"Tanh" -> 1
|
||||||
|
|
|
@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze)
|
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
||||||
import qualified TensorFlow.Output as TF
|
import qualified TensorFlow.Output as TF
|
||||||
|
@ -324,6 +324,14 @@ testPad =
|
||||||
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
|
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
|
||||||
V.fromList [2, 2, 3] @=? s
|
V.fromList [2, 2, 3] @=? s
|
||||||
|
|
||||||
|
testSqrt :: Test
|
||||||
|
testSqrt = testCase "testSqrt" $ do
|
||||||
|
[dx] <- TF.runSession $ do
|
||||||
|
x <- TF.render $ TF.vector [0.0625 :: Float]
|
||||||
|
let y = TF.sqrt x
|
||||||
|
TF.gradients y [x] >>= TF.run
|
||||||
|
V.fromList [2] @=? dx
|
||||||
|
|
||||||
testBatchToSpaceND :: Test
|
testBatchToSpaceND :: Test
|
||||||
testBatchToSpaceND =
|
testBatchToSpaceND =
|
||||||
testCase "testBatchToSpaceND" $ do
|
testCase "testBatchToSpaceND" $ do
|
||||||
|
@ -517,6 +525,7 @@ main = defaultMain
|
||||||
, testExpandDims
|
, testExpandDims
|
||||||
, testReshape
|
, testReshape
|
||||||
, testPad
|
, testPad
|
||||||
|
, testSqrt
|
||||||
, testBatchToSpaceND
|
, testBatchToSpaceND
|
||||||
, testSpaceToBatchND
|
, testSpaceToBatchND
|
||||||
, testSqueeze
|
, testSqueeze
|
||||||
|
|
Loading…
Reference in a new issue