mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Refactor of the SpaceToBatchND and BatchToSpaceND gradient functions
This commit is contained in:
parent
22f9a7f925
commit
285ffe38c4
|
@ -710,19 +710,13 @@ opGrad "Pad" _ [toT -> x, toT -> padPattern] [dz] =
|
|||
gradientSliceBegin = CoreOps.reshape padPatternSliced rankx
|
||||
gradientSliceSize = shape (x :: Tensor Build Float)
|
||||
|
||||
opGrad "BatchToSpaceND" _ [_, toT -> blockShape, toT -> crops] [dz] =
|
||||
[Just $ CoreOps.spaceToBatchND dz blockShape' crops', Nothing, Nothing]
|
||||
where
|
||||
-- TODO: This could be either Int32 or Int64.
|
||||
(blockShape' :: Tensor Build Int32) = blockShape
|
||||
(crops' :: Tensor Build Int32) = crops
|
||||
-- TODO: This could be either Int32 or Int64.
|
||||
opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] =
|
||||
[Just $ CoreOps.spaceToBatchND dz blockShape crops, Nothing, Nothing]
|
||||
|
||||
opGrad "SpaceToBatchND" _ [_, toT -> blockShape, toT -> paddings] [dz] =
|
||||
[Just $ CoreOps.batchToSpaceND dz blockShape' paddings', Nothing, Nothing]
|
||||
where
|
||||
-- TODO: This could be either Int32 or Int64.
|
||||
(blockShape' :: Tensor Build Int32) = blockShape
|
||||
(paddings' :: Tensor Build Int32) = paddings
|
||||
-- TODO: This could be either Int32 or Int64.
|
||||
opGrad "SpaceToBatchND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> paddings] [dz] =
|
||||
[Just $ CoreOps.batchToSpaceND dz blockShape paddings, Nothing, Nothing]
|
||||
|
||||
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
||||
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
||||
|
|
Loading…
Reference in New Issue
Block a user