1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

Fix gradients calculation for min and max (#48)

This commit is contained in:
fkm3 2016-12-12 09:47:02 -08:00 committed by Greg Steuck
parent 1539783ee5
commit cc08520dc7
3 changed files with 19 additions and 7 deletions

View File

@ -459,9 +459,10 @@ opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
where where
sx = shape (x :: Tensor Value a) sx = shape (x :: Tensor Value a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
x' = reshape x outputShapeKeptDims y = CoreOps.max x indices
y' = reshape y outputShapeKeptDims
dz' = reshape dz outputShapeKeptDims dz' = reshape dz outputShapeKeptDims
indicators = CoreOps.cast $ CoreOps.equal x' x indicators = CoreOps.cast $ CoreOps.equal y' x
numSelected = reshape (sum indicators indices) outputShapeKeptDims numSelected = reshape (sum indicators indices) outputShapeKeptDims
-- Min and Max have identical gradient implementations. -- Min and Max have identical gradient implementations.

View File

@ -144,10 +144,12 @@ Test-Suite GradientTest
, lens-family , lens-family
, google-shim , google-shim
, tensorflow , tensorflow
, tensorflow-core-ops
, tensorflow-ops , tensorflow-ops
, tensorflow-proto , tensorflow-proto
, test-framework , test-framework
, test-framework-hunit , test-framework-hunit
, vector
Test-Suite MiscTest Test-Suite MiscTest
default-language: Haskell2010 default-language: Haskell2010

View File

@ -15,6 +15,7 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
import Data.Int (Int32)
import Data.List (sort) import Data.List (sort)
import Data.ProtoLens.TextFormat (showMessage) import Data.ProtoLens.TextFormat (showMessage)
import Google.Test (googleTest) import Google.Test (googleTest)
@ -22,14 +23,12 @@ import Lens.Family2 ((^..))
import Test.Framework (Test) import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (max)
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import Proto.Tensorflow.Core.Framework.Graph (node) import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Framework.NodeDef (op) import Proto.Tensorflow.Core.Framework.NodeDef (op)
@ -149,10 +148,20 @@ testDiamond = testCase "testDiamond" $ do
(4 :: Float) @=? TF.unScalar dx (4 :: Float) @=? TF.unScalar dx
testMaxGradient :: Test
testMaxGradient = testCase "testMaxGradient" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
TF.gradients y [x]
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
main :: IO () main :: IO ()
main = googleTest [ testGradientSimple main = googleTest [ testGradientSimple
, testGradientDisconnected , testGradientDisconnected
, testCreateGraphStateful , testCreateGraphStateful
, testCreateGraphNameScopes , testCreateGraphNameScopes
, testDiamond , testDiamond
, testMaxGradient
] ]