Added support for ExpandDims gradient. (#224)

This commit is contained in:
Rik 2018-11-21 03:45:31 +01:00 committed by fkm3
parent 915015018c
commit 95c6b6f277
2 changed files with 35 additions and 2 deletions

View File

@ -691,8 +691,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
padding = lookupAttr nodeDef "padding" :: ByteString
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "Reshape" _ [toT -> x, _] [dz] =
[Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
opGrad "TruncatedNormal" _ _ _ = [Nothing]
@ -810,6 +810,7 @@ numOutputs o =
"DynamicPartition" ->
fromIntegral (lookupAttr o "num_partitions" :: Int64)
"Exp" -> 1
"ExpandDims" -> 1
"Gather" -> 1
"LabelClasses" -> 1
"LabelWeights" -> 1

View File

@ -43,6 +43,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
import qualified Data.ByteString.Char8 as BS
import TensorFlow.Session (SessionT)
testGradientSimple :: Test
testGradientSimple = testCase "testGradientSimple" $ do
@ -290,6 +291,35 @@ testTanhGrad = testCase "testTanhGrad" $ do
TF.gradients y [x] >>= TF.run
V.fromList [1] @=? dx
testExpandDims :: Test
testExpandDims =
testCase "testExpandDims" $ do
([dx], [s]) <-
TF.runSession $ do
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 3 :: Int64]
let y = TF.expandDims x $ TF.constant (TF.Shape [1]) [0 :: Int32]
calculateGradWithShape y x
V.fromList [1, 1, 1, 1, 1, 1] @=? dx
V.fromList [1, 2, 3] @=? s
testReshape :: Test
testReshape =
testCase "testReshape" $ do
([dx], [s]) <-
TF.runSession $ do
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 2 :: Int64]
let y = TF.reshape x $ TF.constant (TF.Shape [2]) [1, 4 :: Int32]
calculateGradWithShape y x
V.fromList [1, 1, 1, 1] @=? dx
V.fromList [2, 2] @=? s
calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32])
calculateGradWithShape y x = do
gs <- TF.gradients y [x]
xs <- TF.run gs
(shapes :: [V.Vector Int32]) <- mapM (TF.run . TF.shape) gs
return (xs, shapes)
testFillGrad :: Test
testFillGrad = testCase "testFillGrad" $ do
[dx] <- TF.runSession $ do
@ -436,6 +466,8 @@ main = defaultMain
, testReluGrad
, testReluGradGrad
, testTanhGrad
, testExpandDims
, testReshape
, testFillGrad
, testTileGrad
, testTile2DGrad