1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Implemented TileGrad

Some notes about static shape inference
This commit is contained in:
Jarl Christian Berentsen 2017-04-29 15:57:42 +02:00 committed by fkm3
parent 97b4bb5bab
commit 51014a015c
2 changed files with 48 additions and 1 deletions

View file

@ -667,6 +667,21 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing]
opGrad "Size" _ _ _ = [Nothing]
-- TODO (jcberentsen): Python implementation uses set_shape for
-- static shape inference, which is unsupported.
-- TODO: implement support for static shape inference
opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] =
[Just inputGrad, Nothing]
where
inputGrad = sum reshapedDz axes
inputShape = shape (x :: Tensor Build a)
packed = CoreOps.pack [multiples, inputShape]
perm = vector [1, 0 :: Int32]
splitShape = CoreOps.reshape (CoreOps.transpose packed perm) allDimensions
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
reshapedDz = CoreOps.reshape dz splitShape
opGrad "ZerosLike" _ _ _ = [Nothing]
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
where
@ -719,6 +734,7 @@ numOutputs o =
"SparseSegmentSum" -> 1
"Sub" -> 1
"Sum" -> 1
"Tile" -> 1
"Transpose" -> 1
"TruncatedNormal" -> 1
"Variable" -> 1

View file

@ -27,7 +27,7 @@ import Test.HUnit ((@=?))
import qualified Data.Vector as V
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (max)
import qualified TensorFlow.GenOps.Core as TF (max, tile)
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF
@ -187,6 +187,35 @@ testFillGrad = testCase "testFillGrad" $ do
TF.gradients y [x] >>= TF.run
V.fromList [6] @=? dx
testTileGrad :: Test
testTileGrad = testCase "testTileGrad" $ do
[dx] <- TF.runSession $ do
x <- TF.render $ TF.vector [5, 9 :: Float]
let multiples = TF.vector [2 :: Int32]
let y = TF.tile x multiples
TF.gradients y [x] >>= TF.run
V.fromList [2, 2] @=? dx
testTile2DGrad :: Test
testTile2DGrad = testCase "testTileGrad2D" $ do
(dx, shapeDX, shapeX) <- TF.runSession $ do
let shape = TF.vector [3, 2 :: Int32]
x <- TF.render $ TF.fill shape (TF.scalar (1::Float))
let multiples = TF.vector [2, 3 :: Int32]
let y = TF.tile x multiples
[dx] <- TF.gradients y [x]
shapeDX <- TF.run $ TF.shape dx
shapeX <- TF.run $ TF.shape x
dxv <- TF.run dx
return (dxv, shapeDX, shapeX)
shapeX @=? (shapeDX :: V.Vector Int32)
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
main :: IO ()
main = googleTest [ testGradientSimple
, testGradientDisconnected
@ -197,4 +226,6 @@ main = googleTest [ testGradientSimple
, testReluGrad
, testReluGradGrad
, testFillGrad
, testTileGrad
, testTile2DGrad
]