mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
Adding gradient for Concat (#144)
This commit is contained in:
parent
cac45d1cd6
commit
79d8d7edea
3 changed files with 106 additions and 0 deletions
|
@ -459,6 +459,39 @@ opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
||||||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||||
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
||||||
|
|
||||||
|
opGrad "Concat" _ _ix [dy]
|
||||||
|
-- Concat concatenates input tensors
|
||||||
|
-- x1 of shape s1 = [k1, ..., ki_1, ..., kn]
|
||||||
|
-- x2 of shape s2 = [k1, ..., ki_2, ..., kn]
|
||||||
|
-- . . . . .
|
||||||
|
-- . . . . .
|
||||||
|
-- . . . . .
|
||||||
|
-- xm of shape sm = [k1, ..., ki_m, ..., kn]
|
||||||
|
-- along dimension i to an output tensor
|
||||||
|
-- y of shape sy = [k1, ..., k, ..., kn]
|
||||||
|
-- where k = sum ki = sum [ki_1,...,ki_m]
|
||||||
|
--
|
||||||
|
-- The incoming gradient dy from backpropagation is
|
||||||
|
-- simply forwarded split across input tensors yielding dx.
|
||||||
|
-- Forwarded gradients have shapes s = [s1, ..., sm].
|
||||||
|
| m == 1 = Nothing : [Just $ expr dy]
|
||||||
|
| otherwise = Nothing : map Just (dx `reshapeZip` s)
|
||||||
|
where
|
||||||
|
reshapeZip = zipWith reshape
|
||||||
|
dx = CoreOps.splitV (fromIntegral m) dy ki _i
|
||||||
|
s :: [Tensor Build Int32]
|
||||||
|
s = map shape x
|
||||||
|
x :: [Tensor Build a]
|
||||||
|
x = map toT $ tail _ix
|
||||||
|
-- i: concat dimension. Adjusted modulo n to handle negative indices.
|
||||||
|
_i = toT (head _ix) `CoreOps.floorMod` n
|
||||||
|
i = reshape _i $ vector [1 :: Int32]
|
||||||
|
-- sizes along concatenated dimension
|
||||||
|
ki :: Tensor Build Int32
|
||||||
|
ki = CoreOps.concat 0 $ map (\t -> CoreOps.slice t i $ vector [1 :: Int32]) s
|
||||||
|
m = length x
|
||||||
|
n = CoreOps.rank (head x)
|
||||||
|
|
||||||
opGrad "Square" _ [toT -> x] [dz] =
|
opGrad "Square" _ [toT -> x] [dz] =
|
||||||
-- TODO(fmayle): Handle complex numbers.
|
-- TODO(fmayle): Handle complex numbers.
|
||||||
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
|
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
|
||||||
|
@ -744,6 +777,7 @@ numOutputs o =
|
||||||
"AddN" -> 1
|
"AddN" -> 1
|
||||||
"Cast" -> 1
|
"Cast" -> 1
|
||||||
"Const" -> 1
|
"Const" -> 1
|
||||||
|
"Concat" -> 1
|
||||||
"Conv2D" -> 1
|
"Conv2D" -> 1
|
||||||
"Div" -> 1
|
"Div" -> 1
|
||||||
"DynamicStitch" -> 1
|
"DynamicStitch" -> 1
|
||||||
|
|
|
@ -190,6 +190,7 @@ Test-Suite GradientTest
|
||||||
, base
|
, base
|
||||||
, proto-lens
|
, proto-lens
|
||||||
, lens-family
|
, lens-family
|
||||||
|
, random
|
||||||
, tensorflow
|
, tensorflow
|
||||||
, tensorflow-core-ops
|
, tensorflow-core-ops
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.List (sort)
|
import Data.List (sort)
|
||||||
|
import qualified Data.List as List
|
||||||
import Data.ProtoLens.TextFormat (showMessage)
|
import Data.ProtoLens.TextFormat (showMessage)
|
||||||
import Test.Framework (defaultMain, Test)
|
import Test.Framework (defaultMain, Test)
|
||||||
import Lens.Family2 ((^..), (.~))
|
import Lens.Family2 ((^..), (.~))
|
||||||
|
@ -26,6 +27,8 @@ import Lens.Family2 ((^..), (.~))
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?), assertEqual)
|
import Test.HUnit ((@=?), assertEqual)
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
import System.Random (randomIO, randomRIO)
|
||||||
|
import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
|
@ -173,6 +176,71 @@ testMaxGradient = testCase "testMaxGradient" $ do
|
||||||
TF.gradients y [x] >>= TF.run
|
TF.gradients y [x] >>= TF.run
|
||||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||||
|
|
||||||
|
testConcatGradient :: Test
|
||||||
|
testConcatGradient = testCase "testConcatGradient" $ do
|
||||||
|
[dv,dv'] <- TF.runSession $ do
|
||||||
|
v <- TF.render $ TF.vector [1 :: Float]
|
||||||
|
v' <- TF.render $ TF.vector [2 :: Float]
|
||||||
|
let y = TF.concat (TF.scalar 0) [ v, v' ]
|
||||||
|
TF.gradients y [v,v'] >>= TF.run
|
||||||
|
V.fromList [1 :: Float] @=? dv
|
||||||
|
V.fromList [1 :: Float] @=? dv'
|
||||||
|
[dw,dw'] <- TF.runSession $ do
|
||||||
|
w <- TF.render $ TF.vector [1,2,3,4 :: Float]
|
||||||
|
w' <- TF.render $ TF.vector [5,6,7,8 :: Float]
|
||||||
|
let y = TF.concat (TF.scalar 0) [ w, w', w ]
|
||||||
|
TF.gradients y [w,w'] >>= TF.run
|
||||||
|
V.fromList [2,2,2,2 :: Float] @=? dw
|
||||||
|
V.fromList [1,1,1,1 :: Float] @=? dw'
|
||||||
|
|
||||||
|
verifyConcatGradients :: [[Int64]] -> Int32 -> IO ()
|
||||||
|
verifyConcatGradients shapes concatDim = do
|
||||||
|
let floatsFromShape :: [Int64] -> IO [Float]
|
||||||
|
floatsFromShape shape = replicateM (fromIntegral $ List.product shape) randomIO
|
||||||
|
constantZip = zipWithM $ \x shape -> TF.render $ TF.constant (TF.Shape shape) x
|
||||||
|
inputGrads <- mapM floatsFromShape shapes
|
||||||
|
inputs <- mapM floatsFromShape shapes
|
||||||
|
dinputs <- TF.runSession $ do
|
||||||
|
inputTensors <- inputs `constantZip` shapes
|
||||||
|
inputGradTensors <- inputGrads `constantZip` shapes
|
||||||
|
inputTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputTensors
|
||||||
|
inputGradTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputGradTensors
|
||||||
|
output <- TF.render $ inputTensor `TF.mul` inputGradTensor
|
||||||
|
TF.gradients output inputTensors >>= TF.run
|
||||||
|
(V.fromList <$> inputGrads) @=? dinputs
|
||||||
|
|
||||||
|
-- This test checks that the gradient of a concat op
|
||||||
|
-- is correct along the first, second, and third dimension.
|
||||||
|
testConcatGradientSimple :: Test
|
||||||
|
testConcatGradientSimple = testCase "testConcatGradientSimple" $ do
|
||||||
|
-- The following check is equivalent to ConcatTest._testGradientsSimple from
|
||||||
|
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||||
|
verifyConcatGradients [[10,x,2] | x <- [1,2,6]] 1
|
||||||
|
-- The following check is equivalent to ConcatTest._testGradientsFirstDim from
|
||||||
|
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||||
|
verifyConcatGradients [[x,10,2] | x <- [1,2,6]] 0
|
||||||
|
-- The following check is equivalent to ConcatTest._testGradientsLastDim from
|
||||||
|
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||||
|
verifyConcatGradients [[10,2,x] | x <- [1,2,6]] 2
|
||||||
|
|
||||||
|
|
||||||
|
-- This test checks that the gradient of a concat op
|
||||||
|
-- along a random dimension across random shapes is as expected.
|
||||||
|
-- This test is inspired by ConcatTest._RunAndVerifyGradientsRandom from
|
||||||
|
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py, but also
|
||||||
|
-- verifies the gradient along negative concat dimensions.
|
||||||
|
testConcatRunAndVerifyGradientsRandom :: Test
|
||||||
|
testConcatRunAndVerifyGradientsRandom = testCase "testConcatRunAndVerifyGradientsRandom" $
|
||||||
|
forM_ [1..5 :: Int] $ \_ -> do
|
||||||
|
(shapes' :: [Int64]) <- replicateM 5 $ randomRIO (1, 5)
|
||||||
|
(numTensors :: Int) <- randomRIO (2, 10)
|
||||||
|
(concatDim :: Int) <- randomRIO (-4, 4)
|
||||||
|
(concatDimSizes :: [Int64]) <- replicateM numTensors $ randomRIO (1, 5)
|
||||||
|
let update i xs x = take i xs ++ x: drop (i+1) xs
|
||||||
|
concatDim' = concatDim `mod` length shapes'
|
||||||
|
shapes = map (update concatDim' shapes') concatDimSizes
|
||||||
|
verifyConcatGradients shapes $ fromIntegral concatDim
|
||||||
|
|
||||||
-- run single test like this:
|
-- run single test like this:
|
||||||
-- stack --docker --docker-image=$IMAGE_NAME test tensorflow-ops:GradientTest --test-arguments -t"*MaximumGrad*"
|
-- stack --docker --docker-image=$IMAGE_NAME test tensorflow-ops:GradientTest --test-arguments -t"*MaximumGrad*"
|
||||||
testMaximumGrad :: Test
|
testMaximumGrad :: Test
|
||||||
|
@ -329,6 +397,9 @@ main = defaultMain
|
||||||
, testDiamond
|
, testDiamond
|
||||||
, testAddNGradient
|
, testAddNGradient
|
||||||
, testMaxGradient
|
, testMaxGradient
|
||||||
|
, testConcatGradient
|
||||||
|
, testConcatGradientSimple
|
||||||
|
, testConcatRunAndVerifyGradientsRandom
|
||||||
, testMaximumGrad
|
, testMaximumGrad
|
||||||
, testMaximumGradGrad
|
, testMaximumGradGrad
|
||||||
, testReluGrad
|
, testReluGrad
|
||||||
|
|
Loading…
Reference in a new issue