1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Add gradient for batchMatMul (#246)

This commit is contained in:
rschlotterbeck 2019-07-08 19:41:35 +02:00 committed by fkm3
parent c811037cb9
commit d741c3ee59
2 changed files with 104 additions and 1 deletions

View file

@ -650,6 +650,25 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ matMul' (transAttrs True True) y dz [ Just $ matMul' (transAttrs True True) y dz
, Just $ matMul' (transAttrs True True) dz x] , Just $ matMul' (transAttrs True True) dz x]
opGrad "BatchMatMul" nodeDef [toT -> x, toT -> y] [dz] =
let adjX = lookupAttr nodeDef "adj_x"
adjY = lookupAttr nodeDef "adj_y"
adjAttrs a b =
(opAttr "adj_x" .~ a) . (opAttr "adj_y" .~ b)
in case (adjX, adjY) of
(False, False) ->
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) dz y
, Just $ CoreOps.batchMatMul' (adjAttrs True False) x dz]
(False, True) ->
[ Just $ CoreOps.batchMatMul dz y
, Just $ CoreOps.batchMatMul' (adjAttrs True False) dz x]
(True, False) ->
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) y dz
, Just $ CoreOps.batchMatMul x dz]
(True, True) ->
[ Just $ CoreOps.batchMatMul' (adjAttrs True True) y dz
, Just $ CoreOps.batchMatMul' (adjAttrs True True) dz x]
opGrad "Transpose" _ [_, toT -> p] [dz] = opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz [ Just $ CoreOps.transpose dz
(CoreOps.invertPermutation p :: Tensor Build Int32) (CoreOps.invertPermutation p :: Tensor Build Int32)
@ -915,6 +934,7 @@ numOutputs o =
"Add" -> 1 "Add" -> 1
"AddN" -> 1 "AddN" -> 1
"BatchToSpaceND" -> 1 "BatchToSpaceND" -> 1
"BatchMatMul" -> 1
"Cast" -> 1 "Cast" -> 1
"Const" -> 1 "Const" -> 1
"Concat" -> 1 "Concat" -> 1

View file

@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput') import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput', batchMatMul, batchMatMul', sum, conjugateTranspose)
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape) import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
import qualified TensorFlow.Output as TF import qualified TensorFlow.Output as TF
@ -604,6 +604,83 @@ transAttrs :: (TF.Attribute a,
transAttrs a b = transAttrs a b =
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b) (TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
batchMatMulGradient :: Test
batchMatMulGradient = testCase "batchMatMulGradients" $ do
let dfBuild = do
x <- TF.render $ TF.zeros $ TF.Shape [2,3, 1 :: Int64]
w <- TF.zeroInitializedVariable $ TF.Shape [2,1, 2 :: Int64]
let f = x `TF.batchMatMul` TF.readValue w :: TF.Tensor TF.Build Float
dfs <- TF.gradients f [x]
return (x, dfs)
(xShape, dxShape) <- TF.runSession $ do
(x, [dx]) <- TF.build dfBuild
TF.run (TF.shape x, TF.shape dx)
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
-- test that gradient of batchMatMul can be taken gradient of
batchMatMulGradGrad :: Test
batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
let width = 2 :: Int64
height = 3 :: Int64
batch = 4 :: Int64
let tower = do
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
let f = x `TF.batchMatMul` TF.readValue w
[dfdx] <- TF.gradients f [x]
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
return [TF.readValue w, TF.expr dfdw]
TF.runSession $ do
[w, dfdw] <- TF.build tower
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
let step = w `TF.add` dfdw
w0 <- TF.run step
liftIO $ V.fromList [3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0 :: Float] @=? w0
-- test that gradient of batchMatMul deals correctly with adj_x and adj_y
batchMatMulAdjointGradient :: (Bool, Bool) -> Test
batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ show axw) $ do
let (adjX, adjW) = axw
let dfBuild = do
let xShape = TF.Shape [2, 3, 1 :: Int64]
let xZeros = TF.zeros xShape
x <- TF.render $ if adjX then TF.conjugateTranspose xZeros (TF.vector [0, 2, 1 :: Int32]) else xZeros
variable <- TF.zeroInitializedVariable $ TF.Shape [2, 1, 2 :: Int64]
let wv = if adjW then TF.conjugateTranspose (TF.readValue variable) (TF.vector [0, 2, 1 :: Int32]) else TF.readValue variable
let f = TF.batchMatMul' (adjAttrs adjX adjW) x wv :: TF.Tensor TF.Build Float
w <- TF.render wv
ds <- TF.gradients f [x, w]
return (x, w, ds)
TF.runSession $ do
(x, w, [dx, dw]) <- TF.build dfBuild
xShape <- TF.run $ TF.shape x
dxShape <- TF.run $ TF.shape dx
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
wShape <- TF.run $ TF.shape w
dwShape <- TF.run $ TF.shape dw
liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32)
adjAttrs :: (TF.Attribute x,
TF.Attribute y) =>
x -> y -> TF.OpDef -> TF.OpDef
adjAttrs x y =
(TF.opAttr "adj_x" .~ x) . (TF.opAttr "adj_y" .~ y)
-- TODO check gradient with regard to filter also -- TODO check gradient with regard to filter also
testConv2DBackpropInputGrad :: Test testConv2DBackpropInputGrad :: Test
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
@ -708,6 +785,12 @@ main = defaultMain
, matMulTransposeGradient (False, True) , matMulTransposeGradient (False, True)
, matMulTransposeGradient (True, False) , matMulTransposeGradient (True, False)
, matMulTransposeGradient (True, True) , matMulTransposeGradient (True, True)
, batchMatMulGradient
, batchMatMulGradGrad
, batchMatMulAdjointGradient (False, False)
, batchMatMulAdjointGradient (False, True)
, batchMatMulAdjointGradient (True, False)
, batchMatMulAdjointGradient (True, True)
, testConv2DBackpropInputGrad , testConv2DBackpropInputGrad
, testDepthwiseConv2dGrad , testDepthwiseConv2dGrad
, testDepthwiseConv2dBackpropInputGrad , testDepthwiseConv2dBackpropInputGrad