mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Handling negative indices and amendments
This commit is contained in:
parent
41b47e4ce6
commit
23dda410fb
|
@ -70,7 +70,6 @@ import TensorFlow.BuildOp
|
|||
import TensorFlow.Ops
|
||||
( addN
|
||||
, broadcastGradientArgs
|
||||
, constant
|
||||
, expandDims
|
||||
, fill
|
||||
, matMul
|
||||
|
@ -103,7 +102,7 @@ import TensorFlow.Tensor
|
|||
, renderValue
|
||||
, ToTensor(..)
|
||||
)
|
||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens, Shape(..))
|
||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
(NodeDef, attr, input, op, name)
|
||||
|
||||
|
@ -444,35 +443,38 @@ opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
|||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
|
||||
|
||||
-- Concat concatenates input tensors
|
||||
-- x1 of shape s1 = [d1, ..., di_1, ..., dn]
|
||||
-- x2 of shape s2 = [d1, ..., di_2, ..., dn]
|
||||
-- . . . . .
|
||||
-- . . . . .
|
||||
-- . . . . .
|
||||
-- xm of shape sm = [d1, ..., di_m, ..., dn]
|
||||
-- along dimension i to an output tensor
|
||||
-- y of shape sy = [d1, ..., d, ..., dn]
|
||||
-- where d = sum di = sum [di_1,...,di_m]
|
||||
--
|
||||
-- The incoming gradient from backpropagation is
|
||||
-- simply forwarded split across input tensors.
|
||||
-- Forwarded gradients have shapes s = [s1, ..., sm].
|
||||
opGrad "Concat" _ _ix [dy]
|
||||
| length x == 1 = Nothing : [Just $ expr dy]
|
||||
| otherwise = Nothing : map Just (dx `reshapeZip` s)
|
||||
where x :: [Tensor Build a]
|
||||
x = map toT $ tail _ix
|
||||
_i = toT $ head _ix
|
||||
i = reshape _i one
|
||||
m = length x
|
||||
s :: [Tensor Build Int32]
|
||||
s = map shape x
|
||||
di :: Tensor Build Int32
|
||||
di = CoreOps.concat (scalar 0) $ map (\t -> CoreOps.slice t i one) s
|
||||
dx = CoreOps.splitV (fromIntegral m) dy di _i
|
||||
reshapeZip = zipWith reshape
|
||||
one = constant (Shape [1 :: Int64]) [1 :: Int32]
|
||||
-- Concat concatenates input tensors
|
||||
-- x1 of shape s1 = [k1, ..., ki_1, ..., kn]
|
||||
-- x2 of shape s2 = [k1, ..., ki_2, ..., kn]
|
||||
-- . . . . .
|
||||
-- . . . . .
|
||||
-- . . . . .
|
||||
-- xm of shape sm = [k1, ..., ki_m, ..., kn]
|
||||
-- along dimension i to an output tensor
|
||||
-- y of shape sy = [k1, ..., k, ..., kn]
|
||||
-- where k = sum ki = sum [ki_1,...,ki_m]
|
||||
--
|
||||
-- The incoming gradient dy from backpropagation is
|
||||
-- simply forwarded split across input tensors yielding dx.
|
||||
-- Forwarded gradients have shapes s = [s1, ..., sm].
|
||||
| m == 1 = Nothing : [Just $ expr dy]
|
||||
| otherwise = Nothing : map Just (dx `reshapeZip` s)
|
||||
where
|
||||
reshapeZip = zipWith reshape
|
||||
dx = CoreOps.splitV (fromIntegral m) dy ki _i
|
||||
s :: [Tensor Build Int32]
|
||||
s = map shape x
|
||||
x :: [Tensor Build a]
|
||||
x = map toT $ tail _ix
|
||||
-- i: concat dimension. Adjusted modulo n to handle negative indices.
|
||||
_i = toT (head _ix) `CoreOps.floorMod` n
|
||||
i = reshape _i $ vector [1 :: Int32]
|
||||
-- sizes along concatenated dimension
|
||||
ki :: Tensor Build Int32
|
||||
ki = CoreOps.concat 0 $ map (\t -> CoreOps.slice t i $ vector [1 :: Int32]) s
|
||||
m = length x
|
||||
n = CoreOps.rank (head x)
|
||||
|
||||
opGrad "Square" _ [toT -> x] [dz] =
|
||||
-- TODO(fmayle): Handle complex numbers.
|
||||
|
|
|
@ -28,7 +28,7 @@ import Test.Framework.Providers.HUnit (testCase)
|
|||
import Test.HUnit ((@=?), assertEqual)
|
||||
import qualified Data.Vector as V
|
||||
import System.Random (randomIO, randomRIO)
|
||||
import Control.Monad(forM, forM_, replicateM)
|
||||
import Control.Monad(forM_, replicateM, zipWithM)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
|
@ -184,100 +184,53 @@ testConcatGradient = testCase "testConcatGradient" $ do
|
|||
V.fromList [2,2,2,2 :: Float] @=? dw
|
||||
V.fromList [1,1,1,1 :: Float] @=? dw'
|
||||
|
||||
verifyConcatGradients :: [[Int64]] -> Int32 -> IO ()
|
||||
verifyConcatGradients shapes concatDim = do
|
||||
let floatsFromShape :: [Int64] -> IO [Float]
|
||||
floatsFromShape shape = replicateM (fromIntegral $ List.product shape) randomIO
|
||||
constantZip = zipWithM $ \x shape -> TF.render $ TF.constant (TF.Shape shape) x
|
||||
inputGrads <- mapM floatsFromShape shapes
|
||||
inputs <- mapM floatsFromShape shapes
|
||||
dinputs <- TF.runSession $ do
|
||||
inputTensors <- inputs `constantZip` shapes
|
||||
inputGradTensors <- inputGrads `constantZip` shapes
|
||||
inputTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputTensors
|
||||
inputGradTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputGradTensors
|
||||
output <- TF.render $ inputTensor `TF.mul` inputGradTensor
|
||||
TF.gradients output inputTensors >>= TF.run
|
||||
(V.fromList <$> inputGrads) @=? dinputs
|
||||
|
||||
-- This test checks that the gradient of a concat op
|
||||
-- along the second dimension is as expected.
|
||||
-- This test is a port of ConcatTest._testGradientsSimple from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
-- is correct along the first, second, and third dimension.
|
||||
testConcatGradientSimple :: Test
|
||||
testConcatGradientSimple = testCase "testConcatGradientSimple" $ do
|
||||
let shapes = [[10,x,2] | x <- [1,2,6]]
|
||||
(inputGrads :: [[Float]]) <- forM shapes $ \shape ->
|
||||
replicateM (List.product shape) randomIO
|
||||
(inputs :: [[Float]]) <- forM shapes $ \shape ->
|
||||
replicateM (List.product shape) randomIO
|
||||
dinputs <- TF.runSession $ do
|
||||
inputTensors <- forM (inputs `zip` shapes) $ \(input,shape) ->
|
||||
TF.render $ TF.constant (TF.Shape shape) input
|
||||
inputGradTensors <- forM (inputGrads `zip` shapes) $ \(inputGrad, shape) ->
|
||||
TF.render $ TF.constant (TF.Shape shape) inputGrad
|
||||
inputGradTensor <- TF.render $ TF.concat (TF.scalar 1) inputGradTensors
|
||||
inputTensor <- TF.render $ TF.concat (TF.scalar 1) inputTensors
|
||||
output <- TF.render $ inputTensor `TF.mul` inputGradTensor
|
||||
TF.gradients output inputTensors >>= TF.run
|
||||
(V.fromList <$> inputGrads) @=? dinputs
|
||||
|
||||
-- This test checks that the gradient of a concat op
|
||||
-- along the first dimension is as expected.
|
||||
-- This test is a port of ConcatTest._testGradientsFirstDim from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
testConcatGradientFirstDim :: Test
|
||||
testConcatGradientFirstDim = testCase "testConcatGradientFirstDim" $ do
|
||||
let shapes = [[x,10,2] | x <- [1,2,6]]
|
||||
(inputGrads :: [[Float]]) <- forM shapes $ \shape ->
|
||||
replicateM (List.product shape) randomIO
|
||||
(inputs :: [[Float]]) <- forM shapes $ \shape ->
|
||||
replicateM (List.product shape) randomIO
|
||||
dinputs <- TF.runSession $ do
|
||||
inputTensors <- forM (inputs `zip` shapes) $ \(input,shape) ->
|
||||
TF.render $ TF.constant (TF.Shape shape) input
|
||||
inputGradTensors <- forM (inputGrads `zip` shapes) $ \(inputGrad, shape) ->
|
||||
TF.render $ TF.constant (TF.Shape shape) inputGrad
|
||||
inputGradTensor <- TF.render $ TF.concat (TF.scalar 0) inputGradTensors
|
||||
inputTensor <- TF.render $ TF.concat (TF.scalar 0) inputTensors
|
||||
output <- TF.render $ inputTensor `TF.mul` inputGradTensor
|
||||
TF.gradients output inputTensors >>= TF.run
|
||||
(V.fromList <$> inputGrads) @=? dinputs
|
||||
|
||||
-- This test checks that the gradient of a concat op
|
||||
-- along the last dimension is as expected.
|
||||
-- This test is a port of ConcatTest._testGradientsLastDim from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
testConcatGradientLastDim :: Test
|
||||
testConcatGradientLastDim = testCase "testConcatGradientLastDim" $ do
|
||||
let shapes = [[10,2,x] | x <- [1,2,6]]
|
||||
(inputGrads :: [[Float]]) <- forM shapes $ \shape ->
|
||||
replicateM (List.product shape) randomIO
|
||||
(inputs :: [[Float]]) <- forM shapes $ \shape ->
|
||||
replicateM (List.product shape) randomIO
|
||||
dinputs <- TF.runSession $ do
|
||||
inputTensors <- forM (inputs `zip` shapes) $ \(input,shape) ->
|
||||
TF.render $ TF.constant (TF.Shape shape) input
|
||||
inputGradTensors <- forM (inputGrads `zip` shapes) $ \(inputGrad, shape) ->
|
||||
TF.render $ TF.constant (TF.Shape shape) inputGrad
|
||||
inputGradTensor <- TF.render $ TF.concat (TF.scalar 2) inputGradTensors
|
||||
inputTensor <- TF.render $ TF.concat (TF.scalar 2) inputTensors
|
||||
output <- TF.render $ inputTensor `TF.mul` inputGradTensor
|
||||
TF.gradients output inputTensors >>= TF.run
|
||||
(V.fromList <$> inputGrads) @=? dinputs
|
||||
-- The following check is equivalent to ConcatTest._testGradientsSimple from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
verifyConcatGradients [[10,x,2] | x <- [1,2,6]] 1
|
||||
-- The following check is equivalent to ConcatTest._testGradientsFirstDim from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
verifyConcatGradients [[x,10,2] | x <- [1,2,6]] 0
|
||||
-- The following check is equivalent to ConcatTest._testGradientsLastDim from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
verifyConcatGradients [[10,2,x] | x <- [1,2,6]] 2
|
||||
|
||||
|
||||
-- This test checks that the gradient of a concat op
|
||||
-- along a random dimension across random shapes is as expected.
|
||||
-- This test is a port of ConcatTest._RunAndVerifyGradientsRandom from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
||||
-- This test is inspired by ConcatTest._RunAndVerifyGradientsRandom from
|
||||
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py, but also
|
||||
-- verifies the gradient along negative concat dimensions.
|
||||
testConcatRunAndVerifyGradientsRandom :: Test
|
||||
testConcatRunAndVerifyGradientsRandom = testCase "testConcatRunAndVerifyGradientsRandom" $
|
||||
forM_ [1..5 :: Int] $ \_ -> do
|
||||
(shapes' :: [Int64]) <- replicateM 5 $ randomRIO (1, 5)
|
||||
(numTensors :: Int) <- randomRIO (2, 10)
|
||||
(concatDim :: Int32) <- randomRIO (0, 4)
|
||||
(concatDimSizes :: [Int64]) <- replicateM numTensors $ randomRIO (1, 5)
|
||||
let update i xs x = take (fromIntegral i) xs ++ x: drop (fromIntegral $ i+1) xs
|
||||
shapes = map (update concatDim shapes') concatDimSizes
|
||||
(inputGrads :: [[Float]]) <- forM shapes $ \shape ->
|
||||
replicateM (fromIntegral $ List.product shape) randomIO
|
||||
(inputs :: [[Float]]) <- forM shapes $ \shape ->
|
||||
replicateM (fromIntegral $ List.product shape) randomIO
|
||||
dinputs <- TF.runSession $ do
|
||||
inputTensors <- forM (inputs `zip` shapes) $ \(input,shape) ->
|
||||
TF.render $ TF.constant (TF.Shape shape) input
|
||||
inputTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputTensors
|
||||
inputGradTensors <- forM (inputGrads `zip` shapes) $ \(inputGrad, shape) ->
|
||||
TF.render $ TF.constant (TF.Shape shape) inputGrad
|
||||
inputGradTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputGradTensors
|
||||
output <- TF.render $ inputTensor `TF.mul` inputGradTensor
|
||||
TF.gradients output inputTensors >>= TF.run
|
||||
(V.fromList <$> inputGrads) @=? dinputs
|
||||
(shapes' :: [Int64]) <- replicateM 5 $ randomRIO (1, 5)
|
||||
(numTensors :: Int) <- randomRIO (2, 10)
|
||||
(concatDim :: Int) <- randomRIO (-4, 4)
|
||||
(concatDimSizes :: [Int64]) <- replicateM numTensors $ randomRIO (1, 5)
|
||||
let update i xs x = take i xs ++ x: drop (i+1) xs
|
||||
concatDim' = concatDim `mod` length shapes'
|
||||
shapes = map (update concatDim' shapes') concatDimSizes
|
||||
verifyConcatGradients shapes $ fromIntegral concatDim
|
||||
|
||||
testReluGrad :: Test
|
||||
testReluGrad = testCase "testReluGrad" $ do
|
||||
|
@ -415,8 +368,6 @@ main = defaultMain
|
|||
, testMaxGradient
|
||||
, testConcatGradient
|
||||
, testConcatGradientSimple
|
||||
, testConcatGradientFirstDim
|
||||
, testConcatGradientLastDim
|
||||
, testConcatRunAndVerifyGradientsRandom
|
||||
, testReluGrad
|
||||
, testReluGradGrad
|
||||
|
|
Loading…
Reference in New Issue
Block a user