2016-10-24 21:26:42 +02:00
|
|
|
-- Copyright 2016 TensorFlow authors.
|
|
|
|
--
|
|
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
-- you may not use this file except in compliance with the License.
|
|
|
|
-- You may obtain a copy of the License at
|
|
|
|
--
|
|
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
--
|
|
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
-- See the License for the specific language governing permissions and
|
|
|
|
-- limitations under the License.
|
|
|
|
|
|
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
|
2016-12-12 18:47:02 +01:00
|
|
|
import Data.Int (Int32)
|
2016-10-24 21:26:42 +02:00
|
|
|
import Data.List (sort)
|
|
|
|
import Data.ProtoLens.TextFormat (showMessage)
|
|
|
|
import Google.Test (googleTest)
|
|
|
|
import Lens.Family2 ((^..))
|
|
|
|
import Test.Framework (Test)
|
|
|
|
import Test.Framework.Providers.HUnit (testCase)
|
|
|
|
import Test.HUnit ((@=?))
|
2016-12-12 18:47:02 +01:00
|
|
|
import qualified Data.Vector as V
|
2016-10-24 21:26:42 +02:00
|
|
|
|
2016-12-12 18:47:02 +01:00
|
|
|
import qualified TensorFlow.Core as TF
|
|
|
|
import qualified TensorFlow.GenOps.Core as TF (max)
|
2016-10-24 21:26:42 +02:00
|
|
|
import qualified TensorFlow.Gradient as TF
|
|
|
|
import qualified TensorFlow.Ops as TF
|
|
|
|
|
|
|
|
import Proto.Tensorflow.Core.Framework.Graph (node)
|
|
|
|
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
|
|
|
|
|
|
|
testGradientSimple :: Test
|
|
|
|
testGradientSimple = testCase "testGradientSimple" $ do
|
|
|
|
let x = TF.scalar (3 :: Float)
|
|
|
|
b = TF.scalar (4 :: Float)
|
|
|
|
y = x*x + b
|
|
|
|
grads = TF.gradients y [x, b]
|
|
|
|
-- Assert that the gradients are right.
|
|
|
|
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
|
|
|
6 @=? TF.unScalar dx
|
|
|
|
1 @=? TF.unScalar db
|
|
|
|
-- Assert that the graph has the expected ops.
|
|
|
|
let graphDef = TF.asGraphDef grads
|
|
|
|
putStrLn $ showMessage graphDef
|
|
|
|
let ops = graphDef ^.. node . traverse . op
|
|
|
|
expected = [ "Const"
|
|
|
|
, "Mul"
|
|
|
|
, "Const"
|
|
|
|
, "Add"
|
|
|
|
-- Default output gradient of y.
|
|
|
|
, "Shape"
|
|
|
|
, "Const"
|
|
|
|
, "Fill"
|
|
|
|
-- Add gradient.
|
|
|
|
, "Shape"
|
|
|
|
, "Shape"
|
|
|
|
, "BroadcastGradientArgs"
|
|
|
|
, "Sum"
|
|
|
|
, "Sum"
|
|
|
|
, "Reshape"
|
|
|
|
, "Reshape"
|
|
|
|
-- Mul gradient.
|
|
|
|
, "Shape"
|
|
|
|
-- This Op gets dedup'd because the inputs are the same.
|
|
|
|
-- TODO(fmayle): The same would happen to the Mul and Sum ops
|
|
|
|
-- below if the gradient function didn't multiply one as
|
|
|
|
-- 'dz * y' and the other as 'x * dz'. We could change the
|
|
|
|
-- order, but I'm going to keep it the same as the python
|
|
|
|
-- version for now.
|
|
|
|
--
|
|
|
|
-- , "Shape"
|
|
|
|
, "BroadcastGradientArgs"
|
|
|
|
, "Mul"
|
|
|
|
, "Mul"
|
|
|
|
, "Sum"
|
|
|
|
, "Sum"
|
|
|
|
, "Reshape"
|
|
|
|
, "Reshape"
|
|
|
|
-- AddN to combine x's output gradients.
|
|
|
|
, "AddN"
|
|
|
|
]
|
|
|
|
sort expected @=? sort ops
|
|
|
|
|
|
|
|
testGradientDisconnected :: Test
|
|
|
|
testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
|
|
|
let x = TF.scalar (3 :: Float)
|
|
|
|
b = TF.scalar (4 :: Float)
|
|
|
|
grads = TF.gradients x [x, b]
|
|
|
|
-- Assert that the gradients are right.
|
|
|
|
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
|
|
|
1 @=? TF.unScalar dx
|
|
|
|
0 @=? TF.unScalar db
|
|
|
|
-- Assert that the graph has the expected ops.
|
|
|
|
let graphDef = TF.asGraphDef grads
|
|
|
|
putStrLn $ showMessage graphDef
|
|
|
|
let ops = graphDef ^.. node . traverse . op
|
|
|
|
expected = [ "Const"
|
|
|
|
, "Const"
|
|
|
|
-- Default output gradient of x.
|
|
|
|
, "Shape"
|
|
|
|
, "Const"
|
|
|
|
, "Fill"
|
|
|
|
-- Default output gradient of b.
|
|
|
|
, "ZerosLike"
|
|
|
|
]
|
|
|
|
sort expected @=? sort ops
|
|
|
|
|
|
|
|
|
|
|
|
-- Test that identical "stateful" ops work with createGraph.
|
|
|
|
testCreateGraphStateful :: Test
|
|
|
|
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
|
|
|
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do
|
|
|
|
let shape = TF.constant (TF.Shape [1]) [1]
|
|
|
|
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
|
|
|
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
|
|
|
TF.gradients (x + y*3) [x, y]
|
|
|
|
-- If this test fails, it will likely be caused by an exception within
|
|
|
|
-- `TF.gradients`. These asserts are extra.
|
|
|
|
1 @=? TF.unScalar dx
|
|
|
|
3 @=? TF.unScalar dy
|
|
|
|
|
|
|
|
|
|
|
|
-- Test that name scopes work with createGraph.
|
|
|
|
testCreateGraphNameScopes :: Test
|
|
|
|
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
|
|
|
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
|
|
|
let shape = TF.constant (TF.Shape [1]) [1]
|
|
|
|
x :: TF.Tensor TF.Value Float <-
|
|
|
|
TF.withNameScope "foo" (TF.truncatedNormal shape)
|
|
|
|
TF.gradients x [x]
|
|
|
|
-- If this test fails, it will likely be caused by an exception within
|
|
|
|
-- `TF.gradients`. This assert is extra.
|
|
|
|
1 @=? TF.unScalar dx
|
|
|
|
|
|
|
|
|
|
|
|
-- Test that createGraph can handle graphs with diamond shapes.
|
|
|
|
testDiamond :: Test
|
|
|
|
testDiamond = testCase "testDiamond" $ do
|
|
|
|
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
|
|
|
let x = TF.vector [1]
|
|
|
|
y = x*x
|
|
|
|
z = y*y
|
|
|
|
TF.gradients z [x]
|
|
|
|
(4 :: Float) @=? TF.unScalar dx
|
|
|
|
|
|
|
|
|
2016-12-12 18:47:02 +01:00
|
|
|
testMaxGradient :: Test
|
|
|
|
testMaxGradient = testCase "testMaxGradient" $ do
|
|
|
|
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
|
|
|
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
|
|
|
|
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
|
|
|
|
TF.gradients y [x]
|
|
|
|
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
|
|
|
|
|
|
|
|
2016-10-24 21:26:42 +02:00
|
|
|
main :: IO ()
|
|
|
|
main = googleTest [ testGradientSimple
|
|
|
|
, testGradientDisconnected
|
|
|
|
, testCreateGraphStateful
|
|
|
|
, testCreateGraphNameScopes
|
|
|
|
, testDiamond
|
2016-12-12 18:47:02 +01:00
|
|
|
, testMaxGradient
|
2016-10-24 21:26:42 +02:00
|
|
|
]
|