mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Gradient of Conv2DBackpropInput (#155)
This commit is contained in:
parent
d8bf349962
commit
2dcc921f6e
3 changed files with 49 additions and 2 deletions
|
@ -650,6 +650,27 @@ opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
|||
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
|
||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||
|
||||
opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
|
||||
[ Nothing
|
||||
, Just $ CoreOps.conv2DBackpropFilter'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
dz (shape x) y
|
||||
, Just $ CoreOps.conv2D'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
dz x
|
||||
]
|
||||
where
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
|
||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||
|
||||
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||
[ Just $ CoreOps.maxPoolGrad'
|
||||
((opAttr "ksize" .~ ksize)
|
||||
|
@ -779,6 +800,7 @@ numOutputs o =
|
|||
"Const" -> 1
|
||||
"Concat" -> 1
|
||||
"Conv2D" -> 1
|
||||
"Conv2DBackpropInput" -> 1
|
||||
"Div" -> 1
|
||||
"DynamicStitch" -> 1
|
||||
"DynamicPartition" ->
|
||||
|
|
|
@ -188,6 +188,7 @@ Test-Suite GradientTest
|
|||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, bytestring
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, random
|
||||
|
|
|
@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (max, tile, maximum)
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
||||
import qualified TensorFlow.Output as TF
|
||||
|
@ -42,6 +42,8 @@ import qualified TensorFlow.Variable as TF
|
|||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||
|
||||
import qualified Data.ByteString.Char8 as BS
|
||||
|
||||
testGradientSimple :: Test
|
||||
testGradientSimple = testCase "testGradientSimple" $ do
|
||||
let grads = do
|
||||
|
@ -313,7 +315,6 @@ testTile2DGrad = testCase "testTileGrad2D" $ do
|
|||
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
|
||||
|
||||
|
||||
matMulGradient :: Test
|
||||
matMulGradient = testCase "matMulGradients" $ do
|
||||
|
||||
|
@ -388,6 +389,28 @@ transAttrs :: (TF.Attribute a,
|
|||
transAttrs a b =
|
||||
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
|
||||
|
||||
testConv2DBackpropInputGrad :: Test
|
||||
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
||||
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
||||
let conv_input_shape = TF.vector [1, 2, 2, 1 :: Int32] -- [batch, h, w, in_channels]
|
||||
let conv_out_shape = TF.vector [1, 1, 1, 1 :: Int32] -- [batch, h, w, out_channels]
|
||||
x <- TF.render $ TF.fill conv_out_shape (TF.scalar (1::Float))
|
||||
|
||||
let filterShape = TF.vector [2, 2, 1, 1 :: Int32] -- [fh, fw, inc, out]
|
||||
filter <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
|
||||
let y = TF.conv2DBackpropInput'
|
||||
( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
|
||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
||||
)
|
||||
conv_input_shape filter x
|
||||
|
||||
[dx] <- TF.gradients y [x]
|
||||
TF.run (dx, TF.shape dx, TF.shape x)
|
||||
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||
V.fromList [4::Float] @=? (dx :: V.Vector Float)
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain
|
||||
[ testGradientSimple
|
||||
|
@ -413,4 +436,5 @@ main = defaultMain
|
|||
, matMulTransposeGradient (False, True)
|
||||
, matMulTransposeGradient (True, False)
|
||||
, matMulTransposeGradient (True, True)
|
||||
, testConv2DBackpropInputGrad
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue