mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-11 19:39:49 +01:00
Added support for tanh activation function (#223)
This commit is contained in:
parent
61e58fd33f
commit
915015018c
3 changed files with 16 additions and 1 deletions
|
@ -45,7 +45,7 @@ import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
|
|||
import Lens.Family2.State.Strict (uses)
|
||||
import Lens.Family2.Stock (at, intAt)
|
||||
import Lens.Family2.Unchecked (lens, iso)
|
||||
import Prelude hiding (sum)
|
||||
import Prelude hiding (sum, tanh)
|
||||
import Text.Printf (printf)
|
||||
import qualified Data.Graph.Inductive.Basic as FGL
|
||||
import qualified Data.Graph.Inductive.Graph as FGL
|
||||
|
@ -76,6 +76,8 @@ import TensorFlow.Ops
|
|||
, matMul'
|
||||
, reducedShape
|
||||
, reluGrad
|
||||
, tanh
|
||||
, tanhGrad
|
||||
, reshape
|
||||
, scalar
|
||||
, shape
|
||||
|
@ -459,6 +461,7 @@ opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
|
|||
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
||||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
||||
opGrad "Tanh" _ [toT -> x] [dz] = [Just $ tanhGrad (tanh x) dz]
|
||||
|
||||
opGrad "Concat" _ _ix [dy]
|
||||
-- Concat concatenates input tensors
|
||||
|
@ -833,6 +836,7 @@ numOutputs o =
|
|||
"SparseSegmentSum" -> 1
|
||||
"Sub" -> 1
|
||||
"Sum" -> 1
|
||||
"Tanh" -> 1
|
||||
"Tile" -> 1
|
||||
"Transpose" -> 1
|
||||
"TruncatedNormal" -> 1
|
||||
|
|
|
@ -112,6 +112,8 @@ module TensorFlow.Ops
|
|||
, CoreOps.relu'
|
||||
, CoreOps.reluGrad
|
||||
, CoreOps.reluGrad'
|
||||
, CoreOps.tanh
|
||||
, CoreOps.tanhGrad
|
||||
, CoreOps.reshape
|
||||
, CoreOps.reshape'
|
||||
, restore
|
||||
|
|
|
@ -282,6 +282,14 @@ testReluGradGrad = testCase "testReluGradGrad" $ do
|
|||
TF.gradients y' [x] >>= TF.run
|
||||
V.fromList [0] @=? dx
|
||||
|
||||
testTanhGrad :: Test
|
||||
testTanhGrad = testCase "testTanhGrad" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [0 :: Float]
|
||||
let y = TF.tanh x
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [1] @=? dx
|
||||
|
||||
testFillGrad :: Test
|
||||
testFillGrad = testCase "testFillGrad" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
|
@ -427,6 +435,7 @@ main = defaultMain
|
|||
, testMaximumGradGrad
|
||||
, testReluGrad
|
||||
, testReluGradGrad
|
||||
, testTanhGrad
|
||||
, testFillGrad
|
||||
, testTileGrad
|
||||
, testTile2DGrad
|
||||
|
|
Loading…
Reference in a new issue