diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 34a9dea..996b0bc 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -691,8 +691,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] = padding = lookupAttr nodeDef "padding" :: ByteString dataFormat = lookupAttr nodeDef "data_format" :: ByteString -opGrad "Reshape" _ [toT -> x, _] [dz] = - [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing] +opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing] +opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing] opGrad "TruncatedNormal" _ _ _ = [Nothing] @@ -810,6 +810,7 @@ numOutputs o = "DynamicPartition" -> fromIntegral (lookupAttr o "num_partitions" :: Int64) "Exp" -> 1 + "ExpandDims" -> 1 "Gather" -> 1 "LabelClasses" -> 1 "LabelWeights" -> 1 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index b01971b..8659e90 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -43,6 +43,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node) import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op) import qualified Data.ByteString.Char8 as BS +import TensorFlow.Session (SessionT) testGradientSimple :: Test testGradientSimple = testCase "testGradientSimple" $ do @@ -290,6 +291,35 @@ testTanhGrad = testCase "testTanhGrad" $ do TF.gradients y [x] >>= TF.run V.fromList [1] @=? dx +testExpandDims :: Test +testExpandDims = + testCase "testExpandDims" $ do + ([dx], [s]) <- + TF.runSession $ do + (x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 3 :: Int64] + let y = TF.expandDims x $ TF.constant (TF.Shape [1]) [0 :: Int32] + calculateGradWithShape y x + V.fromList [1, 1, 1, 1, 1, 1] @=? dx + V.fromList [1, 2, 3] @=? s + +testReshape :: Test +testReshape = + testCase "testReshape" $ do + ([dx], [s]) <- + TF.runSession $ do + (x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 2 :: Int64] + let y = TF.reshape x $ TF.constant (TF.Shape [2]) [1, 4 :: Int32] + calculateGradWithShape y x + V.fromList [1, 1, 1, 1] @=? dx + V.fromList [2, 2] @=? s + +calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32]) +calculateGradWithShape y x = do + gs <- TF.gradients y [x] + xs <- TF.run gs + (shapes :: [V.Vector Int32]) <- mapM (TF.run . TF.shape) gs + return (xs, shapes) + testFillGrad :: Test testFillGrad = testCase "testFillGrad" $ do [dx] <- TF.runSession $ do @@ -436,6 +466,8 @@ main = defaultMain , testReluGrad , testReluGradGrad , testTanhGrad + , testExpandDims + , testReshape , testFillGrad , testTileGrad , testTile2DGrad