Added support for tanh activation function (#223)

This commit is contained in:
Rik 2018-11-14 18:08:05 +01:00 committed by fkm3
parent 61e58fd33f
commit 915015018c
3 changed files with 16 additions and 1 deletions

View File

@ -45,7 +45,7 @@ import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
import Lens.Family2.State.Strict (uses)
import Lens.Family2.Stock (at, intAt)
import Lens.Family2.Unchecked (lens, iso)
import Prelude hiding (sum)
import Prelude hiding (sum, tanh)
import Text.Printf (printf)
import qualified Data.Graph.Inductive.Basic as FGL
import qualified Data.Graph.Inductive.Graph as FGL
@ -76,6 +76,8 @@ import TensorFlow.Ops
, matMul'
, reducedShape
, reluGrad
, tanh
, tanhGrad
, reshape
, scalar
, shape
@ -459,6 +461,7 @@ opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
opGrad "Tanh" _ [toT -> x] [dz] = [Just $ tanhGrad (tanh x) dz]
opGrad "Concat" _ _ix [dy]
-- Concat concatenates input tensors
@ -833,6 +836,7 @@ numOutputs o =
"SparseSegmentSum" -> 1
"Sub" -> 1
"Sum" -> 1
"Tanh" -> 1
"Tile" -> 1
"Transpose" -> 1
"TruncatedNormal" -> 1

View File

@ -112,6 +112,8 @@ module TensorFlow.Ops
, CoreOps.relu'
, CoreOps.reluGrad
, CoreOps.reluGrad'
, CoreOps.tanh
, CoreOps.tanhGrad
, CoreOps.reshape
, CoreOps.reshape'
, restore

View File

@ -282,6 +282,14 @@ testReluGradGrad = testCase "testReluGradGrad" $ do
TF.gradients y' [x] >>= TF.run
V.fromList [0] @=? dx
testTanhGrad :: Test
testTanhGrad = testCase "testTanhGrad" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.vector [0 :: Float]
let y = TF.tanh x
TF.gradients y [x] >>= TF.run
V.fromList [1] @=? dx
testFillGrad :: Test
testFillGrad = testCase "testFillGrad" $ do
[dx] <- TF.runSession $ do
@ -427,6 +435,7 @@ main = defaultMain
, testMaximumGradGrad
, testReluGrad
, testReluGradGrad
, testTanhGrad
, testFillGrad
, testTileGrad
, testTile2DGrad