diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 476222d..5e12e74 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -650,6 +650,25 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] = [ Just $ matMul' (transAttrs True True) y dz , Just $ matMul' (transAttrs True True) dz x] +opGrad "BatchMatMul" nodeDef [toT -> x, toT -> y] [dz] = + let adjX = lookupAttr nodeDef "adj_x" + adjY = lookupAttr nodeDef "adj_y" + adjAttrs a b = + (opAttr "adj_x" .~ a) . (opAttr "adj_y" .~ b) + in case (adjX, adjY) of + (False, False) -> + [ Just $ CoreOps.batchMatMul' (adjAttrs False True) dz y + , Just $ CoreOps.batchMatMul' (adjAttrs True False) x dz] + (False, True) -> + [ Just $ CoreOps.batchMatMul dz y + , Just $ CoreOps.batchMatMul' (adjAttrs True False) dz x] + (True, False) -> + [ Just $ CoreOps.batchMatMul' (adjAttrs False True) y dz + , Just $ CoreOps.batchMatMul x dz] + (True, True) -> + [ Just $ CoreOps.batchMatMul' (adjAttrs True True) y dz + , Just $ CoreOps.batchMatMul' (adjAttrs True True) dz x] + opGrad "Transpose" _ [_, toT -> p] [dz] = [ Just $ CoreOps.transpose dz (CoreOps.invertPermutation p :: Tensor Build Int32) @@ -915,6 +934,7 @@ numOutputs o = "Add" -> 1 "AddN" -> 1 "BatchToSpaceND" -> 1 + "BatchMatMul" -> 1 "Cast" -> 1 "Const" -> 1 "Concat" -> 1 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index b306bb9..24cb427 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM) import Control.Monad.IO.Class (liftIO) import qualified TensorFlow.Core as TF -import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput') +import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput', batchMatMul, batchMatMul', sum, conjugateTranspose) import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape) import qualified TensorFlow.Output as TF @@ -604,6 +604,83 @@ transAttrs :: (TF.Attribute a, transAttrs a b = (TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b) + +batchMatMulGradient :: Test +batchMatMulGradient = testCase "batchMatMulGradients" $ do + + let dfBuild = do + x <- TF.render $ TF.zeros $ TF.Shape [2,3, 1 :: Int64] + w <- TF.zeroInitializedVariable $ TF.Shape [2,1, 2 :: Int64] + let f = x `TF.batchMatMul` TF.readValue w :: TF.Tensor TF.Build Float + dfs <- TF.gradients f [x] + return (x, dfs) + + (xShape, dxShape) <- TF.runSession $ do + (x, [dx]) <- TF.build dfBuild + TF.run (TF.shape x, TF.shape dx) + + assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32) + + +-- test that gradient of batchMatMul can be taken gradient of +batchMatMulGradGrad :: Test +batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do + let width = 2 :: Int64 + height = 3 :: Int64 + batch = 4 :: Int64 + + let tower = do + x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1] + w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width] + let f = x `TF.batchMatMul` TF.readValue w + [dfdx] <- TF.gradients f [x] + let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32]) + [dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w) + return [TF.readValue w, TF.expr dfdw] + + TF.runSession $ do + [w, dfdw] <- TF.build tower + (wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw) + liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32) + + let step = w `TF.add` dfdw + w0 <- TF.run step + liftIO $ V.fromList [3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0 :: Float] @=? w0 + + +-- test that gradient of batchMatMul deals correctly with adj_x and adj_y +batchMatMulAdjointGradient :: (Bool, Bool) -> Test +batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ show axw) $ do + let (adjX, adjW) = axw + + let dfBuild = do + let xShape = TF.Shape [2, 3, 1 :: Int64] + let xZeros = TF.zeros xShape + x <- TF.render $ if adjX then TF.conjugateTranspose xZeros (TF.vector [0, 2, 1 :: Int32]) else xZeros + variable <- TF.zeroInitializedVariable $ TF.Shape [2, 1, 2 :: Int64] + let wv = if adjW then TF.conjugateTranspose (TF.readValue variable) (TF.vector [0, 2, 1 :: Int32]) else TF.readValue variable + let f = TF.batchMatMul' (adjAttrs adjX adjW) x wv :: TF.Tensor TF.Build Float + w <- TF.render wv + ds <- TF.gradients f [x, w] + return (x, w, ds) + + TF.runSession $ do + (x, w, [dx, dw]) <- TF.build dfBuild + xShape <- TF.run $ TF.shape x + dxShape <- TF.run $ TF.shape dx + liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32) + + wShape <- TF.run $ TF.shape w + dwShape <- TF.run $ TF.shape dw + liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32) + +adjAttrs :: (TF.Attribute x, + TF.Attribute y) => + x -> y -> TF.OpDef -> TF.OpDef +adjAttrs x y = + (TF.opAttr "adj_x" .~ x) . (TF.opAttr "adj_y" .~ y) + + -- TODO check gradient with regard to filter also testConv2DBackpropInputGrad :: Test testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do @@ -708,6 +785,12 @@ main = defaultMain , matMulTransposeGradient (False, True) , matMulTransposeGradient (True, False) , matMulTransposeGradient (True, True) + , batchMatMulGradient + , batchMatMulGradGrad + , batchMatMulAdjointGradient (False, False) + , batchMatMulAdjointGradient (False, True) + , batchMatMulAdjointGradient (True, False) + , batchMatMulAdjointGradient (True, True) , testConv2DBackpropInputGrad , testDepthwiseConv2dGrad , testDepthwiseConv2dBackpropInputGrad