1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00

Refactor of the SpaceToBatchND and BatchToSpaceND gradient functions

This commit is contained in:
Rik van der Kleij 2018-11-27 16:39:16 +01:00
parent 22f9a7f925
commit 285ffe38c4

View File

@ -710,19 +710,13 @@ opGrad "Pad" _ [toT -> x, toT -> padPattern] [dz] =
gradientSliceBegin = CoreOps.reshape padPatternSliced rankx
gradientSliceSize = shape (x :: Tensor Build Float)
opGrad "BatchToSpaceND" _ [_, toT -> blockShape, toT -> crops] [dz] =
[Just $ CoreOps.spaceToBatchND dz blockShape' crops', Nothing, Nothing]
where
-- TODO: This could be either Int32 or Int64.
(blockShape' :: Tensor Build Int32) = blockShape
(crops' :: Tensor Build Int32) = crops
-- TODO: This could be either Int32 or Int64.
opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] =
[Just $ CoreOps.spaceToBatchND dz blockShape crops, Nothing, Nothing]
opGrad "SpaceToBatchND" _ [_, toT -> blockShape, toT -> paddings] [dz] =
[Just $ CoreOps.batchToSpaceND dz blockShape' paddings', Nothing, Nothing]
where
-- TODO: This could be either Int32 or Int64.
(blockShape' :: Tensor Build Int32) = blockShape
(paddings' :: Tensor Build Int32) = paddings
-- TODO: This could be either Int32 or Int64.
opGrad "SpaceToBatchND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> paddings] [dz] =
[Just $ CoreOps.batchToSpaceND dz blockShape paddings, Nothing, Nothing]
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
opGrad "TruncatedNormal" _ _ _ = [Nothing]