diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 84085a0..476222d 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -85,6 +85,8 @@ import TensorFlow.Ops , shape , softmaxCrossEntropyWithLogits , sum + , sigmoid + , sigmoidGrad , scalarize , vector , zerosLike @@ -481,6 +483,7 @@ opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz] opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x] opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x] opGrad "Tanh" _ [toT -> x] [dz] = [Just $ tanhGrad (tanh x) dz] +opGrad "Sigmoid" _ [toT -> x] [dz] = [Just $ sigmoidGrad (sigmoid x) dz] opGrad "Concat" _ _ix [dy] -- Concat concatenates input tensors @@ -947,6 +950,7 @@ numOutputs o = "ReluGrad" -> 1 "Reshape" -> 1 "Select" -> 1 + "Sigmoid" -> 1 "Size" -> 1 "Slice" -> 1 "SoftmaxCrossEntropyWithLogits" -> 2 diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index ea36c90..9fcad6e 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -123,6 +123,8 @@ module TensorFlow.Ops , scalar' , shape , shape' + , CoreOps.sigmoid + , CoreOps.sigmoidGrad , CoreOps.sign , CoreOps.sign' , CoreOps.size diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index caf7629..b306bb9 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -368,6 +368,14 @@ testTanhGrad = testCase "testTanhGrad" $ do TF.gradients y [x] >>= TF.run V.fromList [1] @=? dx +testSigmoidGrad :: Test +testSigmoidGrad = testCase "testSigmoidGrad" $ do + [dx] <- TF.runSession $ do + x <- TF.render $ TF.vector [0 :: Float] + let y = TF.sigmoid x + TF.gradients y [x] >>= TF.run + V.fromList [0.25] @=? dx + testExpandDims :: Test testExpandDims = testCase "testExpandDims" $ do @@ -681,6 +689,7 @@ main = defaultMain , testReluGrad , testReluGradGrad , testTanhGrad + , testSigmoidGrad , testExpandDims , testReshape , testPad