1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-24 10:39:45 +01:00

Fixed matMul gradients for transposed arguments

This commit is contained in:
Jarl Christian Berentsen 2017-05-04 09:39:15 +02:00 committed by fkm3
parent 51014a015c
commit d153d0aded
4 changed files with 100 additions and 24 deletions

View file

@ -551,13 +551,13 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
, Just $ matMul' (transAttrs True False) x dz]
(False, True) ->
[ Just $ matMul dz y
, Just $ matMul' (transAttrs True False) x dz]
, Just $ matMul' (transAttrs True False) dz x]
(True, False) ->
[ Just $ matMul' (transAttrs False True) dz y
[ Just $ matMul' (transAttrs False True) y dz
, Just $ matMul x dz]
(True, True) ->
[ Just $ matMul' (transAttrs True True) dz y
, Just $ matMul' (transAttrs True True) x dz]
[ Just $ matMul' (transAttrs True True) y dz
, Just $ matMul' (transAttrs True True) dz x]
opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz

View file

@ -318,21 +318,15 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
-- | Sum a tensor down to a scalar
-- Seee `TensorFlow.GenOps.Core.sum`
reduceSum
:: ( TensorType a
, OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a
)
=> Tensor v a -> Tensor Build a
reduceSum :: (OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) =>
Tensor v a -> Tensor Build a
reduceSum x = CoreOps.sum x allAxes
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
reduceSum'
:: ( TensorType a
, OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a
)
=> OpParams -> Tensor v a -> Tensor Build a
reduceSum' :: (OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) =>
OpParams -> Tensor v a -> Tensor Build a
reduceSum' params x = CoreOps.sum' params x allAxes
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1

View file

@ -200,6 +200,7 @@ Test-Suite GradientTest
, tensorflow-proto
, test-framework
, test-framework-hunit
, transformers
, vector
Test-Suite MiscTest

View file

@ -15,21 +15,26 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
import Data.Int (Int32)
import Data.Int (Int32, Int64)
import Data.List (sort)
import Data.ProtoLens.TextFormat (showMessage)
import Google.Test (googleTest)
import Lens.Family2 ((^..))
import Lens.Family2 ((^..), (.~))
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import Test.HUnit ((@=?), assertEqual)
import qualified Data.Vector as V
import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (max, tile)
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Output as TF
import qualified TensorFlow.Types as TF
import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Framework.NodeDef (op)
@ -207,15 +212,85 @@ testTile2DGrad = testCase "testTileGrad2D" $ do
let y = TF.tile x multiples
[dx] <- TF.gradients y [x]
shapeDX <- TF.run $ TF.shape dx
shapeX <- TF.run $ TF.shape x
dxv <- TF.run dx
return (dxv, shapeDX, shapeX)
TF.run (dx, TF.shape dx, TF.shape x)
shapeX @=? (shapeDX :: V.Vector Int32)
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
matMulGradient :: Test
matMulGradient = testCase "matMulGradients" $ do
let dfBuild = do
x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64]
w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
let f = x `TF.matMul` w :: TF.Tensor TF.Build Float
dfs <- TF.gradients f [x]
return (x, dfs)
(xShape, dxShape) <- TF.runSession $ do
(x, [dx]) <- TF.build dfBuild
TF.run (TF.shape x, TF.shape dx)
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
-- test that gradient of matMul can be taken gradient of
matMulGradGrad :: Test
matMulGradGrad = testCase "matMulGradGrad" $ do
let width = 2 :: Int64
batch = 4 :: Int64
let tower = do
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
let f = x `TF.matMul` w
[dfdx] <- TF.gradients f [x]
let f'x = TF.reduceSum dfdx
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
return [TF.value w, dfdw]
TF.runSession $ do
[w, dfdw] <- TF.build tower
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
let step = w `TF.add` dfdw
w0 <- TF.run step
liftIO $ ((V.fromList [4, 4 :: Float]) @=? w0)
-- test that gradient of matMul deals correctly with transpose_a and transpose_b
matMulTransposeGradient :: (Bool, Bool) -> Test
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw)) $ do
let (transposeX, transposeW) = txw
let dfBuild = do
let xShape = TF.Shape [3, 1 :: Int64]
let xZeros = TF.zeros xShape
x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros
variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
let wv = if transposeW then TF.matTranspose variable else TF.expr variable
let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float
w <- TF.render wv
ds <- TF.gradients f [x, w]
return (x, w, ds)
TF.runSession $ do
(x, w, [dx, dw]) <- TF.build dfBuild
xShape <- TF.run $ TF.shape x
dxShape <- TF.run $ TF.shape dx
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
wShape <- TF.run $ TF.shape w
dwShape <- TF.run $ TF.shape dw
liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32)
transAttrs :: (TF.Attribute a,
TF.Attribute b) =>
a -> b -> TF.OpDef -> TF.OpDef
transAttrs a b =
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
main :: IO ()
main = googleTest [ testGradientSimple
, testGradientDisconnected
@ -228,4 +303,10 @@ main = googleTest [ testGradientSimple
, testFillGrad
, testTileGrad
, testTile2DGrad
, matMulGradient
, matMulGradGrad
, matMulTransposeGradient (False, False)
, matMulTransposeGradient (False, True)
, matMulTransposeGradient (True, False)
, matMulTransposeGradient (True, True)
]