mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-24 09:49:49 +01:00
Fix gradients calculation for min and max (#48)
This commit is contained in:
parent
1539783ee5
commit
cc08520dc7
3 changed files with 19 additions and 7 deletions
|
@ -459,9 +459,10 @@ opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
|
|||
where
|
||||
sx = shape (x :: Tensor Value a)
|
||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
||||
x' = reshape x outputShapeKeptDims
|
||||
y = CoreOps.max x indices
|
||||
y' = reshape y outputShapeKeptDims
|
||||
dz' = reshape dz outputShapeKeptDims
|
||||
indicators = CoreOps.cast $ CoreOps.equal x' x
|
||||
indicators = CoreOps.cast $ CoreOps.equal y' x
|
||||
numSelected = reshape (sum indicators indices) outputShapeKeptDims
|
||||
|
||||
-- Min and Max have identical gradient implementations.
|
||||
|
|
|
@ -144,10 +144,12 @@ Test-Suite GradientTest
|
|||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, vector
|
||||
|
||||
Test-Suite MiscTest
|
||||
default-language: Haskell2010
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
import Data.Int (Int32)
|
||||
import Data.List (sort)
|
||||
import Data.ProtoLens.TextFormat (showMessage)
|
||||
import Google.Test (googleTest)
|
||||
|
@ -22,14 +23,12 @@ import Lens.Family2 ((^..))
|
|||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (max)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||
|
@ -149,10 +148,20 @@ testDiamond = testCase "testDiamond" $ do
|
|||
(4 :: Float) @=? TF.unScalar dx
|
||||
|
||||
|
||||
testMaxGradient :: Test
|
||||
testMaxGradient = testCase "testMaxGradient" $ do
|
||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
|
||||
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
|
||||
TF.gradients y [x]
|
||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testGradientSimple
|
||||
, testGradientDisconnected
|
||||
, testCreateGraphStateful
|
||||
, testCreateGraphNameScopes
|
||||
, testDiamond
|
||||
, testMaxGradient
|
||||
]
|
||||
|
|
Loading…
Add table
Reference in a new issue