mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Fix gradients calculation for min and max (#48)
This commit is contained in:
parent
1539783ee5
commit
cc08520dc7
3 changed files with 19 additions and 7 deletions
|
@ -459,9 +459,10 @@ opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
|
||||||
where
|
where
|
||||||
sx = shape (x :: Tensor Value a)
|
sx = shape (x :: Tensor Value a)
|
||||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
||||||
x' = reshape x outputShapeKeptDims
|
y = CoreOps.max x indices
|
||||||
|
y' = reshape y outputShapeKeptDims
|
||||||
dz' = reshape dz outputShapeKeptDims
|
dz' = reshape dz outputShapeKeptDims
|
||||||
indicators = CoreOps.cast $ CoreOps.equal x' x
|
indicators = CoreOps.cast $ CoreOps.equal y' x
|
||||||
numSelected = reshape (sum indicators indices) outputShapeKeptDims
|
numSelected = reshape (sum indicators indices) outputShapeKeptDims
|
||||||
|
|
||||||
-- Min and Max have identical gradient implementations.
|
-- Min and Max have identical gradient implementations.
|
||||||
|
|
|
@ -144,10 +144,12 @@ Test-Suite GradientTest
|
||||||
, lens-family
|
, lens-family
|
||||||
, google-shim
|
, google-shim
|
||||||
, tensorflow
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
, tensorflow-proto
|
, tensorflow-proto
|
||||||
, test-framework
|
, test-framework
|
||||||
, test-framework-hunit
|
, test-framework-hunit
|
||||||
|
, vector
|
||||||
|
|
||||||
Test-Suite MiscTest
|
Test-Suite MiscTest
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
import Data.Int (Int32)
|
||||||
import Data.List (sort)
|
import Data.List (sort)
|
||||||
import Data.ProtoLens.TextFormat (showMessage)
|
import Data.ProtoLens.TextFormat (showMessage)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
|
@ -22,14 +23,12 @@ import Lens.Family2 ((^..))
|
||||||
import Test.Framework (Test)
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?))
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.Build as TF
|
import qualified TensorFlow.Core as TF
|
||||||
|
import qualified TensorFlow.GenOps.Core as TF (max)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Nodes as TF
|
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
|
||||||
import qualified TensorFlow.Tensor as TF
|
|
||||||
import qualified TensorFlow.Types as TF
|
|
||||||
|
|
||||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||||
|
@ -149,10 +148,20 @@ testDiamond = testCase "testDiamond" $ do
|
||||||
(4 :: Float) @=? TF.unScalar dx
|
(4 :: Float) @=? TF.unScalar dx
|
||||||
|
|
||||||
|
|
||||||
|
testMaxGradient :: Test
|
||||||
|
testMaxGradient = testCase "testMaxGradient" $ do
|
||||||
|
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||||
|
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
|
||||||
|
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
|
||||||
|
TF.gradients y [x]
|
||||||
|
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||||
|
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testGradientSimple
|
main = googleTest [ testGradientSimple
|
||||||
, testGradientDisconnected
|
, testGradientDisconnected
|
||||||
, testCreateGraphStateful
|
, testCreateGraphStateful
|
||||||
, testCreateGraphNameScopes
|
, testCreateGraphNameScopes
|
||||||
, testDiamond
|
, testDiamond
|
||||||
|
, testMaxGradient
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue