1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Added support for tanh activation function (#223)

This commit is contained in:
Rik 2018-11-14 18:08:05 +01:00 committed by fkm3
parent 61e58fd33f
commit 915015018c
3 changed files with 16 additions and 1 deletions

View file

@ -45,7 +45,7 @@ import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
import Lens.Family2.State.Strict (uses) import Lens.Family2.State.Strict (uses)
import Lens.Family2.Stock (at, intAt) import Lens.Family2.Stock (at, intAt)
import Lens.Family2.Unchecked (lens, iso) import Lens.Family2.Unchecked (lens, iso)
import Prelude hiding (sum) import Prelude hiding (sum, tanh)
import Text.Printf (printf) import Text.Printf (printf)
import qualified Data.Graph.Inductive.Basic as FGL import qualified Data.Graph.Inductive.Basic as FGL
import qualified Data.Graph.Inductive.Graph as FGL import qualified Data.Graph.Inductive.Graph as FGL
@ -76,6 +76,8 @@ import TensorFlow.Ops
, matMul' , matMul'
, reducedShape , reducedShape
, reluGrad , reluGrad
, tanh
, tanhGrad
, reshape , reshape
, scalar , scalar
, shape , shape
@ -459,6 +461,7 @@ opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz] opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x] opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x] opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
opGrad "Tanh" _ [toT -> x] [dz] = [Just $ tanhGrad (tanh x) dz]
opGrad "Concat" _ _ix [dy] opGrad "Concat" _ _ix [dy]
-- Concat concatenates input tensors -- Concat concatenates input tensors
@ -833,6 +836,7 @@ numOutputs o =
"SparseSegmentSum" -> 1 "SparseSegmentSum" -> 1
"Sub" -> 1 "Sub" -> 1
"Sum" -> 1 "Sum" -> 1
"Tanh" -> 1
"Tile" -> 1 "Tile" -> 1
"Transpose" -> 1 "Transpose" -> 1
"TruncatedNormal" -> 1 "TruncatedNormal" -> 1

View file

@ -112,6 +112,8 @@ module TensorFlow.Ops
, CoreOps.relu' , CoreOps.relu'
, CoreOps.reluGrad , CoreOps.reluGrad
, CoreOps.reluGrad' , CoreOps.reluGrad'
, CoreOps.tanh
, CoreOps.tanhGrad
, CoreOps.reshape , CoreOps.reshape
, CoreOps.reshape' , CoreOps.reshape'
, restore , restore

View file

@ -282,6 +282,14 @@ testReluGradGrad = testCase "testReluGradGrad" $ do
TF.gradients y' [x] >>= TF.run TF.gradients y' [x] >>= TF.run
V.fromList [0] @=? dx V.fromList [0] @=? dx
testTanhGrad :: Test
testTanhGrad = testCase "testTanhGrad" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.vector [0 :: Float]
let y = TF.tanh x
TF.gradients y [x] >>= TF.run
V.fromList [1] @=? dx
testFillGrad :: Test testFillGrad :: Test
testFillGrad = testCase "testFillGrad" $ do testFillGrad = testCase "testFillGrad" $ do
[dx] <- TF.runSession $ do [dx] <- TF.runSession $ do
@ -427,6 +435,7 @@ main = defaultMain
, testMaximumGradGrad , testMaximumGradGrad
, testReluGrad , testReluGrad
, testReluGradGrad , testReluGradGrad
, testTanhGrad
, testFillGrad , testFillGrad
, testTileGrad , testTileGrad
, testTile2DGrad , testTile2DGrad