mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-26 18:55:02 +01:00
Adding gradient for Concat (#144)
This commit is contained in:
parent
cac45d1cd6
commit
79d8d7edea
3 changed files with 106 additions and 0 deletions
|
@ -459,6 +459,39 @@ opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
|||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
||||
|
||||
opGrad "Concat" _ _ix [dy]
|
||||
-- Concat concatenates input tensors
|
||||
-- x1 of shape s1 = [k1, ..., ki_1, ..., kn]
|
||||
-- x2 of shape s2 = [k1, ..., ki_2, ..., kn]
|
||||
-- . . . . .
|
||||
-- . . . . .
|
||||
-- . . . . .
|
||||
-- xm of shape sm = [k1, ..., ki_m, ..., kn]
|
||||
-- along dimension i to an output tensor
|
||||
-- y of shape sy = [k1, ..., k, ..., kn]
|
||||
-- where k = sum ki = sum [ki_1,...,ki_m]
|
||||
--
|
||||
-- The incoming gradient dy from backpropagation is
|
||||
-- simply forwarded split across input tensors yielding dx.
|
||||
-- Forwarded gradients have shapes s = [s1, ..., sm].
|
||||
| m == 1 = Nothing : [Just $ expr dy]
|
||||
| otherwise = Nothing : map Just (dx `reshapeZip` s)
|
||||
where
|
||||
reshapeZip = zipWith reshape
|
||||
dx = CoreOps.splitV (fromIntegral m) dy ki _i
|
||||
s :: [Tensor Build Int32]
|
||||
s = map shape x
|
||||
x :: [Tensor Build a]
|
||||
x = map toT $ tail _ix
|
||||
-- i: concat dimension. Adjusted modulo n to handle negative indices.
|
||||
_i = toT (head _ix) `CoreOps.floorMod` n
|
||||
i = reshape _i $ vector [1 :: Int32]
|
||||
-- sizes along concatenated dimension
|
||||
ki :: Tensor Build Int32
|
||||
ki = CoreOps.concat 0 $ map (\t -> CoreOps.slice t i $ vector [1 :: Int32]) s
|
||||
m = length x
|
||||
n = CoreOps.rank (head x)
|
||||
|
||||
opGrad "Square" _ [toT -> x] [dz] =
|
||||
-- TODO(fmayle): Handle complex numbers.
|
||||
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
|
||||
|
@ -744,6 +777,7 @@ numOutputs o =
|
|||
"AddN" -> 1
|
||||
"Cast" -> 1
|
||||
"Const" -> 1
|
||||
"Concat" -> 1
|
||||
"Conv2D" -> 1
|
||||
"Div" -> 1
|
||||
"DynamicStitch" -> 1
|
||||
|
|
|
@ -190,6 +190,7 @@ Test-Suite GradientTest
|
|||
, base
|
||||
, proto-lens
|
||||
, lens-family
|
||||
, random
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (sort)
|
||||
import qualified Data.List as List
|
||||
import Data.ProtoLens.TextFormat (showMessage)
|
||||
import Test.Framework (defaultMain, Test)
|
||||
import Lens.Family2 ((^..), (.~))
|
||||
|
@ -26,6 +27,8 @@ import Lens.Family2 ((^..), (.~))
|
|||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?), assertEqual)
|
||||
import qualified Data.Vector as V
|
||||
import System.Random (randomIO, randomRIO)
|
||||
import Control.Monad(forM_, replicateM, zipWithM)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
|
@ -173,6 +176,71 @@ testMaxGradient = testCase "testMaxGradient" $ do
|
|||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||
|
||||
testConcatGradient :: Test
|
||||
testConcatGradient = testCase "testConcatGradient" $ do
|
||||
[dv,dv'] <- TF.runSession $ do
|
||||
v <- TF.render $ TF.vector [1 :: Float]
|
||||
v' <- TF.render $ TF.vector [2 :: Float]
|
||||
let y = TF.concat (TF.scalar 0) [ v, v' ]
|
||||
TF.gradients y [v,v'] >>= TF.run
|
||||
V.fromList [1 :: Float] @=? dv
|
||||
V.fromList [1 :: Float] @=? dv'
|
||||
[dw,dw'] <- TF.runSession $ do
|
||||
w <- TF.render $ TF.vector [1,2,3,4 :: Float]
|
||||
w' <- TF.render $ TF.vector [5,6,7,8 :: Float]
|
||||
let y = TF.concat (TF.scalar 0) [ w, w', w ]
|
||||
TF.gradients y [w,w'] >>= TF.run
|
||||
V.fromList [2,2,2,2 :: Float] @=? dw
|
||||
V.fromList [1,1,1,1 :: Float] @=? dw'
|
||||
|
||||
verifyConcatGradients :: [[Int64]] -> Int32 -> IO ()
|
||||
verifyConcatGradients shapes concatDim = do
|
||||
let floatsFromShape :: [Int64] -> IO [Float]
|
||||
floatsFromShape shape = replicateM (fromIntegral $ List.product shape) randomIO
|
||||
constantZip = zipWithM $ \x shape -> TF.render $ TF.constant (TF.Shape shape) x
|
||||
inputGrads <- mapM floatsFromShape shapes
|
||||
inputs <- mapM floatsFromShape shapes
|
||||
dinputs <- TF.runSession $ do
|
||||
inputTensors <- inputs `constantZip` shapes
|
||||
inputGradTensors <- inputGrads `constantZip` shapes
|
||||
inputTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputTensors
|
||||
inputGradTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputGradTensors
|
||||
output <- TF.render $ inputTensor `TF.mul` inputGradTensor
|
||||
TF.gradients output inputTensors >>= TF.run
|
||||
(V.fromList <$> inputGrads) @=? dinputs
|
||||
|
||||
-- This test checks that the gradient of a concat op
|
||||
-- is correct along the first, second, and third dimension.
|
||||
testConcatGradientSimple :: Test
|
||||
testConcatGradientSimple = testCase "testConcatGradientSimple" $ do
|
||||
-- The following check is equivalent to ConcatTest._testGradientsSimple from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
verifyConcatGradients [[10,x,2] | x <- [1,2,6]] 1
|
||||
-- The following check is equivalent to ConcatTest._testGradientsFirstDim from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
verifyConcatGradients [[x,10,2] | x <- [1,2,6]] 0
|
||||
-- The following check is equivalent to ConcatTest._testGradientsLastDim from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
verifyConcatGradients [[10,2,x] | x <- [1,2,6]] 2
|
||||
|
||||
|
||||
-- This test checks that the gradient of a concat op
|
||||
-- along a random dimension across random shapes is as expected.
|
||||
-- This test is inspired by ConcatTest._RunAndVerifyGradientsRandom from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py, but also
|
||||
-- verifies the gradient along negative concat dimensions.
|
||||
testConcatRunAndVerifyGradientsRandom :: Test
|
||||
testConcatRunAndVerifyGradientsRandom = testCase "testConcatRunAndVerifyGradientsRandom" $
|
||||
forM_ [1..5 :: Int] $ \_ -> do
|
||||
(shapes' :: [Int64]) <- replicateM 5 $ randomRIO (1, 5)
|
||||
(numTensors :: Int) <- randomRIO (2, 10)
|
||||
(concatDim :: Int) <- randomRIO (-4, 4)
|
||||
(concatDimSizes :: [Int64]) <- replicateM numTensors $ randomRIO (1, 5)
|
||||
let update i xs x = take i xs ++ x: drop (i+1) xs
|
||||
concatDim' = concatDim `mod` length shapes'
|
||||
shapes = map (update concatDim' shapes') concatDimSizes
|
||||
verifyConcatGradients shapes $ fromIntegral concatDim
|
||||
|
||||
-- run single test like this:
|
||||
-- stack --docker --docker-image=$IMAGE_NAME test tensorflow-ops:GradientTest --test-arguments -t"*MaximumGrad*"
|
||||
testMaximumGrad :: Test
|
||||
|
@ -329,6 +397,9 @@ main = defaultMain
|
|||
, testDiamond
|
||||
, testAddNGradient
|
||||
, testMaxGradient
|
||||
, testConcatGradient
|
||||
, testConcatGradientSimple
|
||||
, testConcatRunAndVerifyGradientsRandom
|
||||
, testMaximumGrad
|
||||
, testMaximumGradGrad
|
||||
, testReluGrad
|
||||
|
|
Loading…
Add table
Reference in a new issue