mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Implemented TileGrad
Some notes about static shape inference
This commit is contained in:
parent
97b4bb5bab
commit
51014a015c
2 changed files with 48 additions and 1 deletions
|
@ -667,6 +667,21 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
||||||
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||||
opGrad "Size" _ _ _ = [Nothing]
|
opGrad "Size" _ _ _ = [Nothing]
|
||||||
|
|
||||||
|
-- TODO (jcberentsen): Python implementation uses set_shape for
|
||||||
|
-- static shape inference, which is unsupported.
|
||||||
|
-- TODO: implement support for static shape inference
|
||||||
|
opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] =
|
||||||
|
[Just inputGrad, Nothing]
|
||||||
|
where
|
||||||
|
inputGrad = sum reshapedDz axes
|
||||||
|
inputShape = shape (x :: Tensor Build a)
|
||||||
|
packed = CoreOps.pack [multiples, inputShape]
|
||||||
|
perm = vector [1, 0 :: Int32]
|
||||||
|
splitShape = CoreOps.reshape (CoreOps.transpose packed perm) allDimensions
|
||||||
|
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
|
||||||
|
reshapedDz = CoreOps.reshape dz splitShape
|
||||||
|
|
||||||
opGrad "ZerosLike" _ _ _ = [Nothing]
|
opGrad "ZerosLike" _ _ _ = [Nothing]
|
||||||
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
||||||
where
|
where
|
||||||
|
@ -719,6 +734,7 @@ numOutputs o =
|
||||||
"SparseSegmentSum" -> 1
|
"SparseSegmentSum" -> 1
|
||||||
"Sub" -> 1
|
"Sub" -> 1
|
||||||
"Sum" -> 1
|
"Sum" -> 1
|
||||||
|
"Tile" -> 1
|
||||||
"Transpose" -> 1
|
"Transpose" -> 1
|
||||||
"TruncatedNormal" -> 1
|
"TruncatedNormal" -> 1
|
||||||
"Variable" -> 1
|
"Variable" -> 1
|
||||||
|
|
|
@ -27,7 +27,7 @@ import Test.HUnit ((@=?))
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (max)
|
import qualified TensorFlow.GenOps.Core as TF (max, tile)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
|
|
||||||
|
@ -187,6 +187,35 @@ testFillGrad = testCase "testFillGrad" $ do
|
||||||
TF.gradients y [x] >>= TF.run
|
TF.gradients y [x] >>= TF.run
|
||||||
V.fromList [6] @=? dx
|
V.fromList [6] @=? dx
|
||||||
|
|
||||||
|
|
||||||
|
testTileGrad :: Test
|
||||||
|
testTileGrad = testCase "testTileGrad" $ do
|
||||||
|
[dx] <- TF.runSession $ do
|
||||||
|
x <- TF.render $ TF.vector [5, 9 :: Float]
|
||||||
|
let multiples = TF.vector [2 :: Int32]
|
||||||
|
let y = TF.tile x multiples
|
||||||
|
TF.gradients y [x] >>= TF.run
|
||||||
|
V.fromList [2, 2] @=? dx
|
||||||
|
|
||||||
|
|
||||||
|
testTile2DGrad :: Test
|
||||||
|
testTile2DGrad = testCase "testTileGrad2D" $ do
|
||||||
|
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
||||||
|
let shape = TF.vector [3, 2 :: Int32]
|
||||||
|
x <- TF.render $ TF.fill shape (TF.scalar (1::Float))
|
||||||
|
let multiples = TF.vector [2, 3 :: Int32]
|
||||||
|
let y = TF.tile x multiples
|
||||||
|
|
||||||
|
[dx] <- TF.gradients y [x]
|
||||||
|
|
||||||
|
shapeDX <- TF.run $ TF.shape dx
|
||||||
|
shapeX <- TF.run $ TF.shape x
|
||||||
|
dxv <- TF.run dx
|
||||||
|
return (dxv, shapeDX, shapeX)
|
||||||
|
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||||
|
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
|
||||||
|
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testGradientSimple
|
main = googleTest [ testGradientSimple
|
||||||
, testGradientDisconnected
|
, testGradientDisconnected
|
||||||
|
@ -197,4 +226,6 @@ main = googleTest [ testGradientSimple
|
||||||
, testReluGrad
|
, testReluGrad
|
||||||
, testReluGradGrad
|
, testReluGradGrad
|
||||||
, testFillGrad
|
, testFillGrad
|
||||||
|
, testTileGrad
|
||||||
|
, testTile2DGrad
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue