mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Add gradients for DepthwiseConv2dNative (#240)
This commit is contained in:
parent
4a2e46ba57
commit
1fbd5d41dd
2 changed files with 82 additions and 1 deletions
|
@ -694,6 +694,41 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
|
||||||
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
|
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
|
||||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||||
|
|
||||||
|
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
|
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
|
||||||
|
((opAttr "strides" .~ strides)
|
||||||
|
. (opAttr "padding" .~ padding)
|
||||||
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
|
(shape x) y dz
|
||||||
|
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
||||||
|
((opAttr "strides" .~ strides)
|
||||||
|
. (opAttr "padding" .~ padding)
|
||||||
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
|
x (shape y) dz
|
||||||
|
]
|
||||||
|
where
|
||||||
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||||
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||||
|
|
||||||
|
opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
|
||||||
|
[ Nothing
|
||||||
|
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
||||||
|
((opAttr "strides" .~ strides)
|
||||||
|
. (opAttr "padding" .~ padding)
|
||||||
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
|
dz (shape x) y
|
||||||
|
, Just $ CoreOps.depthwiseConv2dNative'
|
||||||
|
((opAttr "strides" .~ strides)
|
||||||
|
. (opAttr "padding" .~ padding)
|
||||||
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
|
dz x
|
||||||
|
]
|
||||||
|
where
|
||||||
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||||
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||||
|
|
||||||
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||||
[ Just $ CoreOps.maxPoolGrad'
|
[ Just $ CoreOps.maxPoolGrad'
|
||||||
((opAttr "ksize" .~ ksize)
|
((opAttr "ksize" .~ ksize)
|
||||||
|
@ -882,6 +917,8 @@ numOutputs o =
|
||||||
"Concat" -> 1
|
"Concat" -> 1
|
||||||
"Conv2D" -> 1
|
"Conv2D" -> 1
|
||||||
"Conv2DBackpropInput" -> 1
|
"Conv2DBackpropInput" -> 1
|
||||||
|
"DepthwiseConv2dNative" -> 1
|
||||||
|
"DepthwiseConv2dNativeBackpropInput" -> 1
|
||||||
"Div" -> 1
|
"Div" -> 1
|
||||||
"DynamicStitch" -> 1
|
"DynamicStitch" -> 1
|
||||||
"DynamicPartition" ->
|
"DynamicPartition" ->
|
||||||
|
|
|
@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag)
|
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput')
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||||
import qualified TensorFlow.Output as TF
|
import qualified TensorFlow.Output as TF
|
||||||
|
@ -596,6 +596,7 @@ transAttrs :: (TF.Attribute a,
|
||||||
transAttrs a b =
|
transAttrs a b =
|
||||||
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
|
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
|
||||||
|
|
||||||
|
-- TODO check gradient with regard to filter also
|
||||||
testConv2DBackpropInputGrad :: Test
|
testConv2DBackpropInputGrad :: Test
|
||||||
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
||||||
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
||||||
|
@ -617,6 +618,47 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
||||||
shapeX @=? (shapeDX :: V.Vector Int32)
|
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||||
V.fromList [4::Float] @=? (dx :: V.Vector Float)
|
V.fromList [4::Float] @=? (dx :: V.Vector Float)
|
||||||
|
|
||||||
|
testDepthwiseConv2dGrad :: Test
|
||||||
|
testDepthwiseConv2dGrad = testCase "testDepthwiseConv2dGrad" $ do
|
||||||
|
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
||||||
|
let conv_input_shape = TF.vector [1, 2, 2, 1 :: Int32]
|
||||||
|
x <- TF.render $ TF.fill conv_input_shape (TF.scalar (2 :: Float))
|
||||||
|
|
||||||
|
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
|
||||||
|
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
||||||
|
let y = TF.depthwiseConv2dNative'
|
||||||
|
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
||||||
|
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
||||||
|
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
||||||
|
)
|
||||||
|
x filter'
|
||||||
|
|
||||||
|
[dx] <- TF.gradients y [x]
|
||||||
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
|
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||||
|
V.fromList [1, 1, 1, 1 :: Float] @=? (dx :: V.Vector Float)
|
||||||
|
|
||||||
|
-- TODO also test filter gradient
|
||||||
|
testDepthwiseConv2dBackpropInputGrad :: Test
|
||||||
|
testDepthwiseConv2dBackpropInputGrad = testCase "testDepthwiseConv2dBackpropInputGrad" $ do
|
||||||
|
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
||||||
|
let conv_input_shape = TF.vector [1, 2, 2, 1 :: Int32]
|
||||||
|
let conv_out_shape = TF.vector [1, 1, 1, 1 :: Int32] -- [batch, h, w, out_channels]
|
||||||
|
x <- TF.render $ TF.fill conv_out_shape (TF.scalar (1::Float))
|
||||||
|
|
||||||
|
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
|
||||||
|
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
||||||
|
let y = TF.depthwiseConv2dNativeBackpropInput'
|
||||||
|
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
||||||
|
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
||||||
|
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
||||||
|
)
|
||||||
|
conv_input_shape filter' x
|
||||||
|
|
||||||
|
[dx] <- TF.gradients y [x]
|
||||||
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
|
shapeX @=? (shapeDX :: V.Vector Int32)
|
||||||
|
V.fromList [4::Float] @=? (dx :: V.Vector Float)
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = defaultMain
|
main = defaultMain
|
||||||
|
@ -658,4 +700,6 @@ main = defaultMain
|
||||||
, matMulTransposeGradient (True, False)
|
, matMulTransposeGradient (True, False)
|
||||||
, matMulTransposeGradient (True, True)
|
, matMulTransposeGradient (True, True)
|
||||||
, testConv2DBackpropInputGrad
|
, testConv2DBackpropInputGrad
|
||||||
|
, testDepthwiseConv2dGrad
|
||||||
|
, testDepthwiseConv2dBackpropInputGrad
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue