1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

Implemented TileGrad

Some notes about static shape inference
This commit is contained in:
Jarl Christian Berentsen 2017-04-29 15:57:42 +02:00 committed by fkm3
parent 97b4bb5bab
commit 51014a015c
2 changed files with 48 additions and 1 deletions

View File

@ -667,6 +667,21 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing] opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing] opGrad "LabelWeights" _ _ _ = [Nothing]
opGrad "Size" _ _ _ = [Nothing] opGrad "Size" _ _ _ = [Nothing]
-- TODO (jcberentsen): Python implementation uses set_shape for
-- static shape inference, which is unsupported.
-- TODO: implement support for static shape inference
opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] =
[Just inputGrad, Nothing]
where
inputGrad = sum reshapedDz axes
inputShape = shape (x :: Tensor Build a)
packed = CoreOps.pack [multiples, inputShape]
perm = vector [1, 0 :: Int32]
splitShape = CoreOps.reshape (CoreOps.transpose packed perm) allDimensions
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
reshapedDz = CoreOps.reshape dz splitShape
opGrad "ZerosLike" _ _ _ = [Nothing] opGrad "ZerosLike" _ _ _ = [Nothing]
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx] opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
where where
@ -719,6 +734,7 @@ numOutputs o =
"SparseSegmentSum" -> 1 "SparseSegmentSum" -> 1
"Sub" -> 1 "Sub" -> 1
"Sum" -> 1 "Sum" -> 1
"Tile" -> 1
"Transpose" -> 1 "Transpose" -> 1
"TruncatedNormal" -> 1 "TruncatedNormal" -> 1
"Variable" -> 1 "Variable" -> 1

View File

@ -27,7 +27,7 @@ import Test.HUnit ((@=?))
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.Core as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (max) import qualified TensorFlow.GenOps.Core as TF (max, tile)
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
@ -187,6 +187,35 @@ testFillGrad = testCase "testFillGrad" $ do
TF.gradients y [x] >>= TF.run TF.gradients y [x] >>= TF.run
V.fromList [6] @=? dx V.fromList [6] @=? dx
testTileGrad :: Test
testTileGrad = testCase "testTileGrad" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.vector [5, 9 :: Float]
let multiples = TF.vector [2 :: Int32]
let y = TF.tile x multiples
TF.gradients y [x] >>= TF.run
V.fromList [2, 2] @=? dx
testTile2DGrad :: Test
testTile2DGrad = testCase "testTileGrad2D" $ do
(dx, shapeDX, shapeX) <- TF.runSession $ do
let shape = TF.vector [3, 2 :: Int32]
x <- TF.render $ TF.fill shape (TF.scalar (1::Float))
let multiples = TF.vector [2, 3 :: Int32]
let y = TF.tile x multiples
[dx] <- TF.gradients y [x]
shapeDX <- TF.run $ TF.shape dx
shapeX <- TF.run $ TF.shape x
dxv <- TF.run dx
return (dxv, shapeDX, shapeX)
shapeX @=? (shapeDX :: V.Vector Int32)
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
main :: IO () main :: IO ()
main = googleTest [ testGradientSimple main = googleTest [ testGradientSimple
, testGradientDisconnected , testGradientDisconnected
@ -197,4 +226,6 @@ main = googleTest [ testGradientSimple
, testReluGrad , testReluGrad
, testReluGradGrad , testReluGradGrad
, testFillGrad , testFillGrad
, testTileGrad
, testTile2DGrad
] ]