test case to show can't calculate grad for embedding (and associated fix) (#23)

* Fix for embedding gradient calculation

- Passes vectors instead of scalars to slice
- converts the numRows to a scalar
- add `toScalar` utility function
- minor change to test case so that it actually works

* added lib for testing helper functions

* add flatSlice function
This commit is contained in:
Noon van der Silk 2016-11-18 08:54:36 +11:00 committed by fkm3
parent fc3d398ca9
commit 69fdbf677f
10 changed files with 153 additions and 32 deletions

View File

@ -11,6 +11,7 @@ packages:
- tensorflow-mnist-input-data
- tensorflow-queue
- tensorflow-nn
- tensorflow-test
extra-deps:
# proto-lens is not yet in Stackage.

View File

@ -30,6 +30,7 @@ Test-Suite NNTest
, QuickCheck
, base
, tensorflow
, tensorflow-test
, tensorflow-ops
, tensorflow-nn
, google-shim

View File

@ -22,6 +22,7 @@ module Main where
import Data.Maybe (fromMaybe)
import Google.Test (googleTest)
import TensorFlow.Test (assertAllClose)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@?))
import Test.HUnit.Lang (Assertion(..))
@ -63,15 +64,6 @@ defInputs = Inputs {
}
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
++ "\ntolerance: " ++ show tol)
where
absDiff x y = abs (x - y)
tol = 0.001 :: Float
testLogisticOutput = testCase "testLogisticOutput" $ do
let inputs = defInputs
vLogits = TF.vector $ logits inputs

View File

@ -77,6 +77,7 @@ import TensorFlow.Ops
, shape
, softmaxCrossEntropyWithLogits
, sum
, scalarize
, vector
, zerosLike
)
@ -402,6 +403,21 @@ type GradientFunc a = NodeDef
toT :: Output -> Tensor Value a
toT = Tensor ValueKind
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
-- simple slicing operations.
flatSlice :: forall v1 t i . (TensorType t)
=> Tensor v1 t -- ^ __input__
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
-- 'input' to slice from.
-> Int32 -- ^ __size__: specifies the number of elements of the first dimension
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
-- are included in the slice (i.e. this is equivalent to setting
-- size = input.dim_size(0) - begin).
-> Tensor Value t -- ^ __output__
flatSlice input begin size = CoreOps.slice input (vector [begin]) (vector [size])
-- | The gradient function for an op type.
--
-- These implementations should match their python counterparts in:
@ -430,11 +446,10 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
where
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
denseShape = shape (x :: Tensor Value a)
numRows = CoreOps.slice denseShape 0 (1 :: Tensor Value Int32)
valuesShape = CoreOps.concat 0 [
allDimensions
, CoreOps.slice denseShape 1 (-1 :: Tensor Value Int32)
]
numRows = scalarize $ flatSlice denseShape 0 1
valuesShape = CoreOps.concat 0 [ allDimensions
, flatSlice denseShape 1 (-1)
]
values = reshape dz valuesShape
-- TODO(fmayle): This could be either Int32 or Int64.
indices' = reshape indices allDimensions :: Tensor Value Int32
@ -628,7 +643,7 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
, Nothing
, Nothing
]
where inputRows = CoreOps.slice (shape (x :: Tensor Value a)) (scalar (0 :: Int32)) (scalar 1)
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing]

View File

