mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Add gradient for sqrt function (#236)
This commit is contained in:
parent
896a0d31f7
commit
666dce94bd
2 changed files with 16 additions and 2 deletions
|
@ -813,6 +813,10 @@ opGrad "Placeholder" _ _ _ = []
|
|||
opGrad "VarHandleOp" _ _ _ = []
|
||||
opGrad "Variable" _ _ _ = []
|
||||
|
||||
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
|
||||
where
|
||||
sq' = scalar 1 `CoreOps.div` (scalar 2 `CoreOps.mul` CoreOps.sqrt x)
|
||||
|
||||
opGrad n nodeDef ins grads =
|
||||
error $ "no gradient implemented for " ++
|
||||
show (n, length ins, length grads, showMessage nodeDef, ins)
|
||||
|
@ -863,6 +867,7 @@ numOutputs o =
|
|||
"SparseSegmentSum" -> 1
|
||||
"Square" -> 1
|
||||
"Squeeze" -> 1
|
||||
"Sqrt" -> 1
|
||||
"Sub" -> 1
|
||||
"Sum" -> 1
|
||||
"Tanh" -> 1
|
||||
|
|
|
@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze)
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
||||
import qualified TensorFlow.Output as TF
|
||||
|
@ -324,6 +324,14 @@ testPad =
|
|||
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
|
||||
V.fromList [2, 2, 3] @=? s
|
||||
|
||||
testSqrt :: Test
|
||||
testSqrt = testCase "testSqrt" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [0.0625 :: Float]
|
||||
let y = TF.sqrt x
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [2] @=? dx
|
||||
|
||||
testBatchToSpaceND :: Test
|
||||
testBatchToSpaceND =
|
||||
testCase "testBatchToSpaceND" $ do
|
||||
|
@ -517,6 +525,7 @@ main = defaultMain
|
|||
, testExpandDims
|
||||
, testReshape
|
||||
, testPad
|
||||
, testSqrt
|
||||
, testBatchToSpaceND
|
||||
, testSpaceToBatchND
|
||||
, testSqueeze
|
||||
|
|
Loading…
Reference in a new issue