1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-01-26 18:55:02 +01:00

Adding gradient for Concat (#144)

This commit is contained in:
Jonathan Kochems 2017-07-30 04:29:33 +01:00 committed by fkm3
parent cac45d1cd6
commit 79d8d7edea
3 changed files with 106 additions and 0 deletions

View file

@ -459,6 +459,39 @@ opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
opGrad "Concat" _ _ix [dy]
-- Concat concatenates input tensors
-- x1 of shape s1 = [k1, ..., ki_1, ..., kn]
-- x2 of shape s2 = [k1, ..., ki_2, ..., kn]
-- . . . . .
-- . . . . .
-- . . . . .
-- xm of shape sm = [k1, ..., ki_m, ..., kn]
-- along dimension i to an output tensor
-- y of shape sy = [k1, ..., k, ..., kn]
-- where k = sum ki = sum [ki_1,...,ki_m]
--
-- The incoming gradient dy from backpropagation is
-- simply forwarded split across input tensors yielding dx.
-- Forwarded gradients have shapes s = [s1, ..., sm].
| m == 1 = Nothing : [Just $ expr dy]
| otherwise = Nothing : map Just (dx `reshapeZip` s)
where
reshapeZip = zipWith reshape
dx = CoreOps.splitV (fromIntegral m) dy ki _i
s :: [Tensor Build Int32]
s = map shape x
x :: [Tensor Build a]
x = map toT $ tail _ix
-- i: concat dimension. Adjusted modulo n to handle negative indices.
_i = toT (head _ix) `CoreOps.floorMod` n
i = reshape _i $ vector [1 :: Int32]
-- sizes along concatenated dimension
ki :: Tensor Build Int32
ki = CoreOps.concat 0 $ map (\t -> CoreOps.slice t i $ vector [1 :: Int32]) s
m = length x
n = CoreOps.rank (head x)
opGrad "Square" _ [toT -> x] [dz] =
-- TODO(fmayle): Handle complex numbers.
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
@ -744,6 +777,7 @@ numOutputs o =
"AddN" -> 1
"Cast" -> 1
"Const" -> 1
"Concat" -> 1
"Conv2D" -> 1
"Div" -> 1
"DynamicStitch" -> 1

View file

@ -190,6 +190,7 @@ Test-Suite GradientTest
, base
, proto-lens
, lens-family
, random
, tensorflow
, tensorflow-core-ops
, tensorflow-ops

View file

@ -19,6 +19,7 @@
import Data.Int (Int32, Int64)
import Data.List (sort)
import qualified Data.List as List
import Data.ProtoLens.TextFormat (showMessage)
import Test.Framework (defaultMain, Test)
import Lens.Family2 ((^..), (.~))
@ -26,6 +27,8 @@ import Lens.Family2 ((^..), (.~))
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), assertEqual)
import qualified Data.Vector as V
import System.Random (randomIO, randomRIO)
import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF
@ -173,6 +176,71 @@ testMaxGradient = testCase "testMaxGradient" $ do
TF.gradients y [x] >>= TF.run
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
testConcatGradient :: Test
testConcatGradient = testCase "testConcatGradient" $ do
[dv,dv'] <- TF.runSession $ do
v <- TF.render $ TF.vector [1 :: Float]
v' <- TF.render $ TF.vector [2 :: Float]
let y = TF.concat (TF.scalar 0) [ v, v' ]
TF.gradients y [v,v'] >>= TF.run
V.fromList [1 :: Float] @=? dv
V.fromList [1 :: Float] @=? dv'
[dw,dw'] <- TF.runSession $ do
w <- TF.render $ TF.vector [1,2,3,4 :: Float]
w' <- TF.render $ TF.vector [5,6,7,8 :: Float]
let y = TF.concat (TF.scalar 0) [ w, w', w ]
TF.gradients y [w,w'] >>= TF.run
V.fromList [2,2,2,2 :: Float] @=? dw
V.fromList [1,1,1,1 :: Float] @=? dw'
verifyConcatGradients :: [[Int64]] -> Int32 -> IO ()
verifyConcatGradients shapes concatDim = do
let floatsFromShape :: [Int64] -> IO [Float]
floatsFromShape shape = replicateM (fromIntegral $ List.product shape) randomIO
constantZip = zipWithM $ \x shape -> TF.render $ TF.constant (TF.Shape shape) x
inputGrads <- mapM floatsFromShape shapes
inputs <- mapM floatsFromShape shapes
dinputs <- TF.runSession $ do
inputTensors <- inputs `constantZip` shapes
inputGradTensors <- inputGrads `constantZip` shapes
inputTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputTensors
inputGradTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputGradTensors
output <- TF.render $ inputTensor `TF.mul` inputGradTensor
TF.gradients output inputTensors >>= TF.run
(V.fromList <$> inputGrads) @=? dinputs
-- This test checks that the gradient of a concat op
-- is correct along the first, second, and third dimension.
testConcatGradientSimple :: Test
testConcatGradientSimple = testCase "testConcatGradientSimple" $ do
-- The following check is equivalent to ConcatTest._testGradientsSimple from
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
verifyConcatGradients [[10,x,2] | x <- [1,2,6]] 1
-- The following check is equivalent to ConcatTest._testGradientsFirstDim from
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
verifyConcatGradients [[x,10,2] | x <- [1,2,6]] 0
-- The following check is equivalent to ConcatTest._testGradientsLastDim from
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
verifyConcatGradients [[10,2,x] | x <- [1,2,6]] 2
-- This test checks that the gradient of a concat op
-- along a random dimension across random shapes is as expected.
-- This test is inspired by ConcatTest._RunAndVerifyGradientsRandom from
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py, but also
-- verifies the gradient along negative concat dimensions.
testConcatRunAndVerifyGradientsRandom :: Test
testConcatRunAndVerifyGradientsRandom = testCase "testConcatRunAndVerifyGradientsRandom" $
forM_ [1..5 :: Int] $ \_ -> do
(shapes' :: [Int64]) <- replicateM 5 $ randomRIO (1, 5)
(numTensors :: Int) <- randomRIO (2, 10)
(concatDim :: Int) <- randomRIO (-4, 4)
(concatDimSizes :: [Int64]) <- replicateM numTensors $ randomRIO (1, 5)
let update i xs x = take i xs ++ x: drop (i+1) xs
concatDim' = concatDim `mod` length shapes'
shapes = map (update concatDim' shapes') concatDimSizes
verifyConcatGradients shapes $ fromIntegral concatDim
-- run single test like this:
-- stack --docker --docker-image=$IMAGE_NAME test tensorflow-ops:GradientTest --test-arguments -t"*MaximumGrad*"
testMaximumGrad :: Test
@ -329,6 +397,9 @@ main = defaultMain
, testDiamond
, testAddNGradient
, testMaxGradient
, testConcatGradient
, testConcatGradientSimple
, testConcatRunAndVerifyGradientsRandom
, testMaximumGrad
, testMaximumGradGrad
, testReluGrad