From 69fdbf677ff5c56ee99576310ee65ccf0efdf2b2 Mon Sep 17 00:00:00 2001 From: Noon van der Silk Date: Fri, 18 Nov 2016 08:54:36 +1100 Subject: [PATCH] test case to show can't calculate grad for embedding (and associated fix) (#23) * Fix for embedding gradient calculation - Passes vectors instead of scalars to slice - converts the numRows to a scalar - add `toScalar` utility function - minor change to test case so that it actually works * added lib for testing helper functions * add flatSlice function --- stack.yaml | 1 + tensorflow-nn/tensorflow-nn.cabal | 1 + tensorflow-nn/tests/NNTest.hs | 10 +-- tensorflow-ops/src/TensorFlow/Gradient.hs | 27 +++++++-- tensorflow-ops/src/TensorFlow/Ops.hs | 10 +++ tensorflow-ops/tensorflow-ops.cabal | 1 + tensorflow-ops/tests/EmbeddingOpsTest.hs | 74 +++++++++++++++++------ tensorflow-test/Setup.hs | 3 + tensorflow-test/src/TensorFlow/Test.hs | 34 +++++++++++ tensorflow-test/tensorflow-test.cabal | 24 ++++++++ 10 files changed, 153 insertions(+), 32 deletions(-) create mode 100644 tensorflow-test/Setup.hs create mode 100644 tensorflow-test/src/TensorFlow/Test.hs create mode 100644 tensorflow-test/tensorflow-test.cabal diff --git a/stack.yaml b/stack.yaml index 862882a..ff4f8c5 100644 --- a/stack.yaml +++ b/stack.yaml @@ -11,6 +11,7 @@ packages: - tensorflow-mnist-input-data - tensorflow-queue - tensorflow-nn +- tensorflow-test extra-deps: # proto-lens is not yet in Stackage. diff --git a/tensorflow-nn/tensorflow-nn.cabal b/tensorflow-nn/tensorflow-nn.cabal index de029b9..ca51474 100644 --- a/tensorflow-nn/tensorflow-nn.cabal +++ b/tensorflow-nn/tensorflow-nn.cabal @@ -30,6 +30,7 @@ Test-Suite NNTest , QuickCheck , base , tensorflow + , tensorflow-test , tensorflow-ops , tensorflow-nn , google-shim diff --git a/tensorflow-nn/tests/NNTest.hs b/tensorflow-nn/tests/NNTest.hs index ce8a9b1..a05285f 100644 --- a/tensorflow-nn/tests/NNTest.hs +++ b/tensorflow-nn/tests/NNTest.hs @@ -22,6 +22,7 @@ module Main where import Data.Maybe (fromMaybe) import Google.Test (googleTest) +import TensorFlow.Test (assertAllClose) import Test.Framework.Providers.HUnit (testCase) import Test.HUnit ((@?)) import Test.HUnit.Lang (Assertion(..)) @@ -63,15 +64,6 @@ defInputs = Inputs { } -assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion -assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @? - ("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys - ++ "\ntolerance: " ++ show tol) - where - absDiff x y = abs (x - y) - tol = 0.001 :: Float - - testLogisticOutput = testCase "testLogisticOutput" $ do let inputs = defInputs vLogits = TF.vector $ logits inputs diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index f863e36..002c8bf 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -77,6 +77,7 @@ import TensorFlow.Ops , shape , softmaxCrossEntropyWithLogits , sum + , scalarize , vector , zerosLike ) @@ -402,6 +403,21 @@ type GradientFunc a = NodeDef toT :: Output -> Tensor Value a toT = Tensor ValueKind + +-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for +-- simple slicing operations. +flatSlice :: forall v1 t i . (TensorType t) + => Tensor v1 t -- ^ __input__ + -> Int32 -- ^ __begin__: specifies the offset into the first dimension of + -- 'input' to slice from. + -> Int32 -- ^ __size__: specifies the number of elements of the first dimension + -- of 'input' to slice. If size is -1, all remaining elements in the dimension + -- are included in the slice (i.e. this is equivalent to setting + -- size = input.dim_size(0) - begin). + -> Tensor Value t -- ^ __output__ +flatSlice input begin size = CoreOps.slice input (vector [begin]) (vector [size]) + + -- | The gradient function for an op type. -- -- These implementations should match their python counterparts in: @@ -430,11 +446,10 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] = where -- TODO(gnezdo): Use colocateWith but it requires Build monad. denseShape = shape (x :: Tensor Value a) - numRows = CoreOps.slice denseShape 0 (1 :: Tensor Value Int32) - valuesShape = CoreOps.concat 0 [ - allDimensions - , CoreOps.slice denseShape 1 (-1 :: Tensor Value Int32) - ] + numRows = scalarize $ flatSlice denseShape 0 1 + valuesShape = CoreOps.concat 0 [ allDimensions + , flatSlice denseShape 1 (-1) + ] values = reshape dz valuesShape -- TODO(fmayle): This could be either Int32 or Int64. indices' = reshape indices allDimensions :: Tensor Value Int32 @@ -628,7 +643,7 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] = , Nothing , Nothing ] - where inputRows = CoreOps.slice (shape (x :: Tensor Value a)) (scalar (0 :: Int32)) (scalar 1) + where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1 opGrad "LabelClasses" _ _ _ = [Nothing, Nothing] opGrad "LabelWeights" _ _ _ = [Nothing] diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 9d124b0..71b4bf1 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -101,6 +101,7 @@ module TensorFlow.Ops , vector , zeros , CoreOps.zerosLike + , scalarize ) where import Data.ByteString (ByteString) @@ -256,6 +257,15 @@ constant (Shape shape') values [def & TensorShape.size .~ x | x <- shape'] & tensorVal .~ values +-- | Reshape a N-D tensor down to a scalar. +-- +-- See `TensorFlow.GenOps.Core.reshape`. +scalarize :: (TensorType a) => Tensor v a -> Tensor Value a +scalarize t = CoreOps.reshape t (vector scalarShape) + where + scalarShape = [] :: [Int32] + + -- | Create a constant vector. vector :: TensorType a => [a] -> Tensor Value a vector xs = constant [fromIntegral $ length xs] xs diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index e3905af..b6b60ad 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -60,6 +60,7 @@ Test-Suite EmbeddingOpsTest , lens-family , google-shim , tensorflow + , tensorflow-test , tensorflow-core-ops , tensorflow-ops , tensorflow-proto diff --git a/tensorflow-ops/tests/EmbeddingOpsTest.hs b/tensorflow-ops/tests/EmbeddingOpsTest.hs index 722492b..c4ecaab 100644 --- a/tensorflow-ops/tests/EmbeddingOpsTest.hs +++ b/tensorflow-ops/tests/EmbeddingOpsTest.hs @@ -23,10 +23,12 @@ import Data.List (genericLength) import Google.Test (googleTest) import TensorFlow.EmbeddingOps (embeddingLookup) import Test.Framework.Providers.QuickCheck2 (testProperty) -import Test.HUnit ((@=?)) +import Test.HUnit.Lang (Assertion(..)) +import Test.HUnit ((@=?), (@?)) import Test.Framework.Providers.HUnit (testCase) import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf) import Test.QuickCheck.Monadic (monadicIO, run) +import TensorFlow.Test (assertAllClose) import qualified Data.Vector as V import qualified TensorFlow.GenOps.Core as CoreOps @@ -34,21 +36,26 @@ import qualified TensorFlow.Ops as TF import qualified TensorFlow.Session as TF import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Types as TF +import qualified TensorFlow.Gradient as TF +import qualified TensorFlow.Build as TF +import qualified TensorFlow.Nodes as TF +buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a buildAndRun = TF.runSession . TF.buildAnd TF.run + -- | Tries to perform a simple embedding lookup, with two partitions. -testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do +testEmbeddingLookupHasRightShapeWithPartition = + testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items. - let embedding1 = [ 1, 1, 1 ] :: [Int32] - let embedding2 = [ 0, 0, 0 ] :: [Int32] + let embedding1 = [1, 1, 1 :: Int32] + let embedding2 = [0, 0, 0 :: Int32] + let embedding = [ TF.constant shape embedding1 + , TF.constant shape embedding2 + ] - let embedding = [ TF.constant shape embedding1 - , TF.constant shape embedding2 - ] - - let idValues = [0, 1] :: [Int32] + let idValues = [0, 1 :: Int32] let ids = TF.constant (TF.Shape [1, 2]) idValues let op = embeddingLookup embedding ids @@ -57,20 +64,23 @@ testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHas return (vs, TF.shape vs) -- This is the shape that is returned in the equiv. Python. - shape @=? V.fromList [ 1, 2, 3 ] + shape @=? V.fromList [1, 2, 3] -- "[0, 1]" should pull out the resulting vector. - values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ] + values @=? V.fromList [1, 1, 1, 0, 0, 0] -- | Tries to perform a simple embedding lookup, with only a single partition. -testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do - let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items. +testEmbeddingLookupHasRightShape = + testCase "testEmbeddingLookupHasRightShape" $ do + -- Consider a 3-dim embedding of two items + let shape = TF.Shape [2, 3] let embeddingInit = [ 1, 1, 1 - , 0, 0, 0 ] :: [Int32] + , 0, 0, 0 :: Int32 + ] let embedding = TF.constant shape embeddingInit - let idValues = [0, 1] :: [Int32] + let idValues = [0, 1 :: Int32] let ids = TF.constant (TF.Shape [1, 2]) idValues let op = embeddingLookup [embedding] ids @@ -79,10 +89,39 @@ testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ return (vs, TF.shape vs) -- This is the shape that is returned in the equiv. Python. - shape @=? V.fromList [ 1, 2, 3 ] + shape @=? V.fromList [1, 2, 3] -- "[0, 1]" should pull out the resulting vector. - values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ] + values @=? V.fromList [1, 1, 1, 0, 0, 0] + + +-- | Check that we can calculate gradients w.r.t embeddings. +testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do + -- Agrees with "embedding", so gradient should be zero. + let xVals = V.fromList ([20, 20 :: Float]) + let shape = TF.Shape [2] + + gs <- TF.runSession $ do + grads <- TF.build $ do + let shape = TF.Shape [2, 1] + let embeddingInit = [1, 20 ::Float] + let idValues = [1, 1 :: Int32] + let ids = TF.constant (TF.Shape [1, 2]) idValues + + x <- TF.placeholder (TF.Shape [2]) + embedding <- TF.initializedVariable + =<< TF.render (TF.constant shape embeddingInit) + + op <- embeddingLookup [embedding] ids + let twoNorm = CoreOps.square $ TF.abs (op - x) + loss = TF.mean twoNorm (TF.scalar (0 :: Int32)) + + grad <- fmap head (TF.gradients loss [embedding]) + return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad + + grads (TF.encodeTensorData shape xVals :: TF.TensorData Float) + -- Gradients should be zero (or close) + assertAllClose gs (V.fromList ([0, 0 :: Float])) -- Verifies that direct gather is the same as dynamic split into @@ -138,4 +177,5 @@ main = googleTest (testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property) , testEmbeddingLookupHasRightShape , testEmbeddingLookupHasRightShapeWithPartition + , testEmbeddingLookupGradients ] diff --git a/tensorflow-test/Setup.hs b/tensorflow-test/Setup.hs new file mode 100644 index 0000000..e8ef27d --- /dev/null +++ b/tensorflow-test/Setup.hs @@ -0,0 +1,3 @@ +import Distribution.Simple + +main = defaultMain diff --git a/tensorflow-test/src/TensorFlow/Test.hs b/tensorflow-test/src/TensorFlow/Test.hs new file mode 100644 index 0000000..c12aaa2 --- /dev/null +++ b/tensorflow-test/src/TensorFlow/Test.hs @@ -0,0 +1,34 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} + +module TensorFlow.Test + ( assertAllClose + ) where + +import qualified Data.Vector as V +import Test.HUnit ((@?)) +import Test.HUnit.Lang (Assertion(..)) + +-- | Compares that the vectors are element-by-element equal within the given +-- tolerance. Raises an assertion and prints some information if not. +assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion +assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @? + "Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys + ++ "\ntolerance: " ++ show tol + where + absDiff x y = abs (x - y) + tol = 0.001 :: Float + diff --git a/tensorflow-test/tensorflow-test.cabal b/tensorflow-test/tensorflow-test.cabal new file mode 100644 index 0000000..2c3c057 --- /dev/null +++ b/tensorflow-test/tensorflow-test.cabal @@ -0,0 +1,24 @@ +name: tensorflow-test +version: 0.1.0.0 +synopsis: Some common functions for test suites. +homepage: https://github.com/tensorflow/haskell#readme +license: Apache +author: TensorFlow authors +maintainer: tensorflow-haskell@googlegroups.com +copyright: Google Inc. +category: Machine Learning +build-type: Simple +cabal-version: >=1.22 + +library + hs-source-dirs: src + exposed-modules: TensorFlow.Test + build-depends: base >= 4.7 && < 5 + , HUnit + , vector + default-language: Haskell2010 + + +source-repository head + type: git + location: https://github.com/tensorflow/haskell