mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Add gradient for slice function (#234)
This commit is contained in:
parent
666dce94bd
commit
3cfd96ef08
2 changed files with 45 additions and 2 deletions
|
@ -710,6 +710,27 @@ opGrad "Pad" _ [toT -> x, toT -> padPattern] [dz] =
|
||||||
gradientSliceBegin = CoreOps.reshape padPatternSliced rankx
|
gradientSliceBegin = CoreOps.reshape padPatternSliced rankx
|
||||||
gradientSliceSize = shape (x :: Tensor Build Float)
|
gradientSliceSize = shape (x :: Tensor Build Float)
|
||||||
|
|
||||||
|
-- Gradient for Slice
|
||||||
|
-- Create an Nx2 padding where N is the rank of (grad of) Slice and the first
|
||||||
|
-- column represents how many zeros are to be prepended for each dimension, and the second
|
||||||
|
-- column indicates how many zeros are appended.
|
||||||
|
-- The number of zeros to prepend is the shape of the beginvec.
|
||||||
|
-- The number of zeros to append is the shape of the inputvec
|
||||||
|
-- elementwise-subtracted by both the beginvec and sizevec.
|
||||||
|
-- Some more reshaping is needed to assemble this tensor with the
|
||||||
|
-- right dimensions.
|
||||||
|
opGrad "Slice" _ [toT -> inputvec, toT -> beginvec, _] [dz] =
|
||||||
|
[Just $ CoreOps.pad dz paddings, Nothing, Nothing]
|
||||||
|
where
|
||||||
|
v1 = vector [1 :: Int32]
|
||||||
|
inputRank' = CoreOps.rank (inputvec :: Tensor Build Float)
|
||||||
|
-- For some reason inputRank' has an empty shape
|
||||||
|
inputRank = CoreOps.reshape inputRank' v1
|
||||||
|
padShape = CoreOps.concat 0 [inputRank, v1]
|
||||||
|
beforePad = CoreOps.reshape beginvec padShape
|
||||||
|
afterPad = CoreOps.reshape (shape inputvec - shape dz - beginvec) padShape
|
||||||
|
paddings = CoreOps.concat 1 [beforePad, afterPad]
|
||||||
|
|
||||||
-- TODO: This could be either Int32 or Int64.
|
-- TODO: This could be either Int32 or Int64.
|
||||||
opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] =
|
opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] =
|
||||||
[Just $ CoreOps.spaceToBatchND dz blockShape crops, Nothing, Nothing]
|
[Just $ CoreOps.spaceToBatchND dz blockShape crops, Nothing, Nothing]
|
||||||
|
@ -862,6 +883,7 @@ numOutputs o =
|
||||||
"Reshape" -> 1
|
"Reshape" -> 1
|
||||||
"Select" -> 1
|
"Select" -> 1
|
||||||
"Size" -> 1
|
"Size" -> 1
|
||||||
|
"Slice" -> 1
|
||||||
"SoftmaxCrossEntropyWithLogits" -> 2
|
"SoftmaxCrossEntropyWithLogits" -> 2
|
||||||
"SpaceToBatchND" -> 1
|
"SpaceToBatchND" -> 1
|
||||||
"SparseSegmentSum" -> 1
|
"SparseSegmentSum" -> 1
|
||||||
|
|
|
@ -32,9 +32,9 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt)
|
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||||
import qualified TensorFlow.Output as TF
|
import qualified TensorFlow.Output as TF
|
||||||
import qualified TensorFlow.Types as TF
|
import qualified TensorFlow.Types as TF
|
||||||
import qualified TensorFlow.Variable as TF
|
import qualified TensorFlow.Variable as TF
|
||||||
|
@ -324,6 +324,7 @@ testPad =
|
||||||
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
|
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
|
||||||
V.fromList [2, 2, 3] @=? s
|
V.fromList [2, 2, 3] @=? s
|
||||||
|
|
||||||
|
|
||||||
testSqrt :: Test
|
testSqrt :: Test
|
||||||
testSqrt = testCase "testSqrt" $ do
|
testSqrt = testCase "testSqrt" $ do
|
||||||
[dx] <- TF.runSession $ do
|
[dx] <- TF.runSession $ do
|
||||||
|
@ -332,6 +333,25 @@ testSqrt = testCase "testSqrt" $ do
|
||||||
TF.gradients y [x] >>= TF.run
|
TF.gradients y [x] >>= TF.run
|
||||||
V.fromList [2] @=? dx
|
V.fromList [2] @=? dx
|
||||||
|
|
||||||
|
testSlice :: Test
|
||||||
|
testSlice =
|
||||||
|
testCase "testSlice" $ do
|
||||||
|
([dx], [s]) <-
|
||||||
|
TF.runSession $ do
|
||||||
|
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 3, 4 :: Int64]
|
||||||
|
(z :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 2 :: Int64]
|
||||||
|
let y = TF.slice x (TF.constant (TF.Shape [3]) [1, 1, 1 :: Int32]) (TF.shape z)
|
||||||
|
calculateGradWithShape y x
|
||||||
|
let expected =
|
||||||
|
[0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0,
|
||||||
|
0, 1, 1, 0,
|
||||||
|
0, 1, 1, 0]
|
||||||
|
V.fromList expected @=? dx
|
||||||
|
V.fromList [2, 3, 4] @=? s
|
||||||
|
|
||||||
testBatchToSpaceND :: Test
|
testBatchToSpaceND :: Test
|
||||||
testBatchToSpaceND =
|
testBatchToSpaceND =
|
||||||
testCase "testBatchToSpaceND" $ do
|
testCase "testBatchToSpaceND" $ do
|
||||||
|
@ -526,6 +546,7 @@ main = defaultMain
|
||||||
, testReshape
|
, testReshape
|
||||||
, testPad
|
, testPad
|
||||||
, testSqrt
|
, testSqrt
|
||||||
|
, testSlice
|
||||||
, testBatchToSpaceND
|
, testBatchToSpaceND
|
||||||
, testSpaceToBatchND
|
, testSpaceToBatchND
|
||||||
, testSqueeze
|
, testSqueeze
|
||||||
|
|
Loading…
Reference in a new issue