mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Added support for ExpandDims gradient. (#224)
This commit is contained in:
parent
915015018c
commit
95c6b6f277
2 changed files with 35 additions and 2 deletions
|
@ -691,8 +691,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||||
padding = lookupAttr nodeDef "padding" :: ByteString
|
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||||
|
|
||||||
opGrad "Reshape" _ [toT -> x, _] [dz] =
|
opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
|
||||||
[Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
|
opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs
|
||||||
|
|
||||||
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
||||||
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
||||||
|
@ -810,6 +810,7 @@ numOutputs o =
|
||||||
"DynamicPartition" ->
|
"DynamicPartition" ->
|
||||||
fromIntegral (lookupAttr o "num_partitions" :: Int64)
|
fromIntegral (lookupAttr o "num_partitions" :: Int64)
|
||||||
"Exp" -> 1
|
"Exp" -> 1
|
||||||
|
"ExpandDims" -> 1
|
||||||
"Gather" -> 1
|
"Gather" -> 1
|
||||||
"LabelClasses" -> 1
|
"LabelClasses" -> 1
|
||||||
"LabelWeights" -> 1
|
"LabelWeights" -> 1
|
||||||
|
|
|
@ -43,6 +43,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
|
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
|
||||||
|
|
||||||
import qualified Data.ByteString.Char8 as BS
|
import qualified Data.ByteString.Char8 as BS
|
||||||
|
import TensorFlow.Session (SessionT)
|
||||||
|
|
||||||
testGradientSimple :: Test
|
testGradientSimple :: Test
|
||||||
testGradientSimple = testCase "testGradientSimple" $ do
|
testGradientSimple = testCase "testGradientSimple" $ do
|
||||||
|
@ -290,6 +291,35 @@ testTanhGrad = testCase "testTanhGrad" $ do
|
||||||
TF.gradients y [x] >>= TF.run
|
TF.gradients y [x] >>= TF.run
|
||||||
V.fromList [1] @=? dx
|
V.fromList [1] @=? dx
|
||||||
|
|
||||||
|
testExpandDims :: Test
|
||||||
|
testExpandDims =
|
||||||
|
testCase "testExpandDims" $ do
|
||||||
|
([dx], [s]) <-
|
||||||
|
TF.runSession $ do
|
||||||
|
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 3 :: Int64]
|
||||||
|
let y = TF.expandDims x $ TF.constant (TF.Shape [1]) [0 :: Int32]
|
||||||
|
calculateGradWithShape y x
|
||||||
|
V.fromList [1, 1, 1, 1, 1, 1] @=? dx
|
||||||
|
V.fromList [1, 2, 3] @=? s
|
||||||
|
|
||||||
|
testReshape :: Test
|
||||||
|
testReshape =
|
||||||
|
testCase "testReshape" $ do
|
||||||
|
([dx], [s]) <-
|
||||||
|
TF.runSession $ do
|
||||||
|
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 2 :: Int64]
|
||||||
|
let y = TF.reshape x $ TF.constant (TF.Shape [2]) [1, 4 :: Int32]
|
||||||
|
calculateGradWithShape y x
|
||||||
|
V.fromList [1, 1, 1, 1] @=? dx
|
||||||
|
V.fromList [2, 2] @=? s
|
||||||
|
|
||||||
|
calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32])
|
||||||
|
calculateGradWithShape y x = do
|
||||||
|
gs <- TF.gradients y [x]
|
||||||
|
xs <- TF.run gs
|
||||||
|
(shapes :: [V.Vector Int32]) <- mapM (TF.run . TF.shape) gs
|
||||||
|
return (xs, shapes)
|
||||||
|
|
||||||
testFillGrad :: Test
|
testFillGrad :: Test
|
||||||
testFillGrad = testCase "testFillGrad" $ do
|
testFillGrad = testCase "testFillGrad" $ do
|
||||||
[dx] <- TF.runSession $ do
|
[dx] <- TF.runSession $ do
|
||||||
|
@ -436,6 +466,8 @@ main = defaultMain
|
||||||
, testReluGrad
|
, testReluGrad
|
||||||
, testReluGradGrad
|
, testReluGradGrad
|
||||||
, testTanhGrad
|
, testTanhGrad
|
||||||
|
, testExpandDims
|
||||||
|
, testReshape
|
||||||
, testFillGrad
|
, testFillGrad
|
||||||
, testTileGrad
|
, testTileGrad
|
||||||
, testTile2DGrad
|
, testTile2DGrad
|
||||||
|
|
Loading…
Reference in a new issue