mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 06:49:44 +01:00
Fixed matMul gradients for transposed arguments
This commit is contained in:
parent
51014a015c
commit
d153d0aded
4 changed files with 100 additions and 24 deletions
|
@ -551,13 +551,13 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
, Just $ matMul' (transAttrs True False) x dz]
|
, Just $ matMul' (transAttrs True False) x dz]
|
||||||
(False, True) ->
|
(False, True) ->
|
||||||
[ Just $ matMul dz y
|
[ Just $ matMul dz y
|
||||||
, Just $ matMul' (transAttrs True False) x dz]
|
, Just $ matMul' (transAttrs True False) dz x]
|
||||||
(True, False) ->
|
(True, False) ->
|
||||||
[ Just $ matMul' (transAttrs False True) dz y
|
[ Just $ matMul' (transAttrs False True) y dz
|
||||||
, Just $ matMul x dz]
|
, Just $ matMul x dz]
|
||||||
(True, True) ->
|
(True, True) ->
|
||||||
[ Just $ matMul' (transAttrs True True) dz y
|
[ Just $ matMul' (transAttrs True True) y dz
|
||||||
, Just $ matMul' (transAttrs True True) x dz]
|
, Just $ matMul' (transAttrs True True) dz x]
|
||||||
|
|
||||||
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||||
[ Just $ CoreOps.transpose dz
|
[ Just $ CoreOps.transpose dz
|
||||||
|
|
|
@ -318,21 +318,15 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||||
|
|
||||||
-- | Sum a tensor down to a scalar
|
-- | Sum a tensor down to a scalar
|
||||||
-- Seee `TensorFlow.GenOps.Core.sum`
|
-- Seee `TensorFlow.GenOps.Core.sum`
|
||||||
reduceSum
|
reduceSum :: (OneOf '[ Double, Float, Int32, Int64
|
||||||
:: ( TensorType a
|
, Complex Float, Complex Double] a) =>
|
||||||
, OneOf '[ Double, Float, Int32, Int64
|
Tensor v a -> Tensor Build a
|
||||||
, Complex Float, Complex Double] a
|
|
||||||
)
|
|
||||||
=> Tensor v a -> Tensor Build a
|
|
||||||
reduceSum x = CoreOps.sum x allAxes
|
reduceSum x = CoreOps.sum x allAxes
|
||||||
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
|
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
|
||||||
|
|
||||||
reduceSum'
|
reduceSum' :: (OneOf '[ Double, Float, Int32, Int64
|
||||||
:: ( TensorType a
|
, Complex Float, Complex Double] a) =>
|
||||||
, OneOf '[ Double, Float, Int32, Int64
|
OpParams -> Tensor v a -> Tensor Build a
|
||||||
, Complex Float, Complex Double] a
|
|
||||||
)
|
|
||||||
=> OpParams -> Tensor v a -> Tensor Build a
|
|
||||||
reduceSum' params x = CoreOps.sum' params x allAxes
|
reduceSum' params x = CoreOps.sum' params x allAxes
|
||||||
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
|
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
|
||||||
|
|
||||||
|
|
|
@ -200,6 +200,7 @@ Test-Suite GradientTest
|
||||||
, tensorflow-proto
|
, tensorflow-proto
|
||||||
, test-framework
|
, test-framework
|
||||||
, test-framework-hunit
|
, test-framework-hunit
|
||||||
|
, transformers
|
||||||
, vector
|
, vector
|
||||||
|
|
||||||
Test-Suite MiscTest
|
Test-Suite MiscTest
|
||||||
|
|
|
@ -15,21 +15,26 @@
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
|
||||||
import Data.Int (Int32)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.List (sort)
|
import Data.List (sort)
|
||||||
import Data.ProtoLens.TextFormat (showMessage)
|
import Data.ProtoLens.TextFormat (showMessage)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import Lens.Family2 ((^..))
|
import Lens.Family2 ((^..), (.~))
|
||||||
|
|
||||||
import Test.Framework (Test)
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?), assertEqual)
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (max, tile)
|
import qualified TensorFlow.GenOps.Core as TF (max, tile)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Output as TF
|
||||||
|
import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||||
|
@ -207,15 +212,85 @@ testTile2DGrad = testCase "testTileGrad2D" $ do
|
||||||
let y = TF.tile x multiples
|
let y = TF.tile x multiples
|
||||||
|
|
||||||
[dx] <- TF.gradients y [x]
|
[dx] <- TF.gradients y [x]
|
||||||
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
shapeDX <- TF.run $ TF.shape dx
|
|
||||||
shapeX <- TF.run $ TF.shape x
|
|
||||||
dxv <- TF.run dx
|
|
||||||
return (dxv, shapeDX, shapeX)
|
|
||||||
shapeX @=? (shapeDX :: V.Vector Int32)
|
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||||
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
|
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
|
||||||
|
|
||||||
|
|
||||||
|
matMulGradient :: Test
|
||||||
|
matMulGradient = testCase "matMulGradients" $ do
|
||||||
|
|
||||||
|
let dfBuild = do
|
||||||
|
x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64]
|
||||||
|
w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
||||||
|
let f = x `TF.matMul` w :: TF.Tensor TF.Build Float
|
||||||
|
dfs <- TF.gradients f [x]
|
||||||
|
return (x, dfs)
|
||||||
|
|
||||||
|
(xShape, dxShape) <- TF.runSession $ do
|
||||||
|
(x, [dx]) <- TF.build dfBuild
|
||||||
|
TF.run (TF.shape x, TF.shape dx)
|
||||||
|
|
||||||
|
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
|
||||||
|
-- test that gradient of matMul can be taken gradient of
|
||||||
|
matMulGradGrad :: Test
|
||||||
|
matMulGradGrad = testCase "matMulGradGrad" $ do
|
||||||
|
let width = 2 :: Int64
|
||||||
|
batch = 4 :: Int64
|
||||||
|
|
||||||
|
let tower = do
|
||||||
|
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
|
||||||
|
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
|
||||||
|
let f = x `TF.matMul` w
|
||||||
|
[dfdx] <- TF.gradients f [x]
|
||||||
|
let f'x = TF.reduceSum dfdx
|
||||||
|
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||||
|
return [TF.value w, dfdw]
|
||||||
|
|
||||||
|
TF.runSession $ do
|
||||||
|
[w, dfdw] <- TF.build tower
|
||||||
|
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
||||||
|
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
let step = w `TF.add` dfdw
|
||||||
|
w0 <- TF.run step
|
||||||
|
liftIO $ ((V.fromList [4, 4 :: Float]) @=? w0)
|
||||||
|
|
||||||
|
|
||||||
|
-- test that gradient of matMul deals correctly with transpose_a and transpose_b
|
||||||
|
matMulTransposeGradient :: (Bool, Bool) -> Test
|
||||||
|
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw)) $ do
|
||||||
|
let (transposeX, transposeW) = txw
|
||||||
|
|
||||||
|
let dfBuild = do
|
||||||
|
let xShape = TF.Shape [3, 1 :: Int64]
|
||||||
|
let xZeros = TF.zeros xShape
|
||||||
|
x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros
|
||||||
|
variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
||||||
|
let wv = if transposeW then TF.matTranspose variable else TF.expr variable
|
||||||
|
let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float
|
||||||
|
w <- TF.render wv
|
||||||
|
ds <- TF.gradients f [x, w]
|
||||||
|
return (x, w, ds)
|
||||||
|
|
||||||
|
TF.runSession $ do
|
||||||
|
(x, w, [dx, dw]) <- TF.build dfBuild
|
||||||
|
xShape <- TF.run $ TF.shape x
|
||||||
|
dxShape <- TF.run $ TF.shape dx
|
||||||
|
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
wShape <- TF.run $ TF.shape w
|
||||||
|
dwShape <- TF.run $ TF.shape dw
|
||||||
|
liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
transAttrs :: (TF.Attribute a,
|
||||||
|
TF.Attribute b) =>
|
||||||
|
a -> b -> TF.OpDef -> TF.OpDef
|
||||||
|
transAttrs a b =
|
||||||
|
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testGradientSimple
|
main = googleTest [ testGradientSimple
|
||||||
, testGradientDisconnected
|
, testGradientDisconnected
|
||||||
|
@ -228,4 +303,10 @@ main = googleTest [ testGradientSimple
|
||||||
, testFillGrad
|
, testFillGrad
|
||||||
, testTileGrad
|
, testTileGrad
|
||||||
, testTile2DGrad
|
, testTile2DGrad
|
||||||
|
, matMulGradient
|
||||||
|
, matMulGradGrad
|
||||||
|
, matMulTransposeGradient (False, False)
|
||||||
|
, matMulTransposeGradient (False, True)
|
||||||
|
, matMulTransposeGradient (True, False)
|
||||||
|
, matMulTransposeGradient (True, True)
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue