mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-27 11:15:03 +01:00
Add gradient for batchMatMul (#246)
This commit is contained in:
parent
c811037cb9
commit
d741c3ee59
2 changed files with 104 additions and 1 deletions
|
@ -650,6 +650,25 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
|||
[ Just $ matMul' (transAttrs True True) y dz
|
||||
, Just $ matMul' (transAttrs True True) dz x]
|
||||
|
||||
opGrad "BatchMatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||
let adjX = lookupAttr nodeDef "adj_x"
|
||||
adjY = lookupAttr nodeDef "adj_y"
|
||||
adjAttrs a b =
|
||||
(opAttr "adj_x" .~ a) . (opAttr "adj_y" .~ b)
|
||||
in case (adjX, adjY) of
|
||||
(False, False) ->
|
||||
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) dz y
|
||||
, Just $ CoreOps.batchMatMul' (adjAttrs True False) x dz]
|
||||
(False, True) ->
|
||||
[ Just $ CoreOps.batchMatMul dz y
|
||||
, Just $ CoreOps.batchMatMul' (adjAttrs True False) dz x]
|
||||
(True, False) ->
|
||||
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) y dz
|
||||
, Just $ CoreOps.batchMatMul x dz]
|
||||
(True, True) ->
|
||||
[ Just $ CoreOps.batchMatMul' (adjAttrs True True) y dz
|
||||
, Just $ CoreOps.batchMatMul' (adjAttrs True True) dz x]
|
||||
|
||||
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||
[ Just $ CoreOps.transpose dz
|
||||
(CoreOps.invertPermutation p :: Tensor Build Int32)
|
||||
|
@ -915,6 +934,7 @@ numOutputs o =
|
|||
"Add" -> 1
|
||||
"AddN" -> 1
|
||||
"BatchToSpaceND" -> 1
|
||||
"BatchMatMul" -> 1
|
||||
"Cast" -> 1
|
||||
"Const" -> 1
|
||||
"Concat" -> 1
|
||||
|
|
|
@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput')
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput', batchMatMul, batchMatMul', sum, conjugateTranspose)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||
import qualified TensorFlow.Output as TF
|
||||
|
@ -604,6 +604,83 @@ transAttrs :: (TF.Attribute a,
|
|||
transAttrs a b =
|
||||
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
|
||||
|
||||
|
||||
batchMatMulGradient :: Test
|
||||
batchMatMulGradient = testCase "batchMatMulGradients" $ do
|
||||
|
||||
let dfBuild = do
|
||||
x <- TF.render $ TF.zeros $ TF.Shape [2,3, 1 :: Int64]
|
||||
w <- TF.zeroInitializedVariable $ TF.Shape [2,1, 2 :: Int64]
|
||||
let f = x `TF.batchMatMul` TF.readValue w :: TF.Tensor TF.Build Float
|
||||
dfs <- TF.gradients f [x]
|
||||
return (x, dfs)
|
||||
|
||||
(xShape, dxShape) <- TF.runSession $ do
|
||||
(x, [dx]) <- TF.build dfBuild
|
||||
TF.run (TF.shape x, TF.shape dx)
|
||||
|
||||
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
|
||||
|
||||
|
||||
-- test that gradient of batchMatMul can be taken gradient of
|
||||
batchMatMulGradGrad :: Test
|
||||
batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
|
||||
let width = 2 :: Int64
|
||||
height = 3 :: Int64
|
||||
batch = 4 :: Int64
|
||||
|
||||
let tower = do
|
||||
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
|
||||
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
|
||||
let f = x `TF.batchMatMul` TF.readValue w
|
||||
[dfdx] <- TF.gradients f [x]
|
||||
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
|
||||
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||
return [TF.readValue w, TF.expr dfdw]
|
||||
|
||||
TF.runSession $ do
|
||||
[w, dfdw] <- TF.build tower
|
||||
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
||||
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
||||
|
||||
let step = w `TF.add` dfdw
|
||||
w0 <- TF.run step
|
||||
liftIO $ V.fromList [3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0 :: Float] @=? w0
|
||||
|
||||
|
||||
-- test that gradient of batchMatMul deals correctly with adj_x and adj_y
|
||||
batchMatMulAdjointGradient :: (Bool, Bool) -> Test
|
||||
batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ show axw) $ do
|
||||
let (adjX, adjW) = axw
|
||||
|
||||
let dfBuild = do
|
||||
let xShape = TF.Shape [2, 3, 1 :: Int64]
|
||||
let xZeros = TF.zeros xShape
|
||||
x <- TF.render $ if adjX then TF.conjugateTranspose xZeros (TF.vector [0, 2, 1 :: Int32]) else xZeros
|
||||
variable <- TF.zeroInitializedVariable $ TF.Shape [2, 1, 2 :: Int64]
|
||||
let wv = if adjW then TF.conjugateTranspose (TF.readValue variable) (TF.vector [0, 2, 1 :: Int32]) else TF.readValue variable
|
||||
let f = TF.batchMatMul' (adjAttrs adjX adjW) x wv :: TF.Tensor TF.Build Float
|
||||
w <- TF.render wv
|
||||
ds <- TF.gradients f [x, w]
|
||||
return (x, w, ds)
|
||||
|
||||
TF.runSession $ do
|
||||
(x, w, [dx, dw]) <- TF.build dfBuild
|
||||
xShape <- TF.run $ TF.shape x
|
||||
dxShape <- TF.run $ TF.shape dx
|
||||
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
||||
|
||||
wShape <- TF.run $ TF.shape w
|
||||
dwShape <- TF.run $ TF.shape dw
|
||||
liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32)
|
||||
|
||||
adjAttrs :: (TF.Attribute x,
|
||||
TF.Attribute y) =>
|
||||
x -> y -> TF.OpDef -> TF.OpDef
|
||||
adjAttrs x y =
|
||||
(TF.opAttr "adj_x" .~ x) . (TF.opAttr "adj_y" .~ y)
|
||||
|
||||
|
||||
-- TODO check gradient with regard to filter also
|
||||
testConv2DBackpropInputGrad :: Test
|
||||
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
||||
|
@ -708,6 +785,12 @@ main = defaultMain
|
|||
, matMulTransposeGradient (False, True)
|
||||
, matMulTransposeGradient (True, False)
|
||||
, matMulTransposeGradient (True, True)
|
||||
, batchMatMulGradient
|
||||
, batchMatMulGradGrad
|
||||
, batchMatMulAdjointGradient (False, False)
|
||||
, batchMatMulAdjointGradient (False, True)
|
||||
, batchMatMulAdjointGradient (True, False)
|
||||
, batchMatMulAdjointGradient (True, True)
|
||||
, testConv2DBackpropInputGrad
|
||||
, testDepthwiseConv2dGrad
|
||||
, testDepthwiseConv2dBackpropInputGrad
|
||||
|
|
Loading…
Add table
Reference in a new issue