1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-01-24 09:49:49 +01:00

Fix gradients calculation for min and max (#48)

This commit is contained in:
fkm3 2016-12-12 09:47:02 -08:00 committed by Greg Steuck
parent 1539783ee5
commit cc08520dc7
3 changed files with 19 additions and 7 deletions

View file

@ -459,9 +459,10 @@ opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
where
sx = shape (x :: Tensor Value a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
x' = reshape x outputShapeKeptDims
y = CoreOps.max x indices
y' = reshape y outputShapeKeptDims
dz' = reshape dz outputShapeKeptDims
indicators = CoreOps.cast $ CoreOps.equal x' x
indicators = CoreOps.cast $ CoreOps.equal y' x
numSelected = reshape (sum indicators indices) outputShapeKeptDims
-- Min and Max have identical gradient implementations.

View file

@ -144,10 +144,12 @@ Test-Suite GradientTest
, lens-family
, google-shim
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-proto
, test-framework
, test-framework-hunit
, vector
Test-Suite MiscTest
default-language: Haskell2010

View file

@ -15,6 +15,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
import Data.Int (Int32)
import Data.List (sort)
import Data.ProtoLens.TextFormat (showMessage)
import Google.Test (googleTest)
@ -22,14 +23,12 @@ import Lens.Family2 ((^..))
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (max)
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Framework.NodeDef (op)
@ -149,10 +148,20 @@ testDiamond = testCase "testDiamond" $ do
(4 :: Float) @=? TF.unScalar dx
testMaxGradient :: Test
testMaxGradient = testCase "testMaxGradient" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
TF.gradients y [x]
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
main :: IO ()
main = googleTest [ testGradientSimple
, testGradientDisconnected
, testCreateGraphStateful
, testCreateGraphNameScopes
, testDiamond
, testMaxGradient
]