2016-10-24 21:26:42 +02:00
|
|
|
-- Copyright 2016 TensorFlow authors.
|
|
|
|
--
|
|
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
-- you may not use this file except in compliance with the License.
|
|
|
|
-- You may obtain a copy of the License at
|
|
|
|
--
|
|
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
--
|
|
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
-- See the License for the specific language governing permissions and
|
|
|
|
-- limitations under the License.
|
|
|
|
|
|
|
|
{-# LANGUAGE OverloadedStrings #-}
|
2017-03-18 20:08:53 +01:00
|
|
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
2016-10-24 21:26:42 +02:00
|
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
2017-05-04 09:39:15 +02:00
|
|
|
{-# LANGUAGE FlexibleContexts #-}
|
2016-10-24 21:26:42 +02:00
|
|
|
|
2017-05-04 09:39:15 +02:00
|
|
|
import Data.Int (Int32, Int64)
|
2016-10-24 21:26:42 +02:00
|
|
|
import Data.List (sort)
|
2017-07-30 05:29:33 +02:00
|
|
|
import qualified Data.List as List
|
2016-10-24 21:26:42 +02:00
|
|
|
import Data.ProtoLens.TextFormat (showMessage)
|
2017-05-11 00:26:03 +02:00
|
|
|
import Test.Framework (defaultMain, Test)
|
2017-05-04 09:39:15 +02:00
|
|
|
import Lens.Family2 ((^..), (.~))
|
|
|
|
|
2016-10-24 21:26:42 +02:00
|
|
|
import Test.Framework.Providers.HUnit (testCase)
|
2017-05-04 09:39:15 +02:00
|
|
|
import Test.HUnit ((@=?), assertEqual)
|
2016-12-12 18:47:02 +01:00
|
|
|
import qualified Data.Vector as V
|
2017-07-30 05:29:33 +02:00
|
|
|
import System.Random (randomIO, randomRIO)
|
|
|
|
import Control.Monad(forM_, replicateM, zipWithM)
|
2017-05-04 09:39:15 +02:00
|
|
|
import Control.Monad.IO.Class (liftIO)
|
2016-10-24 21:26:42 +02:00
|
|
|
|
2016-12-12 18:47:02 +01:00
|
|
|
import qualified TensorFlow.Core as TF
|
2017-10-15 20:49:44 +02:00
|
|
|
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile)
|
2016-10-24 21:26:42 +02:00
|
|
|
import qualified TensorFlow.Gradient as TF
|
2017-05-17 22:20:51 +02:00
|
|
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
2017-05-04 09:39:15 +02:00
|
|
|
import qualified TensorFlow.Output as TF
|
|
|
|
import qualified TensorFlow.Types as TF
|
2017-05-17 22:20:51 +02:00
|
|
|
import qualified TensorFlow.Variable as TF
|
2016-10-24 21:26:42 +02:00
|
|
|
|
|
|
|
import Proto.Tensorflow.Core.Framework.Graph (node)
|
|
|
|
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
|
|
|
|
2017-10-15 20:49:44 +02:00
|
|
|
import qualified Data.ByteString.Char8 as BS
|
|
|
|
|
2016-10-24 21:26:42 +02:00
|
|
|
testGradientSimple :: Test
|
|
|
|
testGradientSimple = testCase "testGradientSimple" $ do
|
2017-04-07 00:10:33 +02:00
|
|
|
let grads = do
|
|
|
|
x <- TF.render $ TF.scalar (3 :: Float)
|
|
|
|
b <- TF.render $ TF.scalar (4 :: Float)
|
|
|
|
let y = x `TF.mul` x `TF.add` b
|
|
|
|
TF.gradients y [x, b]
|
2016-10-24 21:26:42 +02:00
|
|
|
-- Assert that the gradients are right.
|
2017-03-18 20:08:53 +01:00
|
|
|
[dx, db] <- TF.runSession $ grads >>= TF.run
|
2016-10-24 21:26:42 +02:00
|
|
|
6 @=? TF.unScalar dx
|
|
|
|
1 @=? TF.unScalar db
|
|
|
|
-- Assert that the graph has the expected ops.
|
|
|
|
let graphDef = TF.asGraphDef grads
|
|
|
|
putStrLn $ showMessage graphDef
|
|
|
|
let ops = graphDef ^.. node . traverse . op
|
|
|
|
expected = [ "Const"
|
|
|
|
, "Mul"
|
|
|
|
, "Const"
|
|
|
|
, "Add"
|
|
|
|
-- Default output gradient of y.
|
|
|
|
, "Shape"
|
|
|
|
, "Const"
|
|
|
|
, "Fill"
|
|
|
|
-- Add gradient.
|
|
|
|
, "Shape"
|
|
|
|
, "Shape"
|
|
|
|
, "BroadcastGradientArgs"
|
|
|
|
, "Sum"
|
|
|
|
, "Sum"
|
|
|
|
, "Reshape"
|
|
|
|
, "Reshape"
|
|
|
|
-- Mul gradient.
|
|
|
|
, "Shape"
|
|
|
|
-- This Op gets dedup'd because the inputs are the same.
|
|
|
|
-- TODO(fmayle): The same would happen to the Mul and Sum ops
|
|
|
|
-- below if the gradient function didn't multiply one as
|
|
|
|
-- 'dz * y' and the other as 'x * dz'. We could change the
|
|
|
|
-- order, but I'm going to keep it the same as the python
|
|
|
|
-- version for now.
|
|
|
|
--
|
|
|
|
-- , "Shape"
|
|
|
|
, "BroadcastGradientArgs"
|
|
|
|
, "Mul"
|
|
|
|
, "Mul"
|
|
|
|
, "Sum"
|
|
|
|
, "Sum"
|
|
|
|
, "Reshape"
|
|
|
|
, "Reshape"
|
|
|
|
-- AddN to combine x's output gradients.
|
|
|
|
, "AddN"
|
|
|
|
]
|
|
|
|
sort expected @=? sort ops
|
|
|
|
|
|
|
|
testGradientDisconnected :: Test
|
|
|
|
testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
2017-04-07 00:10:33 +02:00
|
|
|
let grads = do
|
|
|
|
x <- TF.render $ TF.scalar (3 :: Float)
|
|
|
|
b <- TF.render $ TF.scalar (4 :: Float)
|
|
|
|
TF.gradients x [x, b]
|
2016-10-24 21:26:42 +02:00
|
|
|
-- Assert that the gradients are right.
|
2017-03-18 20:08:53 +01:00
|
|
|
[dx, db] <- TF.runSession $ grads >>= TF.run
|
2016-10-24 21:26:42 +02:00
|
|
|
1 @=? TF.unScalar dx
|
|
|
|
0 @=? TF.unScalar db
|
|
|
|
-- Assert that the graph has the expected ops.
|
|
|
|
let graphDef = TF.asGraphDef grads
|
|
|
|
putStrLn $ showMessage graphDef
|
|
|
|
let ops = graphDef ^.. node . traverse . op
|
|
|
|
expected = [ "Const"
|
|
|
|
, "Const"
|
|
|
|
-- Default output gradient of x.
|
|
|
|
, "Shape"
|
|
|
|
, "Const"
|
|
|
|
, "Fill"
|
|
|
|
-- Default output gradient of b.
|
|
|
|
, "ZerosLike"
|
|
|
|
]
|
|
|
|
sort expected @=? sort ops
|
|
|
|
|
|
|
|
|
|
|
|
-- Test that identical "stateful" ops work with createGraph.
|
|
|
|
testCreateGraphStateful :: Test
|
|
|
|
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
2017-03-18 20:08:53 +01:00
|
|
|
[dx, dy] <- TF.runSession $ do
|
2016-10-24 21:26:42 +02:00
|
|
|
let shape = TF.constant (TF.Shape [1]) [1]
|
|
|
|
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
|
|
|
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
2017-04-07 00:10:33 +02:00
|
|
|
TF.gradients (TF.expr x + TF.expr y * 3) [x, y] >>= TF.run
|
2016-10-24 21:26:42 +02:00
|
|
|
-- If this test fails, it will likely be caused by an exception within
|
|
|
|
-- `TF.gradients`. These asserts are extra.
|
|
|
|
1 @=? TF.unScalar dx
|
|
|
|
3 @=? TF.unScalar dy
|
|
|
|
|
|
|
|
|
|
|
|
-- Test that name scopes work with createGraph.
|
|
|
|
testCreateGraphNameScopes :: Test
|
|
|
|
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
2017-03-18 20:08:53 +01:00
|
|
|
[dx] <- TF.runSession $ do
|
2016-10-24 21:26:42 +02:00
|
|
|
let shape = TF.constant (TF.Shape [1]) [1]
|
|
|
|
x :: TF.Tensor TF.Value Float <-
|
|
|
|
TF.withNameScope "foo" (TF.truncatedNormal shape)
|
2017-03-18 20:08:53 +01:00
|
|
|
TF.gradients x [x] >>= TF.run
|
2016-10-24 21:26:42 +02:00
|
|
|
-- If this test fails, it will likely be caused by an exception within
|
|
|
|
-- `TF.gradients`. This assert is extra.
|
|
|
|
1 @=? TF.unScalar dx
|
|
|
|
|
|
|
|
|
|
|
|
-- Test that createGraph can handle graphs with diamond shapes.
|
|
|
|
testDiamond :: Test
|
|
|
|
testDiamond = testCase "testDiamond" $ do
|
2017-03-18 20:08:53 +01:00
|
|
|
[dx] <- TF.runSession $ do
|
2017-04-07 00:10:33 +02:00
|
|
|
x <- TF.render $ TF.vector [1]
|
|
|
|
let y = x `TF.mul` x
|
2016-10-24 21:26:42 +02:00
|
|
|
z = y*y
|
2017-03-18 20:08:53 +01:00
|
|
|
TF.gradients z [x] >>= TF.run
|
2016-10-24 21:26:42 +02:00
|
|
|
(4 :: Float) @=? TF.unScalar dx
|
|
|
|
|
|
|
|
|
2017-06-16 13:42:33 +02:00
|
|
|
testAddNGradient :: Test
|
|
|
|
testAddNGradient = testCase "testAddNGradient" $ do
|
|
|
|
[dx] <- TF.runSession $ do
|
|
|
|
x <- TF.render $ TF.vector [1, 2, 0 :: Float]
|
|
|
|
let y = TF.addN [x, x]
|
|
|
|
TF.gradients y [x] >>= TF.run
|
|
|
|
V.fromList [2, 2, 2 :: Float] @=? dx
|
|
|
|
|
|
|
|
|
2016-12-12 18:47:02 +01:00
|
|
|
testMaxGradient :: Test
|
|
|
|
testMaxGradient = testCase "testMaxGradient" $ do
|
2017-03-18 20:08:53 +01:00
|
|
|
[dx] <- TF.runSession $ do
|
2017-04-07 00:10:33 +02:00
|
|
|
x <- TF.render $ TF.vector [1, 2, 3, 0, 1 :: Float]
|
|
|
|
let y = TF.max x (0 :: TF.Tensor TF.Build Int32)
|
2017-03-18 20:08:53 +01:00
|
|
|
TF.gradients y [x] >>= TF.run
|
2016-12-12 18:47:02 +01:00
|
|
|
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
|
|
|
|
2017-07-30 05:29:33 +02:00
|
|
|
testConcatGradient :: Test
|
|
|
|
testConcatGradient = testCase "testConcatGradient" $ do
|
|
|
|
[dv,dv'] <- TF.runSession $ do
|
|
|
|
v <- TF.render $ TF.vector [1 :: Float]
|
|
|
|
v' <- TF.render $ TF.vector [2 :: Float]
|
|
|
|
let y = TF.concat (TF.scalar 0) [ v, v' ]
|
|
|
|
TF.gradients y [v,v'] >>= TF.run
|
|
|
|
V.fromList [1 :: Float] @=? dv
|
|
|
|
V.fromList [1 :: Float] @=? dv'
|
|
|
|
[dw,dw'] <- TF.runSession $ do
|
|
|
|
w <- TF.render $ TF.vector [1,2,3,4 :: Float]
|
|
|
|
w' <- TF.render $ TF.vector [5,6,7,8 :: Float]
|
|
|
|
let y = TF.concat (TF.scalar 0) [ w, w', w ]
|
|
|
|
TF.gradients y [w,w'] >>= TF.run
|
|
|
|
V.fromList [2,2,2,2 :: Float] @=? dw
|
|
|
|
V.fromList [1,1,1,1 :: Float] @=? dw'
|
|
|
|
|
|
|
|
verifyConcatGradients :: [[Int64]] -> Int32 -> IO ()
|
|
|
|
verifyConcatGradients shapes concatDim = do
|
|
|
|
let floatsFromShape :: [Int64] -> IO [Float]
|
|
|
|
floatsFromShape shape = replicateM (fromIntegral $ List.product shape) randomIO
|
|
|
|
constantZip = zipWithM $ \x shape -> TF.render $ TF.constant (TF.Shape shape) x
|
|
|
|
inputGrads <- mapM floatsFromShape shapes
|
|
|
|
inputs <- mapM floatsFromShape shapes
|
|
|
|
dinputs <- TF.runSession $ do
|
|
|
|
inputTensors <- inputs `constantZip` shapes
|
|
|
|
inputGradTensors <- inputGrads `constantZip` shapes
|
|
|
|
inputTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputTensors
|
|
|
|
inputGradTensor <- TF.render $ TF.concat (TF.scalar concatDim) inputGradTensors
|
|
|
|
output <- TF.render $ inputTensor `TF.mul` inputGradTensor
|
|
|
|
TF.gradients output inputTensors >>= TF.run
|
|
|
|
(V.fromList <$> inputGrads) @=? dinputs
|
|
|
|
|
|
|
|
-- This test checks that the gradient of a concat op
|
|
|
|
-- is correct along the first, second, and third dimension.
|
|
|
|
testConcatGradientSimple :: Test
|
|
|
|
testConcatGradientSimple = testCase "testConcatGradientSimple" $ do
|
|
|
|
-- The following check is equivalent to ConcatTest._testGradientsSimple from
|
|
|
|
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
|
|
|
verifyConcatGradients [[10,x,2] | x <- [1,2,6]] 1
|
|
|
|
-- The following check is equivalent to ConcatTest._testGradientsFirstDim from
|
|
|
|
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
|
|
|
verifyConcatGradients [[x,10,2] | x <- [1,2,6]] 0
|
|
|
|
-- The following check is equivalent to ConcatTest._testGradientsLastDim from
|
|
|
|
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py
|
|
|
|
verifyConcatGradients [[10,2,x] | x <- [1,2,6]] 2
|
|
|
|
|
|
|
|
|
|
|
|
-- This test checks that the gradient of a concat op
|
|
|
|
-- along a random dimension across random shapes is as expected.
|
|
|
|
-- This test is inspired by ConcatTest._RunAndVerifyGradientsRandom from
|
|
|
|
-- tensorflow/tensorflow/compiler/tests/concat_ops_test.py, but also
|
|
|
|
-- verifies the gradient along negative concat dimensions.
|
|
|
|
testConcatRunAndVerifyGradientsRandom :: Test
|
|
|
|
testConcatRunAndVerifyGradientsRandom = testCase "testConcatRunAndVerifyGradientsRandom" $
|
|
|
|
forM_ [1..5 :: Int] $ \_ -> do
|
|
|
|
(shapes' :: [Int64]) <- replicateM 5 $ randomRIO (1, 5)
|
|
|
|
(numTensors :: Int) <- randomRIO (2, 10)
|
|
|
|
(concatDim :: Int) <- randomRIO (-4, 4)
|
|
|
|
(concatDimSizes :: [Int64]) <- replicateM numTensors $ randomRIO (1, 5)
|
|
|
|
let update i xs x = take i xs ++ x: drop (i+1) xs
|
|
|
|
concatDim' = concatDim `mod` length shapes'
|
|
|
|
shapes = map (update concatDim' shapes') concatDimSizes
|
|
|
|
verifyConcatGradients shapes $ fromIntegral concatDim
|
|
|
|
|
2017-06-16 10:26:10 +02:00
|
|
|
-- run single test like this:
|
|
|
|
-- stack --docker --docker-image=$IMAGE_NAME test tensorflow-ops:GradientTest --test-arguments -t"*MaximumGrad*"
|
|
|
|
testMaximumGrad :: Test
|
|
|
|
testMaximumGrad = testCase "testMaximumGrad" $ do
|
|
|
|
[gx, gy] <- TF.runSession $ do
|
|
|
|
x <- TF.render $ TF.vector [0 :: Float]
|
|
|
|
y <- TF.render $ TF.vector [0 :: Float]
|
|
|
|
let z = TF.maximum x y
|
|
|
|
TF.gradients z [x, y] >>= TF.run
|
|
|
|
V.fromList [1] @=? gx
|
|
|
|
V.fromList [1] @=? gy
|
|
|
|
|
|
|
|
testMaximumGradGrad :: Test
|
|
|
|
testMaximumGradGrad = testCase "testMaximumGradGrad" $ do
|
|
|
|
[ggx] <- TF.runSession $ do
|
|
|
|
x <- TF.render $ TF.vector [2 :: Float]
|
|
|
|
y <- TF.render $ TF.vector [1 :: Float]
|
|
|
|
let z = TF.maximum x y
|
|
|
|
[gx, _gy] <- TF.gradients z [x, y]
|
|
|
|
TF.gradients gx [x] >>= TF.run
|
|
|
|
V.fromList [0] @=? ggx
|
2016-12-12 18:47:02 +01:00
|
|
|
|
2017-04-30 20:18:06 +02:00
|
|
|
testReluGrad :: Test
|
|
|
|
testReluGrad = testCase "testReluGrad" $ do
|
|
|
|
[dx] <- TF.runSession $ do
|
|
|
|
x <- TF.render $ TF.vector [2 :: Float]
|
|
|
|
let y = TF.relu x
|
|
|
|
TF.gradients y [x] >>= TF.run
|
|
|
|
V.fromList [1] @=? dx
|
|
|
|
|
|
|
|
testReluGradGrad :: Test
|
|
|
|
testReluGradGrad = testCase "testReluGradGrad" $ do
|
|
|
|
[dx] <- TF.runSession $ do
|
|
|
|
x <- TF.render $ TF.vector [2 :: Float]
|
|
|
|
let y = TF.relu x
|
|
|
|
[y'] <- TF.gradients y [x]
|
|
|
|
TF.gradients y' [x] >>= TF.run
|
|
|
|
V.fromList [0] @=? dx
|
|
|
|
|
|
|
|
testFillGrad :: Test
|
|
|
|
testFillGrad = testCase "testFillGrad" $ do
|
|
|
|
[dx] <- TF.runSession $ do
|
|
|
|
x <- TF.render $ TF.scalar (9 :: Float)
|
|
|
|
let shape = TF.vector [2, 3 :: Int32]
|
|
|
|
let y = TF.fill shape x
|
|
|
|
TF.gradients y [x] >>= TF.run
|
|
|
|
V.fromList [6] @=? dx
|
|
|
|
|
2017-04-29 15:57:42 +02:00
|
|
|
|
|
|
|
testTileGrad :: Test
|
|
|
|
testTileGrad = testCase "testTileGrad" $ do
|
|
|
|
[dx] <- TF.runSession $ do
|
|
|
|
x <- TF.render $ TF.vector [5, 9 :: Float]
|
|
|
|
let multiples = TF.vector [2 :: Int32]
|
|
|
|
let y = TF.tile x multiples
|
|
|
|
TF.gradients y [x] >>= TF.run
|
|
|
|
V.fromList [2, 2] @=? dx
|
|
|
|
|
|
|
|
|
|
|
|
testTile2DGrad :: Test
|
|
|
|
testTile2DGrad = testCase "testTileGrad2D" $ do
|
|
|
|
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
|
|
|
let shape = TF.vector [3, 2 :: Int32]
|
|
|
|
x <- TF.render $ TF.fill shape (TF.scalar (1::Float))
|
|
|
|
let multiples = TF.vector [2, 3 :: Int32]
|
|
|
|
let y = TF.tile x multiples
|
|
|
|
|
|
|
|
[dx] <- TF.gradients y [x]
|
2017-05-04 09:39:15 +02:00
|
|
|
TF.run (dx, TF.shape dx, TF.shape x)
|
2017-04-29 15:57:42 +02:00
|
|
|
shapeX @=? (shapeDX :: V.Vector Int32)
|
|
|
|
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
|
|
|
|
|
2017-05-04 09:39:15 +02:00
|
|
|
matMulGradient :: Test
|
|
|
|
matMulGradient = testCase "matMulGradients" $ do
|
|
|
|
|
|
|
|
let dfBuild = do
|
|
|
|
x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64]
|
|
|
|
w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
2017-05-17 22:20:51 +02:00
|
|
|
let f = x `TF.matMul` TF.readValue w :: TF.Tensor TF.Build Float
|
2017-05-04 09:39:15 +02:00
|
|
|
dfs <- TF.gradients f [x]
|
|
|
|
return (x, dfs)
|
|
|
|
|
|
|
|
(xShape, dxShape) <- TF.runSession $ do
|
|
|
|
(x, [dx]) <- TF.build dfBuild
|
|
|
|
TF.run (TF.shape x, TF.shape dx)
|
|
|
|
|
|
|
|
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
|
|
|
|
|
|
|
|
|
|
|
|
-- test that gradient of matMul can be taken gradient of
|
|
|
|
matMulGradGrad :: Test
|
|
|
|
matMulGradGrad = testCase "matMulGradGrad" $ do
|
|
|
|
let width = 2 :: Int64
|
|
|
|
batch = 4 :: Int64
|
|
|
|
|
|
|
|
let tower = do
|
|
|
|
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
|
|
|
|
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
|
2017-05-17 22:20:51 +02:00
|
|
|
let f = x `TF.matMul` TF.readValue w
|
2017-05-04 09:39:15 +02:00
|
|
|
[dfdx] <- TF.gradients f [x]
|
|
|
|
let f'x = TF.reduceSum dfdx
|
|
|
|
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
2017-05-17 22:20:51 +02:00
|
|
|
return [TF.readValue w, TF.expr dfdw]
|
2017-05-04 09:39:15 +02:00
|
|
|
|
|
|
|
TF.runSession $ do
|
|
|
|
[w, dfdw] <- TF.build tower
|
|
|
|
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
|
|
|
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
|
|
|
|
|
|
|
let step = w `TF.add` dfdw
|
|
|
|
w0 <- TF.run step
|
2017-05-17 22:20:51 +02:00
|
|
|
liftIO $ V.fromList [4, 4 :: Float] @=? w0
|
2017-05-04 09:39:15 +02:00
|
|
|
|
|
|
|
|
|
|
|
-- test that gradient of matMul deals correctly with transpose_a and transpose_b
|
|
|
|
matMulTransposeGradient :: (Bool, Bool) -> Test
|
2017-05-17 22:20:51 +02:00
|
|
|
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw) $ do
|
2017-05-04 09:39:15 +02:00
|
|
|
let (transposeX, transposeW) = txw
|
|
|
|
|
|
|
|
let dfBuild = do
|
|
|
|
let xShape = TF.Shape [3, 1 :: Int64]
|
|
|
|
let xZeros = TF.zeros xShape
|
|
|
|
x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros
|
|
|
|
variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
2017-05-17 22:20:51 +02:00
|
|
|
let wv = if transposeW then TF.matTranspose (TF.readValue variable) else TF.readValue variable
|
2017-05-04 09:39:15 +02:00
|
|
|
let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float
|
|
|
|
w <- TF.render wv
|
|
|
|
ds <- TF.gradients f [x, w]
|
|
|
|
return (x, w, ds)
|
|
|
|
|
|
|
|
TF.runSession $ do
|
|
|
|
(x, w, [dx, dw]) <- TF.build dfBuild
|
|
|
|
xShape <- TF.run $ TF.shape x
|
|
|
|
dxShape <- TF.run $ TF.shape dx
|
|
|
|
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
|
|
|
|
|
|
|
wShape <- TF.run $ TF.shape w
|
|
|
|
dwShape <- TF.run $ TF.shape dw
|
|
|
|
liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32)
|
|
|
|
|
|
|
|
transAttrs :: (TF.Attribute a,
|
|
|
|
TF.Attribute b) =>
|
|
|
|
a -> b -> TF.OpDef -> TF.OpDef
|
|
|
|
transAttrs a b =
|
|
|
|
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
|
|
|
|
|
2017-10-15 20:49:44 +02:00
|
|
|
testConv2DBackpropInputGrad :: Test
|
|
|
|
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
|
|
|
(dx, shapeDX, shapeX) <- TF.runSession $ do
|
|
|
|
let conv_input_shape = TF.vector [1, 2, 2, 1 :: Int32] -- [batch, h, w, in_channels]
|
|
|
|
let conv_out_shape = TF.vector [1, 1, 1, 1 :: Int32] -- [batch, h, w, out_channels]
|
|
|
|
x <- TF.render $ TF.fill conv_out_shape (TF.scalar (1::Float))
|
|
|
|
|
|
|
|
let filterShape = TF.vector [2, 2, 1, 1 :: Int32] -- [fh, fw, inc, out]
|
2018-05-16 05:19:15 +02:00
|
|
|
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
|
2017-10-15 20:49:44 +02:00
|
|
|
let y = TF.conv2DBackpropInput'
|
|
|
|
( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
|
|
|
|
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
|
|
|
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
|
|
|
)
|
2018-05-16 05:19:15 +02:00
|
|
|
conv_input_shape filter' x
|
2017-10-15 20:49:44 +02:00
|
|
|
|
|
|
|
[dx] <- TF.gradients y [x]
|
|
|
|
TF.run (dx, TF.shape dx, TF.shape x)
|
|
|
|
shapeX @=? (shapeDX :: V.Vector Int32)
|
|
|
|
V.fromList [4::Float] @=? (dx :: V.Vector Float)
|
|
|
|
|
|
|
|
|
2016-10-24 21:26:42 +02:00
|
|
|
main :: IO ()
|
2017-05-11 00:26:03 +02:00
|
|
|
main = defaultMain
|
|
|
|
[ testGradientSimple
|
|
|
|
, testGradientDisconnected
|
|
|
|
, testCreateGraphStateful
|
|
|
|
, testCreateGraphNameScopes
|
|
|
|
, testDiamond
|
2017-06-16 13:42:33 +02:00
|
|
|
, testAddNGradient
|
2017-05-11 00:26:03 +02:00
|
|
|
, testMaxGradient
|
2017-07-30 05:29:33 +02:00
|
|
|
, testConcatGradient
|
|
|
|
, testConcatGradientSimple
|
|
|
|
, testConcatRunAndVerifyGradientsRandom
|
2017-06-16 10:26:10 +02:00
|
|
|
, testMaximumGrad
|
|
|
|
, testMaximumGradGrad
|
2017-05-11 00:26:03 +02:00
|
|
|
, testReluGrad
|
|
|
|
, testReluGradGrad
|
|
|
|
, testFillGrad
|
|
|
|
, testTileGrad
|
|
|
|
, testTile2DGrad
|
|
|
|
, matMulGradient
|
|
|
|
, matMulGradGrad
|
|
|
|
, matMulTransposeGradient (False, False)
|
|
|
|
, matMulTransposeGradient (False, True)
|
|
|
|
, matMulTransposeGradient (True, False)
|
|
|
|
, matMulTransposeGradient (True, True)
|
2017-10-15 20:49:44 +02:00
|
|
|
, testConv2DBackpropInputGrad
|
2017-05-11 00:26:03 +02:00
|
|
|
]
|