1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Add gradients for DepthwiseConv2dNative (#240)

This commit is contained in:
Christian Berentsen 2019-04-22 06:46:27 +02:00 committed by fkm3
parent 4a2e46ba57
commit 1fbd5d41dd
2 changed files with 82 additions and 1 deletions

View file

@ -694,6 +694,41 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
dataFormat = lookupAttr nodeDef "data_format" :: ByteString dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
(shape x) y dz
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
x (shape y) dz
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
[ Nothing
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
dz (shape x) y
, Just $ CoreOps.depthwiseConv2dNative'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
dz x
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "MaxPool" nodeDef [toT -> x] [dz] = opGrad "MaxPool" nodeDef [toT -> x] [dz] =
[ Just $ CoreOps.maxPoolGrad' [ Just $ CoreOps.maxPoolGrad'
((opAttr "ksize" .~ ksize) ((opAttr "ksize" .~ ksize)
@ -882,6 +917,8 @@ numOutputs o =
"Concat" -> 1 "Concat" -> 1
"Conv2D" -> 1 "Conv2D" -> 1
"Conv2DBackpropInput" -> 1 "Conv2DBackpropInput" -> 1
"DepthwiseConv2dNative" -> 1
"DepthwiseConv2dNativeBackpropInput" -> 1
"Div" -> 1 "Div" -> 1
"DynamicStitch" -> 1 "DynamicStitch" -> 1
"DynamicPartition" -> "DynamicPartition" ->

View file

@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag) import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput')
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape) import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
import qualified TensorFlow.Output as TF import qualified TensorFlow.Output as TF
@ -596,6 +596,7 @@ transAttrs :: (TF.Attribute a,
transAttrs a b = transAttrs a b =
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b) (TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
-- TODO check gradient with regard to filter also
testConv2DBackpropInputGrad :: Test testConv2DBackpropInputGrad :: Test
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
(dx, shapeDX, shapeX) <- TF.runSession $ do (dx, shapeDX, shapeX) <- TF.runSession $ do
@ -617,6 +618,47 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
shapeX @=? (shapeDX :: V.Vector Int32) shapeX @=? (shapeDX :: V.Vector Int32)
V.fromList [4::Float] @=? (dx :: V.Vector Float) V.fromList [4::Float] @=? (dx :: V.Vector Float)
testDepthwiseConv2dGrad :: Test
testDepthwiseConv2dGrad = testCase "testDepthwiseConv2dGrad" $ do
(dx, shapeDX, shapeX) <- TF.runSession $ do
let conv_input_shape = TF.vector [1, 2, 2, 1 :: Int32]
x <- TF.render $ TF.fill conv_input_shape (TF.scalar (2 :: Float))
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
let y = TF.depthwiseConv2dNative'
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
)
x filter'
[dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x)
shapeX @=? (shapeDX :: V.Vector Int32)
V.fromList [1, 1, 1, 1 :: Float] @=? (dx :: V.Vector Float)
-- TODO also test filter gradient
testDepthwiseConv2dBackpropInputGrad :: Test
testDepthwiseConv2dBackpropInputGrad = testCase "testDepthwiseConv2dBackpropInputGrad" $ do
(dx, shapeDX, shapeX) <- TF.runSession $ do
let conv_input_shape = TF.vector [1, 2, 2, 1 :: Int32]
let conv_out_shape = TF.vector [1, 1, 1, 1 :: Int32] -- [batch, h, w, out_channels]
x <- TF.render $ TF.fill conv_out_shape (TF.scalar (1::Float))
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
let y = TF.depthwiseConv2dNativeBackpropInput'
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
)
conv_input_shape filter' x
[dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x)
shapeX @=? (shapeDX :: V.Vector Int32)
V.fromList [4::Float] @=? (dx :: V.Vector Float)
main :: IO () main :: IO ()
main = defaultMain main = defaultMain
@ -658,4 +700,6 @@ main = defaultMain
, matMulTransposeGradient (True, False) , matMulTransposeGradient (True, False)
, matMulTransposeGradient (True, True) , matMulTransposeGradient (True, True)
, testConv2DBackpropInputGrad , testConv2DBackpropInputGrad
, testDepthwiseConv2dGrad
, testDepthwiseConv2dBackpropInputGrad
] ]