mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-24 18:49:46 +01:00
Implemented TileGrad
Some notes about static shape inference
This commit is contained in:
parent
97b4bb5bab
commit
51014a015c
2 changed files with 48 additions and 1 deletions
|
@ -667,6 +667,21 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
|||
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||
opGrad "Size" _ _ _ = [Nothing]
|
||||
|
||||
-- TODO (jcberentsen): Python implementation uses set_shape for
|
||||
-- static shape inference, which is unsupported.
|
||||
-- TODO: implement support for static shape inference
|
||||
opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] =
|
||||
[Just inputGrad, Nothing]
|
||||
where
|
||||
inputGrad = sum reshapedDz axes
|
||||
inputShape = shape (x :: Tensor Build a)
|
||||
packed = CoreOps.pack [multiples, inputShape]
|
||||
perm = vector [1, 0 :: Int32]
|
||||
splitShape = CoreOps.reshape (CoreOps.transpose packed perm) allDimensions
|
||||
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
|
||||
reshapedDz = CoreOps.reshape dz splitShape
|
||||
|
||||
opGrad "ZerosLike" _ _ _ = [Nothing]
|
||||
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
||||
where
|
||||
|
@ -719,6 +734,7 @@ numOutputs o =
|
|||
"SparseSegmentSum" -> 1
|
||||
"Sub" -> 1
|
||||
"Sum" -> 1
|
||||
"Tile" -> 1
|
||||
"Transpose" -> 1
|
||||
"TruncatedNormal" -> 1
|
||||
"Variable" -> 1
|
||||
|
|
|
@ -27,7 +27,7 @@ import Test.HUnit ((@=?))
|
|||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (max)
|
||||
import qualified TensorFlow.GenOps.Core as TF (max, tile)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
|
||||
|
@ -187,6 +187,35 @@ testFillGrad = testCase "testFillGrad" $ do
|
|||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [6] @=? dx
|
||||
|
||||
|
||||
testTileGrad :: Test
|
||||
testTileGrad = testCase "testTileGrad" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
x <- TF.render $ TF.vector [5, 9 :: Float]
|
||||
let multiples = TF.vector [2 :: Int32]
|
||||
let y = TF.tile x multiples
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [2, 2] @=? dx
|
||||
|
||||
|
||||
testTile2DGrad :: Test
|
||||
testTile2DGrad = testCase "testTileGrad2D" $ do
|
||||
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
||||
let shape = TF.vector [3, 2 :: Int32]
|
||||
x <- TF.render $ TF.fill shape (TF.scalar (1::Float))
|
||||
let multiples = TF.vector [2, 3 :: Int32]
|
||||
let y = TF.tile x multiples
|
||||
|
||||
[dx] <- TF.gradients y [x]
|
||||
|
||||
shapeDX <- TF.run $ TF.shape dx
|
||||
shapeX <- TF.run $ TF.shape x
|
||||
dxv <- TF.run dx
|
||||
return (dxv, shapeDX, shapeX)
|
||||
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testGradientSimple
|
||||
, testGradientDisconnected
|
||||
|
@ -197,4 +226,6 @@ main = googleTest [ testGradientSimple
|
|||
, testReluGrad
|
||||
, testReluGradGrad
|
||||
, testFillGrad
|
||||
, testTileGrad
|
||||
, testTile2DGrad
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue