1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Support gradients of pad, squeeze, spaceToBatchND, and batchToSpaceND (#226)

This commit is contained in:
Rik 2018-11-27 20:17:32 +01:00 committed by fkm3
parent 95c6b6f277
commit e4acd69574
2 changed files with 83 additions and 3 deletions

View file

@ -20,6 +20,7 @@
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-} {-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeApplications #-}
module TensorFlow.Gradient module TensorFlow.Gradient
( GradientCompatible ( GradientCompatible
@ -693,6 +694,29 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing] opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs
opGrad "Squeeze" _ [toT -> x] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a)]
opGrad "Pad" _ [toT -> x, toT -> padPattern] [dz] =
[Just $ CoreOps.slice dz gradientSliceBegin gradientSliceSize, Nothing]
where
v1 = vector [1]
-- For some reason rankx' has an empty shape
rankx' = CoreOps.rank (x :: Tensor Build Float)
rankx = CoreOps.reshape rankx' v1
-- Size of column that is sliced from pad pattern
padPatternSliceSize = CoreOps.concat 0 [rankx, v1]
padPatternSliceBegin = vector [0, 0]
padPatternSliced :: Tensor Build Int32 = CoreOps.slice padPattern padPatternSliceBegin padPatternSliceSize
-- The slice of the pad pattern has the same rank as the pad pattern itself
gradientSliceBegin = CoreOps.reshape padPatternSliced rankx
gradientSliceSize = shape (x :: Tensor Build Float)
-- TODO: This could be either Int32 or Int64.
opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] =
[Just $ CoreOps.spaceToBatchND dz blockShape crops, Nothing, Nothing]
-- TODO: This could be either Int32 or Int64.
opGrad "SpaceToBatchND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> paddings] [dz] =
[Just $ CoreOps.batchToSpaceND dz blockShape paddings, Nothing, Nothing]
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing] opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
opGrad "TruncatedNormal" _ _ _ = [Nothing] opGrad "TruncatedNormal" _ _ _ = [Nothing]
@ -800,6 +824,7 @@ numOutputs o =
"Abs" -> 1 "Abs" -> 1
"Add" -> 1 "Add" -> 1
"AddN" -> 1 "AddN" -> 1
"BatchToSpaceND" -> 1
"Cast" -> 1 "Cast" -> 1
"Const" -> 1 "Const" -> 1
"Concat" -> 1 "Concat" -> 1
@ -823,6 +848,7 @@ numOutputs o =
"Min" -> 1 "Min" -> 1
"Mul" -> 1 "Mul" -> 1
"Neg" -> 1 "Neg" -> 1
"Pad" -> 1
"Placeholder" -> 1 "Placeholder" -> 1
"OneHot" -> 1 "OneHot" -> 1
"ReadVariableOp" -> 1 "ReadVariableOp" -> 1
@ -833,8 +859,10 @@ numOutputs o =
"Select" -> 1 "Select" -> 1
"Size" -> 1 "Size" -> 1
"SoftmaxCrossEntropyWithLogits" -> 2 "SoftmaxCrossEntropyWithLogits" -> 2
"Square" -> 1 "SpaceToBatchND" -> 1
"SparseSegmentSum" -> 1 "SparseSegmentSum" -> 1
"Square" -> 1
"Squeeze" -> 1
"Sub" -> 1 "Sub" -> 1
"Sum" -> 1 "Sum" -> 1
"Tanh" -> 1 "Tanh" -> 1

View file

@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile) import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze)
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable) import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
import qualified TensorFlow.Output as TF import qualified TensorFlow.Output as TF
@ -313,6 +313,54 @@ testReshape =
V.fromList [1, 1, 1, 1] @=? dx V.fromList [1, 1, 1, 1] @=? dx
V.fromList [2, 2] @=? s V.fromList [2, 2] @=? s
testPad :: Test
testPad =
testCase "testPad" $ do
([dx], [s]) <-
TF.runSession $ do
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 2, 3 :: Int64]
let y = TF.pad x $ TF.constant (TF.Shape [3, 2]) [1, 4, 1, 1, 2, 3 :: Int32]
calculateGradWithShape y x
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
V.fromList [2, 2, 3] @=? s
testBatchToSpaceND :: Test
testBatchToSpaceND =
testCase "testBatchToSpaceND" $ do
([dx], [s]) <-
TF.runSession $ do
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.constant (TF.Shape [4, 1, 1, 1 :: Int64]) [1, 2, 3, 4]
shape <- TF.render $ TF.vector [2, 2 :: Int32]
crops <- TF.render $ TF.constant (TF.Shape [2, 2]) [0, 0, 0, 0 :: Int32]
let y = TF.batchToSpaceND x shape crops
calculateGradWithShape y x
V.fromList [1, 1, 1, 1] @=? dx
V.fromList [4, 1, 1, 1] @=? s
testSpaceToBatchND :: Test
testSpaceToBatchND =
testCase "testSpaceToBatchND" $ do
([dx], [s]) <-
TF.runSession $ do
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.constant (TF.Shape [1, 2, 2, 1 :: Int64]) [1, 2, 3, 4]
shape <- TF.render $ TF.vector [2, 2 :: Int32]
paddings <- TF.render $ TF.constant (TF.Shape [2, 2]) [0, 0, 0, 0 :: Int32]
let y = TF.spaceToBatchND x shape paddings
calculateGradWithShape y x
V.fromList [1, 1, 1, 1] @=? dx
V.fromList [1, 2, 2, 1] @=? s
testSqueeze :: Test
testSqueeze =
testCase "testSqueeze" $ do
([dx], [s]) <-
TF.runSession $ do
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 3 :: Int64]
let y = TF.squeeze x
calculateGradWithShape y x
V.fromList [1, 1, 1, 1, 1, 1] @=? dx
V.fromList [1, 2, 3] @=? s
calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32]) calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32])
calculateGradWithShape y x = do calculateGradWithShape y x = do
gs <- TF.gradients y [x] gs <- TF.gradients y [x]
@ -468,6 +516,10 @@ main = defaultMain
, testTanhGrad , testTanhGrad
, testExpandDims , testExpandDims
, testReshape , testReshape
, testPad
, testBatchToSpaceND
, testSpaceToBatchND
, testSqueeze
, testFillGrad , testFillGrad
, testTileGrad , testTileGrad
, testTile2DGrad , testTile2DGrad
@ -478,4 +530,4 @@ main = defaultMain
, matMulTransposeGradient (True, False) , matMulTransposeGradient (True, False)
, matMulTransposeGradient (True, True) , matMulTransposeGradient (True, True)
, testConv2DBackpropInputGrad , testConv2DBackpropInputGrad
] ]