diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 02686be..4d2833d 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -459,9 +459,10 @@ opGrad "Max" _ [toT -> x, toT -> indices] [dz] = where sx = shape (x :: Tensor Value a) outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) - x' = reshape x outputShapeKeptDims + y = CoreOps.max x indices + y' = reshape y outputShapeKeptDims dz' = reshape dz outputShapeKeptDims - indicators = CoreOps.cast $ CoreOps.equal x' x + indicators = CoreOps.cast $ CoreOps.equal y' x numSelected = reshape (sum indicators indices) outputShapeKeptDims -- Min and Max have identical gradient implementations. diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index b6b60ad..80d5a03 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -144,10 +144,12 @@ Test-Suite GradientTest , lens-family , google-shim , tensorflow + , tensorflow-core-ops , tensorflow-ops , tensorflow-proto , test-framework , test-framework-hunit + , vector Test-Suite MiscTest default-language: Haskell2010 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 037b309..2f4dd30 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -15,6 +15,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +import Data.Int (Int32) import Data.List (sort) import Data.ProtoLens.TextFormat (showMessage) import Google.Test (googleTest) @@ -22,14 +23,12 @@ import Lens.Family2 ((^..)) import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) import Test.HUnit ((@=?)) +import qualified Data.Vector as V -import qualified TensorFlow.Build as TF +import qualified TensorFlow.Core as TF +import qualified TensorFlow.GenOps.Core as TF (max) import qualified TensorFlow.Gradient as TF -import qualified TensorFlow.Nodes as TF import qualified TensorFlow.Ops as TF -import qualified TensorFlow.Session as TF -import qualified TensorFlow.Tensor as TF -import qualified TensorFlow.Types as TF import Proto.Tensorflow.Core.Framework.Graph (node) import Proto.Tensorflow.Core.Framework.NodeDef (op) @@ -149,10 +148,20 @@ testDiamond = testCase "testDiamond" $ do (4 :: Float) @=? TF.unScalar dx +testMaxGradient :: Test +testMaxGradient = testCase "testMaxGradient" $ do + [dx] <- TF.runSession $ TF.buildAnd TF.run $ do + let x = TF.vector [1, 2, 3, 0, 1 :: Float] + y = TF.max x (0 :: TF.Tensor TF.Value Int32) + TF.gradients y [x] + V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx + + main :: IO () main = googleTest [ testGradientSimple , testGradientDisconnected , testCreateGraphStateful , testCreateGraphNameScopes , testDiamond + , testMaxGradient ]