mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-24 18:49:46 +01:00
Add gradient for ResizeBilinear (#239)
This commit is contained in:
parent
3cfd96ef08
commit
96f1c88327
2 changed files with 30 additions and 1 deletions
|
@ -817,6 +817,17 @@ opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] =
|
|||
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
|
||||
reshapedDz = CoreOps.reshape dz splitShape
|
||||
|
||||
opGrad "ResizeBilinear" nodeDef [toT -> x, _] [dz] =
|
||||
[ Just $ CoreOps.resizeBilinearGrad'
|
||||
(opAttr "align_corners" .~ align)
|
||||
(CoreOps.cast dz)
|
||||
x
|
||||
|
||||
, Nothing
|
||||
]
|
||||
where
|
||||
align = lookupAttr nodeDef "align_corners" :: Bool
|
||||
|
||||
opGrad "ZerosLike" _ _ _ = [Nothing]
|
||||
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
||||
where
|
||||
|
@ -894,6 +905,7 @@ numOutputs o =
|
|||
"Sum" -> 1
|
||||
"Tanh" -> 1
|
||||
"Tile" -> 1
|
||||
"ResizeBilinear" -> 1
|
||||
"Transpose" -> 1
|
||||
"TruncatedNormal" -> 1
|
||||
"VarHandleOp" -> 1
|
||||
|
|
|
@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||
import qualified TensorFlow.Output as TF
|
||||
|
@ -429,6 +429,22 @@ testTile2DGrad = testCase "testTileGrad2D" $ do
|
|||
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
|
||||
|
||||
testResizeBilinearGrad :: Test
|
||||
testResizeBilinearGrad = testCase "testResizeBilinearGrad" $ do
|
||||
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
||||
let shape = TF.vector [1, 2, 2, 1 :: Int32]
|
||||
x <- TF.render $ TF.fill shape (TF.scalar (1 :: Float))
|
||||
let outSize = TF.vector [4, 4 :: Int32]
|
||||
align = TF.opAttr "align_corners" .~ True
|
||||
y = TF.resizeBilinear' align x outSize
|
||||
|
||||
[dx] <- TF.gradients y [x]
|
||||
TF.run (dx, TF.shape dx, TF.shape x)
|
||||
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||
let expect = V.fromList [4, 4, 4, 4 :: Float]
|
||||
near = 0.00001 > (V.sum $ V.zipWith (-) expect (dx :: V.Vector Float))
|
||||
near @=? True
|
||||
|
||||
matMulGradient :: Test
|
||||
matMulGradient = testCase "matMulGradients" $ do
|
||||
|
||||
|
@ -553,6 +569,7 @@ main = defaultMain
|
|||
, testFillGrad
|
||||
, testTileGrad
|
||||
, testTile2DGrad
|
||||
, testResizeBilinearGrad
|
||||
, matMulGradient
|
||||
, matMulGradGrad
|
||||
, matMulTransposeGradient (False, False)
|
||||
|
|
Loading…
Reference in a new issue