mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-05 02:29:41 +01:00
Added support for ExpandDims gradient. (#224)
This commit is contained in:
parent
915015018c
commit
95c6b6f277
2 changed files with 35 additions and 2 deletions
|
@ -691,8 +691,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
|||
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||
|
||||
opGrad "Reshape" _ [toT -> x, _] [dz] =
|
||||
[Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
|
||||
opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
|
||||
opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs
|
||||
|
||||
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
||||
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
||||
|
@ -810,6 +810,7 @@ numOutputs o =
|
|||
"DynamicPartition" ->
|
||||
fromIntegral (lookupAttr o "num_partitions" :: Int64)
|
||||
"Exp" -> 1
|
||||
"ExpandDims" -> 1
|
||||
"Gather" -> 1
|
||||
"LabelClasses" -> 1
|
||||
"LabelWeights" -> 1
|
||||
|
|
|
@ -43,6 +43,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
|
|||
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
|
||||
|
||||
import qualified Data.ByteString.Char8 as BS
|
||||
import TensorFlow.Session (SessionT)
|
||||
|
||||
testGradientSimple :: Test
|
||||
testGradientSimple = testCase "testGradientSimple" $ do
|
||||
|
@ -290,6 +291,35 @@ testTanhGrad = testCase "testTanhGrad" $ do
|
|||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [1] @=? dx
|
||||
|
||||
testExpandDims :: Test
|
||||
testExpandDims =
|
||||
testCase "testExpandDims" $ do
|
||||
([dx], [s]) <-
|
||||
TF.runSession $ do
|
||||
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 3 :: Int64]
|
||||
let y = TF.expandDims x $ TF.constant (TF.Shape [1]) [0 :: Int32]
|
||||
calculateGradWithShape y x
|
||||
V.fromList [1, 1, 1, 1, 1, 1] @=? dx
|
||||
V.fromList [1, 2, 3] @=? s
|
||||
|
||||
testReshape :: Test
|
||||
testReshape =
|
||||
testCase "testReshape" $ do
|
||||
([dx], [s]) <-
|
||||
TF.runSession $ do
|
||||
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 2 :: Int64]
|
||||
let y = TF.reshape x $ TF.constant (TF.Shape [2]) [1, 4 :: Int32]
|
||||
calculateGradWithShape y x
|
||||
V.fromList [1, 1, 1, 1] @=? dx
|
||||
V.fromList [2, 2] @=? s
|
||||
|
||||
calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32])
|
||||
calculateGradWithShape y x = do
|
||||
gs <- TF.gradients y [x]
|
||||
xs <- TF.run gs
|
||||
(shapes :: [V.Vector Int32]) <- mapM (TF.run . TF.shape) gs
|
||||
return (xs, shapes)
|
||||
|
||||
testFillGrad :: Test
|
||||
testFillGrad = testCase "testFillGrad" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
|
@ -436,6 +466,8 @@ main = defaultMain
|
|||
, testReluGrad
|
||||
, testReluGradGrad
|
||||
, testTanhGrad
|
||||
, testExpandDims
|
||||
, testReshape
|
||||
, testFillGrad
|
||||
, testTileGrad
|
||||
, testTile2DGrad
|
||||
|
|
Loading…
Reference in a new issue