mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Add gradient for batchMatMul (#246)
This commit is contained in:
parent
c811037cb9
commit
d741c3ee59
2 changed files with 104 additions and 1 deletions
|
@ -650,6 +650,25 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
[ Just $ matMul' (transAttrs True True) y dz
|
[ Just $ matMul' (transAttrs True True) y dz
|
||||||
, Just $ matMul' (transAttrs True True) dz x]
|
, Just $ matMul' (transAttrs True True) dz x]
|
||||||
|
|
||||||
|
opGrad "BatchMatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
|
let adjX = lookupAttr nodeDef "adj_x"
|
||||||
|
adjY = lookupAttr nodeDef "adj_y"
|
||||||
|
adjAttrs a b =
|
||||||
|
(opAttr "adj_x" .~ a) . (opAttr "adj_y" .~ b)
|
||||||
|
in case (adjX, adjY) of
|
||||||
|
(False, False) ->
|
||||||
|
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) dz y
|
||||||
|
, Just $ CoreOps.batchMatMul' (adjAttrs True False) x dz]
|
||||||
|
(False, True) ->
|
||||||
|
[ Just $ CoreOps.batchMatMul dz y
|
||||||
|
, Just $ CoreOps.batchMatMul' (adjAttrs True False) dz x]
|
||||||
|
(True, False) ->
|
||||||
|
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) y dz
|
||||||
|
, Just $ CoreOps.batchMatMul x dz]
|
||||||
|
(True, True) ->
|
||||||
|
[ Just $ CoreOps.batchMatMul' (adjAttrs True True) y dz
|
||||||
|
, Just $ CoreOps.batchMatMul' (adjAttrs True True) dz x]
|
||||||
|
|
||||||
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||||
[ Just $ CoreOps.transpose dz
|
[ Just $ CoreOps.transpose dz
|
||||||
(CoreOps.invertPermutation p :: Tensor Build Int32)
|
(CoreOps.invertPermutation p :: Tensor Build Int32)
|
||||||
|
@ -915,6 +934,7 @@ numOutputs o =
|
||||||
"Add" -> 1
|
"Add" -> 1
|
||||||
"AddN" -> 1
|
"AddN" -> 1
|
||||||
"BatchToSpaceND" -> 1
|
"BatchToSpaceND" -> 1
|
||||||
|
"BatchMatMul" -> 1
|
||||||
"Cast" -> 1
|
"Cast" -> 1
|
||||||
"Const" -> 1
|
"Const" -> 1
|
||||||
"Concat" -> 1
|
"Concat" -> 1
|
||||||
|
|
|
@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput')
|
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput', batchMatMul, batchMatMul', sum, conjugateTranspose)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||||
import qualified TensorFlow.Output as TF
|
import qualified TensorFlow.Output as TF
|
||||||
|
@ -604,6 +604,83 @@ transAttrs :: (TF.Attribute a,
|
||||||
transAttrs a b =
|
transAttrs a b =
|
||||||
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
|
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
|
||||||
|
|
||||||
|
|
||||||
|
batchMatMulGradient :: Test
|
||||||
|
batchMatMulGradient = testCase "batchMatMulGradients" $ do
|
||||||
|
|
||||||
|
let dfBuild = do
|
||||||
|
x <- TF.render $ TF.zeros $ TF.Shape [2,3, 1 :: Int64]
|
||||||
|
w <- TF.zeroInitializedVariable $ TF.Shape [2,1, 2 :: Int64]
|
||||||
|
let f = x `TF.batchMatMul` TF.readValue w :: TF.Tensor TF.Build Float
|
||||||
|
dfs <- TF.gradients f [x]
|
||||||
|
return (x, dfs)
|
||||||
|
|
||||||
|
(xShape, dxShape) <- TF.runSession $ do
|
||||||
|
(x, [dx]) <- TF.build dfBuild
|
||||||
|
TF.run (TF.shape x, TF.shape dx)
|
||||||
|
|
||||||
|
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
|
||||||
|
-- test that gradient of batchMatMul can be taken gradient of
|
||||||
|
batchMatMulGradGrad :: Test
|
||||||
|
batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
|
||||||
|
let width = 2 :: Int64
|
||||||
|
height = 3 :: Int64
|
||||||
|
batch = 4 :: Int64
|
||||||
|
|
||||||
|
let tower = do
|
||||||
|
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
|
||||||
|
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
|
||||||
|
let f = x `TF.batchMatMul` TF.readValue w
|
||||||
|
[dfdx] <- TF.gradients f [x]
|
||||||
|
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
|
||||||
|
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||||
|
return [TF.readValue w, TF.expr dfdw]
|
||||||
|
|
||||||
|
TF.runSession $ do
|
||||||
|
[w, dfdw] <- TF.build tower
|
||||||
|
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
||||||
|
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
let step = w `TF.add` dfdw
|
||||||
|
w0 <- TF.run step
|
||||||
|
liftIO $ V.fromList [3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0 :: Float] @=? w0
|
||||||
|
|
||||||
|
|
||||||
|
-- test that gradient of batchMatMul deals correctly with adj_x and adj_y
|
||||||
|
batchMatMulAdjointGradient :: (Bool, Bool) -> Test
|
||||||
|
batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ show axw) $ do
|
||||||
|
let (adjX, adjW) = axw
|
||||||
|
|
||||||
|
let dfBuild = do
|
||||||
|
let xShape = TF.Shape [2, 3, 1 :: Int64]
|
||||||
|
let xZeros = TF.zeros xShape
|
||||||
|
x <- TF.render $ if adjX then TF.conjugateTranspose xZeros (TF.vector [0, 2, 1 :: Int32]) else xZeros
|
||||||
|
variable <- TF.zeroInitializedVariable $ TF.Shape [2, 1, 2 :: Int64]
|
||||||
|
let wv = if adjW then TF.conjugateTranspose (TF.readValue variable) (TF.vector [0, 2, 1 :: Int32]) else TF.readValue variable
|
||||||
|
let f = TF.batchMatMul' (adjAttrs adjX adjW) x wv :: TF.Tensor TF.Build Float
|
||||||
|
w <- TF.render wv
|
||||||
|
ds <- TF.gradients f [x, w]
|
||||||
|
return (x, w, ds)
|
||||||
|
|
||||||
|
TF.runSession $ do
|
||||||
|
(x, w, [dx, dw]) <- TF.build dfBuild
|
||||||
|
xShape <- TF.run $ TF.shape x
|
||||||
|
dxShape <- TF.run $ TF.shape dx
|
||||||
|
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
wShape <- TF.run $ TF.shape w
|
||||||
|
dwShape <- TF.run $ TF.shape dw
|
||||||
|
liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
adjAttrs :: (TF.Attribute x,
|
||||||
|
TF.Attribute y) =>
|
||||||
|
x -> y -> TF.OpDef -> TF.OpDef
|
||||||
|
adjAttrs x y =
|
||||||
|
(TF.opAttr "adj_x" .~ x) . (TF.opAttr "adj_y" .~ y)
|
||||||
|
|
||||||
|
|
||||||
-- TODO check gradient with regard to filter also
|
-- TODO check gradient with regard to filter also
|
||||||
testConv2DBackpropInputGrad :: Test
|
testConv2DBackpropInputGrad :: Test
|
||||||
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
||||||
|
@ -708,6 +785,12 @@ main = defaultMain
|
||||||
, matMulTransposeGradient (False, True)
|
, matMulTransposeGradient (False, True)
|
||||||
, matMulTransposeGradient (True, False)
|
, matMulTransposeGradient (True, False)
|
||||||
, matMulTransposeGradient (True, True)
|
, matMulTransposeGradient (True, True)
|
||||||
|
, batchMatMulGradient
|
||||||
|
, batchMatMulGradGrad
|
||||||
|
, batchMatMulAdjointGradient (False, False)
|
||||||
|
, batchMatMulAdjointGradient (False, True)
|
||||||
|
, batchMatMulAdjointGradient (True, False)
|
||||||
|
, batchMatMulAdjointGradient (True, True)
|
||||||
, testConv2DBackpropInputGrad
|
, testConv2DBackpropInputGrad
|
||||||
, testDepthwiseConv2dGrad
|
, testDepthwiseConv2dGrad
|
||||||
, testDepthwiseConv2dBackpropInputGrad
|
, testDepthwiseConv2dBackpropInputGrad
|
||||||
|
|
Loading…
Reference in a new issue