mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Add gradient for ResizeBilinear (#239)
This commit is contained in:
parent
3cfd96ef08
commit
96f1c88327
2 changed files with 30 additions and 1 deletions
|
@ -817,6 +817,17 @@ opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] =
|
||||||
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
|
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
|
||||||
reshapedDz = CoreOps.reshape dz splitShape
|
reshapedDz = CoreOps.reshape dz splitShape
|
||||||
|
|
||||||
|
opGrad "ResizeBilinear" nodeDef [toT -> x, _] [dz] =
|
||||||
|
[ Just $ CoreOps.resizeBilinearGrad'
|
||||||
|
(opAttr "align_corners" .~ align)
|
||||||
|
(CoreOps.cast dz)
|
||||||
|
x
|
||||||
|
|
||||||
|
, Nothing
|
||||||
|
]
|
||||||
|
where
|
||||||
|
align = lookupAttr nodeDef "align_corners" :: Bool
|
||||||
|
|
||||||
opGrad "ZerosLike" _ _ _ = [Nothing]
|
opGrad "ZerosLike" _ _ _ = [Nothing]
|
||||||
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
||||||
where
|
where
|
||||||
|
@ -894,6 +905,7 @@ numOutputs o =
|
||||||
"Sum" -> 1
|
"Sum" -> 1
|
||||||
"Tanh" -> 1
|
"Tanh" -> 1
|
||||||
"Tile" -> 1
|
"Tile" -> 1
|
||||||
|
"ResizeBilinear" -> 1
|
||||||
"Transpose" -> 1
|
"Transpose" -> 1
|
||||||
"TruncatedNormal" -> 1
|
"TruncatedNormal" -> 1
|
||||||
"VarHandleOp" -> 1
|
"VarHandleOp" -> 1
|
||||||
|
|
|
@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
|
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||||
import qualified TensorFlow.Output as TF
|
import qualified TensorFlow.Output as TF
|
||||||
|
@ -429,6 +429,22 @@ testTile2DGrad = testCase "testTileGrad2D" $ do
|
||||||
shapeX @=? (shapeDX :: V.Vector Int32)
|
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||||
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
|
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
|
||||||
|
|
||||||
|
testResizeBilinearGrad :: Test
|
||||||
|
testResizeBilinearGrad = testCase "testResizeBilinearGrad" $ do
|
||||||
|
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
||||||
|
let shape = TF.vector [1, 2, 2, 1 :: Int32]
|
||||||
|
x <- TF.render $ TF.fill shape (TF.scalar (1 :: Float))
|
||||||
|
let outSize = TF.vector [4, 4 :: Int32]
|
||||||
|
align = TF.opAttr "align_corners" .~ True
|
||||||
|
y = TF.resizeBilinear' align x outSize
|
||||||
|
|
||||||
|
[dx] <- TF.gradients y [x]
|
||||||
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
|
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||||
|
let expect = V.fromList [4, 4, 4, 4 :: Float]
|
||||||
|
near = 0.00001 > (V.sum $ V.zipWith (-) expect (dx :: V.Vector Float))
|
||||||
|
near @=? True
|
||||||
|
|
||||||
matMulGradient :: Test
|
matMulGradient :: Test
|
||||||
matMulGradient = testCase "matMulGradients" $ do
|
matMulGradient = testCase "matMulGradients" $ do
|
||||||
|
|
||||||
|
@ -553,6 +569,7 @@ main = defaultMain
|
||||||
, testFillGrad
|
, testFillGrad
|
||||||
, testTileGrad
|
, testTileGrad
|
||||||
, testTile2DGrad
|
, testTile2DGrad
|
||||||
|
, testResizeBilinearGrad
|
||||||
, matMulGradient
|
, matMulGradient
|
||||||
, matMulGradGrad
|
, matMulGradGrad
|
||||||
, matMulTransposeGradient (False, False)
|
, matMulTransposeGradient (False, False)
|
||||||
|
|
Loading…
Reference in a new issue