mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Added support for tanh activation function (#223)
This commit is contained in:
parent
61e58fd33f
commit
915015018c
3 changed files with 16 additions and 1 deletions
|
@ -45,7 +45,7 @@ import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
|
||||||
import Lens.Family2.State.Strict (uses)
|
import Lens.Family2.State.Strict (uses)
|
||||||
import Lens.Family2.Stock (at, intAt)
|
import Lens.Family2.Stock (at, intAt)
|
||||||
import Lens.Family2.Unchecked (lens, iso)
|
import Lens.Family2.Unchecked (lens, iso)
|
||||||
import Prelude hiding (sum)
|
import Prelude hiding (sum, tanh)
|
||||||
import Text.Printf (printf)
|
import Text.Printf (printf)
|
||||||
import qualified Data.Graph.Inductive.Basic as FGL
|
import qualified Data.Graph.Inductive.Basic as FGL
|
||||||
import qualified Data.Graph.Inductive.Graph as FGL
|
import qualified Data.Graph.Inductive.Graph as FGL
|
||||||
|
@ -76,6 +76,8 @@ import TensorFlow.Ops
|
||||||
, matMul'
|
, matMul'
|
||||||
, reducedShape
|
, reducedShape
|
||||||
, reluGrad
|
, reluGrad
|
||||||
|
, tanh
|
||||||
|
, tanhGrad
|
||||||
, reshape
|
, reshape
|
||||||
, scalar
|
, scalar
|
||||||
, shape
|
, shape
|
||||||
|
@ -459,6 +461,7 @@ opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
|
||||||
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
||||||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||||
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
||||||
|
opGrad "Tanh" _ [toT -> x] [dz] = [Just $ tanhGrad (tanh x) dz]
|
||||||
|
|
||||||
opGrad "Concat" _ _ix [dy]
|
opGrad "Concat" _ _ix [dy]
|
||||||
-- Concat concatenates input tensors
|
-- Concat concatenates input tensors
|
||||||
|
@ -833,6 +836,7 @@ numOutputs o =
|
||||||
"SparseSegmentSum" -> 1
|
"SparseSegmentSum" -> 1
|
||||||
"Sub" -> 1
|
"Sub" -> 1
|
||||||
"Sum" -> 1
|
"Sum" -> 1
|
||||||
|
"Tanh" -> 1
|
||||||
"Tile" -> 1
|
"Tile" -> 1
|
||||||
"Transpose" -> 1
|
"Transpose" -> 1
|
||||||
"TruncatedNormal" -> 1
|
"TruncatedNormal" -> 1
|
||||||
|
|
|
@ -112,6 +112,8 @@ module TensorFlow.Ops
|
||||||
, CoreOps.relu'
|
, CoreOps.relu'
|
||||||
, CoreOps.reluGrad
|
, CoreOps.reluGrad
|
||||||
, CoreOps.reluGrad'
|
, CoreOps.reluGrad'
|
||||||
|
, CoreOps.tanh
|
||||||
|
, CoreOps.tanhGrad
|
||||||
, CoreOps.reshape
|
, CoreOps.reshape
|
||||||
, CoreOps.reshape'
|
, CoreOps.reshape'
|
||||||
, restore
|
, restore
|
||||||
|
|
|
@ -282,6 +282,14 @@ testReluGradGrad = testCase "testReluGradGrad" $ do
|
||||||
TF.gradients y' [x] >>= TF.run
|
TF.gradients y' [x] >>= TF.run
|
||||||
V.fromList [0] @=? dx
|
V.fromList [0] @=? dx
|
||||||
|
|
||||||
|
testTanhGrad :: Test
|
||||||
|
testTanhGrad = testCase "testTanhGrad" $ do
|
||||||
|
[dx] <- TF.runSession $ do
|
||||||
|
x <- TF.render $ TF.vector [0 :: Float]
|
||||||
|
let y = TF.tanh x
|
||||||
|
TF.gradients y [x] >>= TF.run
|
||||||
|
V.fromList [1] @=? dx
|
||||||
|
|
||||||
testFillGrad :: Test
|
testFillGrad :: Test
|
||||||
testFillGrad = testCase "testFillGrad" $ do
|
testFillGrad = testCase "testFillGrad" $ do
|
||||||
[dx] <- TF.runSession $ do
|
[dx] <- TF.runSession $ do
|
||||||
|
@ -427,6 +435,7 @@ main = defaultMain
|
||||||
, testMaximumGradGrad
|
, testMaximumGradGrad
|
||||||
, testReluGrad
|
, testReluGrad
|
||||||
, testReluGradGrad
|
, testReluGradGrad
|
||||||
|
, testTanhGrad
|
||||||
, testFillGrad
|
, testFillGrad
|
||||||
, testTileGrad
|
, testTileGrad
|
||||||
, testTile2DGrad
|
, testTile2DGrad
|
||||||
|
|
Loading…
Reference in a new issue