@ -101,6 +101,7 @@ module TensorFlow.Ops
, vector
, zeros
, CoreOps.zerosLike
, scalarize
) where
import Data.ByteString (ByteString)
@ -256,6 +257,15 @@ constant (Shape shape') values
[def & TensorShape.size .~ x | x <- shape']
& tensorVal .~ values
-- | Reshape a N-D tensor down to a scalar.
--
-- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
scalarize t = CoreOps.reshape t (vector scalarShape)
where
scalarShape = [] :: [Int32]
-- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Value a
vector xs = constant [fromIntegral $ length xs] xs

View File

@ -60,6 +60,7 @@ Test-Suite EmbeddingOpsTest
, lens-family
, google-shim
, tensorflow
, tensorflow-test
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-proto

View File

@ -23,10 +23,12 @@ import Data.List (genericLength)
import Google.Test (googleTest)
import TensorFlow.EmbeddingOps (embeddingLookup)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.HUnit.Lang (Assertion(..))
import Test.HUnit ((@=?), (@?))
import Test.Framework.Providers.HUnit (testCase)
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run)
import TensorFlow.Test (assertAllClose)
import qualified Data.Vector as V
import qualified TensorFlow.GenOps.Core as CoreOps
@ -34,21 +36,26 @@ import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Nodes as TF
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
buildAndRun = TF.runSession . TF.buildAnd TF.run
-- | Tries to perform a simple embedding lookup, with two partitions.
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
testEmbeddingLookupHasRightShapeWithPartition =
testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
let embedding1 = [ 1, 1, 1 ] :: [Int32]
let embedding2 = [ 0, 0, 0 ] :: [Int32]
let embedding1 = [1, 1, 1 :: Int32]
let embedding2 = [0, 0, 0 :: Int32]
let embedding = [ TF.constant shape embedding1
, TF.constant shape embedding2
]
let embedding = [ TF.constant shape embedding1
, TF.constant shape embedding2
]
let idValues = [0, 1] :: [Int32]
let idValues = [0, 1 :: Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids
@ -57,20 +64,23 @@ testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHas
return (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]
shape @=? V.fromList [1, 2, 3]
-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
values @=? V.fromList [1, 1, 1, 0, 0, 0]
-- | Tries to perform a simple embedding lookup, with only a single partition.
testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do
let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items.
testEmbeddingLookupHasRightShape =
testCase "testEmbeddingLookupHasRightShape" $ do
-- Consider a 3-dim embedding of two items
let shape = TF.Shape [2, 3]
let embeddingInit = [ 1, 1, 1
, 0, 0, 0 ] :: [Int32]
, 0, 0, 0 :: Int32
]
let embedding = TF.constant shape embeddingInit
let idValues = [0, 1] :: [Int32]
let idValues = [0, 1 :: Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids
@ -79,10 +89,39 @@ testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $
return (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]
shape @=? V.fromList [1, 2, 3]
-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
values @=? V.fromList [1, 1, 1, 0, 0, 0]
-- | Check that we can calculate gradients w.r.t embeddings.
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
-- Agrees with "embedding", so gradient should be zero.
let xVals = V.fromList ([20, 20 :: Float])
let shape = TF.Shape [2]
gs <- TF.runSession $ do
grads <- TF.build $ do
let shape = TF.Shape [2, 1]
let embeddingInit = [1, 20 ::Float]
let idValues = [1, 1 :: Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
x <- TF.placeholder (TF.Shape [2])
embedding <- TF.initializedVariable
=<< TF.render (TF.constant shape embeddingInit)
op <- embeddingLookup [embedding] ids
let twoNorm = CoreOps.square $ TF.abs (op - x)
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
grad <- fmap head (TF.gradients loss [embedding])
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
-- Gradients should be zero (or close)
assertAllClose gs (V.fromList ([0, 0 :: Float]))
-- Verifies that direct gather is the same as dynamic split into
@ -138,4 +177,5 @@ main = googleTest
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
, testEmbeddingLookupHasRightShape
, testEmbeddingLookupHasRightShapeWithPartition
, testEmbeddingLookupGradients
]

3
tensorflow-test/Setup.hs Normal file
View File

@ -0,0 +1,3 @@
import Distribution.Simple
main = defaultMain

View File

@ -0,0 +1,34 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Test
( assertAllClose
) where
import qualified Data.Vector as V
import Test.HUnit ((@?))
import Test.HUnit.Lang (Assertion(..))
-- | Compares that the vectors are element-by-element equal within the given
-- tolerance. Raises an assertion and prints some information if not.
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
"Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
++ "\ntolerance: " ++ show tol
where
absDiff x y = abs (x - y)
tol = 0.001 :: Float

View File

@ -0,0 +1,24 @@
name: tensorflow-test
version: 0.1.0.0
synopsis: Some common functions for test suites.
homepage: https://github.com/tensorflow/haskell#readme
license: Apache
author: TensorFlow authors
maintainer: tensorflow-haskell@googlegroups.com
copyright: Google Inc.
category: Machine Learning
build-type: Simple
cabal-version: >=1.22
library
hs-source-dirs: src
exposed-modules: TensorFlow.Test
build-depends: base >= 4.7 && < 5
, HUnit
, vector
default-language: Haskell2010
source-repository head
type: git
location: https://github.com/tensorflow/haskell