mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Add gradient for sigmoid (#245)
This commit is contained in:
parent
1fbd5d41dd
commit
c811037cb9
3 changed files with 15 additions and 0 deletions
|
@ -85,6 +85,8 @@ import TensorFlow.Ops
|
|||
, shape
|
||||
, softmaxCrossEntropyWithLogits
|
||||
, sum
|
||||
, sigmoid
|
||||
, sigmoidGrad
|
||||
, scalarize
|
||||
, vector
|
||||
, zerosLike
|
||||
|
@ -481,6 +483,7 @@ opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
|||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
||||
opGrad "Tanh" _ [toT -> x] [dz] = [Just $ tanhGrad (tanh x) dz]
|
||||
opGrad "Sigmoid" _ [toT -> x] [dz] = [Just $ sigmoidGrad (sigmoid x) dz]
|
||||
|
||||
opGrad "Concat" _ _ix [dy]
|
||||
-- Concat concatenates input tensors
|
||||
|
@ -947,6 +950,7 @@ numOutputs o =
|
|||
"ReluGrad" -> 1
|
||||
"Reshape" -> 1
|
||||
"Select" -> 1
|
||||
"Sigmoid" -> 1
|
||||
"Size" -> 1
|
||||
"Slice" -> 1
|
||||
"SoftmaxCrossEntropyWithLogits" -> 2
|
||||
|
|
|
@ -123,6 +123,8 @@ module TensorFlow.Ops
|
|||
, scalar'
|
||||
, shape
|
||||
, shape'
|
||||
, CoreOps.sigmoid
|
||||
, CoreOps.sigmoidGrad
|
||||
, CoreOps.sign
|
||||
, CoreOps.sign'
|
||||
, CoreOps.size
|
||||
|
|
|
@ -368,6 +368,14 @@ testTanhGrad = testCase "testTanhGrad" $ do
|
|||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [1] @=? dx
|
||||
|
||||
testSigmoidGrad :: Test
|
||||
testSigmoidGrad = testCase "testSigmoidGrad" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [0 :: Float]
|
||||
let y = TF.sigmoid x
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [0.25] @=? dx
|
||||
|
||||
testExpandDims :: Test
|
||||
testExpandDims =
|
||||
testCase "testExpandDims" $ do
|
||||
|
@ -681,6 +689,7 @@ main = defaultMain
|
|||
, testReluGrad
|
||||
, testReluGradGrad
|
||||
, testTanhGrad
|
||||
, testSigmoidGrad
|
||||
, testExpandDims
|
||||
, testReshape
|
||||
, testPad
|
||||
|
|
Loading…
Reference in a new issue