diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index d6b867b..8f9c44c 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -667,6 +667,21 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] = opGrad "LabelClasses" _ _ _ = [Nothing, Nothing] opGrad "LabelWeights" _ _ _ = [Nothing] opGrad "Size" _ _ _ = [Nothing] + +-- TODO (jcberentsen): Python implementation uses set_shape for +-- static shape inference, which is unsupported. +-- TODO: implement support for static shape inference +opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] = + [Just inputGrad, Nothing] + where + inputGrad = sum reshapedDz axes + inputShape = shape (x :: Tensor Build a) + packed = CoreOps.pack [multiples, inputShape] + perm = vector [1, 0 :: Int32] + splitShape = CoreOps.reshape (CoreOps.transpose packed perm) allDimensions + axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32) + reshapedDz = CoreOps.reshape dz splitShape + opGrad "ZerosLike" _ _ _ = [Nothing] opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx] where @@ -719,6 +734,7 @@ numOutputs o = "SparseSegmentSum" -> 1 "Sub" -> 1 "Sum" -> 1 + "Tile" -> 1 "Transpose" -> 1 "TruncatedNormal" -> 1 "Variable" -> 1 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index b82d1e2..56fc1d4 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -27,7 +27,7 @@ import Test.HUnit ((@=?)) import qualified Data.Vector as V import qualified TensorFlow.Core as TF -import qualified TensorFlow.GenOps.Core as TF (max) +import qualified TensorFlow.GenOps.Core as TF (max, tile) import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Ops as TF @@ -187,6 +187,35 @@ testFillGrad = testCase "testFillGrad" $ do TF.gradients y [x] >>= TF.run V.fromList [6] @=? dx + +testTileGrad :: Test +testTileGrad = testCase "testTileGrad" $ do + [dx] <- TF.runSession $ do + x <- TF.render $ TF.vector [5, 9 :: Float] + let multiples = TF.vector [2 :: Int32] + let y = TF.tile x multiples + TF.gradients y [x] >>= TF.run + V.fromList [2, 2] @=? dx + + +testTile2DGrad :: Test +testTile2DGrad = testCase "testTileGrad2D" $ do + (dx, shapeDX, shapeX) <- TF.runSession $ do + let shape = TF.vector [3, 2 :: Int32] + x <- TF.render $ TF.fill shape (TF.scalar (1::Float)) + let multiples = TF.vector [2, 3 :: Int32] + let y = TF.tile x multiples + + [dx] <- TF.gradients y [x] + + shapeDX <- TF.run $ TF.shape dx + shapeX <- TF.run $ TF.shape x + dxv <- TF.run dx + return (dxv, shapeDX, shapeX) + shapeX @=? (shapeDX :: V.Vector Int32) + V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float) + + main :: IO () main = googleTest [ testGradientSimple , testGradientDisconnected @@ -197,4 +226,6 @@ main = googleTest [ testGradientSimple , testReluGrad , testReluGradGrad , testFillGrad + , testTileGrad + , testTile2DGrad ]