mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Support gradients of pad, squeeze, spaceToBatchND, and batchToSpaceND (#226)
This commit is contained in:
parent
95c6b6f277
commit
e4acd69574
2 changed files with 83 additions and 3 deletions
|
@ -20,6 +20,7 @@
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE ViewPatterns #-}
|
{-# LANGUAGE ViewPatterns #-}
|
||||||
|
{-# LANGUAGE TypeApplications #-}
|
||||||
|
|
||||||
module TensorFlow.Gradient
|
module TensorFlow.Gradient
|
||||||
( GradientCompatible
|
( GradientCompatible
|
||||||
|
@ -693,6 +694,29 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||||
|
|
||||||
opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
|
opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
|
||||||
opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs
|
opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs
|
||||||
|
opGrad "Squeeze" _ [toT -> x] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a)]
|
||||||
|
opGrad "Pad" _ [toT -> x, toT -> padPattern] [dz] =
|
||||||
|
[Just $ CoreOps.slice dz gradientSliceBegin gradientSliceSize, Nothing]
|
||||||
|
where
|
||||||
|
v1 = vector [1]
|
||||||
|
-- For some reason rankx' has an empty shape
|
||||||
|
rankx' = CoreOps.rank (x :: Tensor Build Float)
|
||||||
|
rankx = CoreOps.reshape rankx' v1
|
||||||
|
-- Size of column that is sliced from pad pattern
|
||||||
|
padPatternSliceSize = CoreOps.concat 0 [rankx, v1]
|
||||||
|
padPatternSliceBegin = vector [0, 0]
|
||||||
|
padPatternSliced :: Tensor Build Int32 = CoreOps.slice padPattern padPatternSliceBegin padPatternSliceSize
|
||||||
|
-- The slice of the pad pattern has the same rank as the pad pattern itself
|
||||||
|
gradientSliceBegin = CoreOps.reshape padPatternSliced rankx
|
||||||
|
gradientSliceSize = shape (x :: Tensor Build Float)
|
||||||
|
|
||||||
|
-- TODO: This could be either Int32 or Int64.
|
||||||
|
opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] =
|
||||||
|
[Just $ CoreOps.spaceToBatchND dz blockShape crops, Nothing, Nothing]
|
||||||
|
|
||||||
|
-- TODO: This could be either Int32 or Int64.
|
||||||
|
opGrad "SpaceToBatchND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> paddings] [dz] =
|
||||||
|
[Just $ CoreOps.batchToSpaceND dz blockShape paddings, Nothing, Nothing]
|
||||||
|
|
||||||
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
||||||
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
||||||
|
@ -800,6 +824,7 @@ numOutputs o =
|
||||||
"Abs" -> 1
|
"Abs" -> 1
|
||||||
"Add" -> 1
|
"Add" -> 1
|
||||||
"AddN" -> 1
|
"AddN" -> 1
|
||||||
|
"BatchToSpaceND" -> 1
|
||||||
"Cast" -> 1
|
"Cast" -> 1
|
||||||
"Const" -> 1
|
"Const" -> 1
|
||||||
"Concat" -> 1
|
"Concat" -> 1
|
||||||
|
@ -823,6 +848,7 @@ numOutputs o =
|
||||||
"Min" -> 1
|
"Min" -> 1
|
||||||
"Mul" -> 1
|
"Mul" -> 1
|
||||||
"Neg" -> 1
|
"Neg" -> 1
|
||||||
|
"Pad" -> 1
|
||||||
"Placeholder" -> 1
|
"Placeholder" -> 1
|
||||||
"OneHot" -> 1
|
"OneHot" -> 1
|
||||||
"ReadVariableOp" -> 1
|
"ReadVariableOp" -> 1
|
||||||
|
@ -833,8 +859,10 @@ numOutputs o =
|
||||||
"Select" -> 1
|
"Select" -> 1
|
||||||
"Size" -> 1
|
"Size" -> 1
|
||||||
"SoftmaxCrossEntropyWithLogits" -> 2
|
"SoftmaxCrossEntropyWithLogits" -> 2
|
||||||
"Square" -> 1
|
"SpaceToBatchND" -> 1
|
||||||
"SparseSegmentSum" -> 1
|
"SparseSegmentSum" -> 1
|
||||||
|
"Square" -> 1
|
||||||
|
"Squeeze" -> 1
|
||||||
"Sub" -> 1
|
"Sub" -> 1
|
||||||
"Sum" -> 1
|
"Sum" -> 1
|
||||||
"Tanh" -> 1
|
"Tanh" -> 1
|
||||||
|
|
|
@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile)
|
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
||||||
import qualified TensorFlow.Output as TF
|
import qualified TensorFlow.Output as TF
|
||||||
|
@ -313,6 +313,54 @@ testReshape =
|
||||||
V.fromList [1, 1, 1, 1] @=? dx
|
V.fromList [1, 1, 1, 1] @=? dx
|
||||||
V.fromList [2, 2] @=? s
|
V.fromList [2, 2] @=? s
|
||||||
|
|
||||||
|
testPad :: Test
|
||||||
|
testPad =
|
||||||
|
testCase "testPad" $ do
|
||||||
|
([dx], [s]) <-
|
||||||
|
TF.runSession $ do
|
||||||
|
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 2, 3 :: Int64]
|
||||||
|
let y = TF.pad x $ TF.constant (TF.Shape [3, 2]) [1, 4, 1, 1, 2, 3 :: Int32]
|
||||||
|
calculateGradWithShape y x
|
||||||
|
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
|
||||||
|
V.fromList [2, 2, 3] @=? s
|
||||||
|
|
||||||
|
testBatchToSpaceND :: Test
|
||||||
|
testBatchToSpaceND =
|
||||||
|
testCase "testBatchToSpaceND" $ do
|
||||||
|
([dx], [s]) <-
|
||||||
|
TF.runSession $ do
|
||||||
|
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.constant (TF.Shape [4, 1, 1, 1 :: Int64]) [1, 2, 3, 4]
|
||||||
|
shape <- TF.render $ TF.vector [2, 2 :: Int32]
|
||||||
|
crops <- TF.render $ TF.constant (TF.Shape [2, 2]) [0, 0, 0, 0 :: Int32]
|
||||||
|
let y = TF.batchToSpaceND x shape crops
|
||||||
|
calculateGradWithShape y x
|
||||||
|
V.fromList [1, 1, 1, 1] @=? dx
|
||||||
|
V.fromList [4, 1, 1, 1] @=? s
|
||||||
|
|
||||||
|
testSpaceToBatchND :: Test
|
||||||
|
testSpaceToBatchND =
|
||||||
|
testCase "testSpaceToBatchND" $ do
|
||||||
|
([dx], [s]) <-
|
||||||
|
TF.runSession $ do
|
||||||
|
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.constant (TF.Shape [1, 2, 2, 1 :: Int64]) [1, 2, 3, 4]
|
||||||
|
shape <- TF.render $ TF.vector [2, 2 :: Int32]
|
||||||
|
paddings <- TF.render $ TF.constant (TF.Shape [2, 2]) [0, 0, 0, 0 :: Int32]
|
||||||
|
let y = TF.spaceToBatchND x shape paddings
|
||||||
|
calculateGradWithShape y x
|
||||||
|
V.fromList [1, 1, 1, 1] @=? dx
|
||||||
|
V.fromList [1, 2, 2, 1] @=? s
|
||||||
|
|
||||||
|
testSqueeze :: Test
|
||||||
|
testSqueeze =
|
||||||
|
testCase "testSqueeze" $ do
|
||||||
|
([dx], [s]) <-
|
||||||
|
TF.runSession $ do
|
||||||
|
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 3 :: Int64]
|
||||||
|
let y = TF.squeeze x
|
||||||
|
calculateGradWithShape y x
|
||||||
|
V.fromList [1, 1, 1, 1, 1, 1] @=? dx
|
||||||
|
V.fromList [1, 2, 3] @=? s
|
||||||
|
|
||||||
calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32])
|
calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32])
|
||||||
calculateGradWithShape y x = do
|
calculateGradWithShape y x = do
|
||||||
gs <- TF.gradients y [x]
|
gs <- TF.gradients y [x]
|
||||||
|
@ -468,6 +516,10 @@ main = defaultMain
|
||||||
, testTanhGrad
|
, testTanhGrad
|
||||||
, testExpandDims
|
, testExpandDims
|
||||||
, testReshape
|
, testReshape
|
||||||
|
, testPad
|
||||||
|
, testBatchToSpaceND
|
||||||
|
, testSpaceToBatchND
|
||||||
|
, testSqueeze
|
||||||
, testFillGrad
|
, testFillGrad
|
||||||
, testTileGrad
|
, testTileGrad
|
||||||
, testTile2DGrad
|
, testTile2DGrad
|
||||||
|
@ -478,4 +530,4 @@ main = defaultMain
|
||||||
, matMulTransposeGradient (True, False)
|
, matMulTransposeGradient (True, False)
|
||||||
, matMulTransposeGradient (True, True)
|
, matMulTransposeGradient (True, True)
|
||||||
, testConv2DBackpropInputGrad
|
, testConv2DBackpropInputGrad
|
||||||
]
|
]
|
Loading…
Reference in a new issue