mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
test case to show can't calculate grad for embedding (and associated fix) (#23)
* Fix for embedding gradient calculation - Passes vectors instead of scalars to slice - converts the numRows to a scalar - add `toScalar` utility function - minor change to test case so that it actually works * added lib for testing helper functions * add flatSlice function
This commit is contained in:
parent
fc3d398ca9
commit
69fdbf677f
10 changed files with 153 additions and 32 deletions
|
@ -11,6 +11,7 @@ packages:
|
|||
- tensorflow-mnist-input-data
|
||||
- tensorflow-queue
|
||||
- tensorflow-nn
|
||||
- tensorflow-test
|
||||
|
||||
extra-deps:
|
||||
# proto-lens is not yet in Stackage.
|
||||
|
|
|
@ -30,6 +30,7 @@ Test-Suite NNTest
|
|||
, QuickCheck
|
||||
, base
|
||||
, tensorflow
|
||||
, tensorflow-test
|
||||
, tensorflow-ops
|
||||
, tensorflow-nn
|
||||
, google-shim
|
||||
|
|
|
@ -22,6 +22,7 @@ module Main where
|
|||
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Test (assertAllClose)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@?))
|
||||
import Test.HUnit.Lang (Assertion(..))
|
||||
|
@ -63,15 +64,6 @@ defInputs = Inputs {
|
|||
}
|
||||
|
||||
|
||||
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
||||
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
|
||||
++ "\ntolerance: " ++ show tol)
|
||||
where
|
||||
absDiff x y = abs (x - y)
|
||||
tol = 0.001 :: Float
|
||||
|
||||
|
||||
testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||
let inputs = defInputs
|
||||
vLogits = TF.vector $ logits inputs
|
||||
|
|
|
@ -77,6 +77,7 @@ import TensorFlow.Ops
|
|||
, shape
|
||||
, softmaxCrossEntropyWithLogits
|
||||
, sum
|
||||
, scalarize
|
||||
, vector
|
||||
, zerosLike
|
||||
)
|
||||
|
@ -402,6 +403,21 @@ type GradientFunc a = NodeDef
|
|||
toT :: Output -> Tensor Value a
|
||||
toT = Tensor ValueKind
|
||||
|
||||
|
||||
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
|
||||
-- simple slicing operations.
|
||||
flatSlice :: forall v1 t i . (TensorType t)
|
||||
=> Tensor v1 t -- ^ __input__
|
||||
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
||||
-- 'input' to slice from.
|
||||
-> Int32 -- ^ __size__: specifies the number of elements of the first dimension
|
||||
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
|
||||
-- are included in the slice (i.e. this is equivalent to setting
|
||||
-- size = input.dim_size(0) - begin).
|
||||
-> Tensor Value t -- ^ __output__
|
||||
flatSlice input begin size = CoreOps.slice input (vector [begin]) (vector [size])
|
||||
|
||||
|
||||
-- | The gradient function for an op type.
|
||||
--
|
||||
-- These implementations should match their python counterparts in:
|
||||
|
@ -430,10 +446,9 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
|||
where
|
||||
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
||||
denseShape = shape (x :: Tensor Value a)
|
||||
numRows = CoreOps.slice denseShape 0 (1 :: Tensor Value Int32)
|
||||
valuesShape = CoreOps.concat 0 [
|
||||
allDimensions
|
||||
, CoreOps.slice denseShape 1 (-1 :: Tensor Value Int32)
|
||||
numRows = scalarize $ flatSlice denseShape 0 1
|
||||
valuesShape = CoreOps.concat 0 [ allDimensions
|
||||
, flatSlice denseShape 1 (-1)
|
||||
]
|
||||
values = reshape dz valuesShape
|
||||
-- TODO(fmayle): This could be either Int32 or Int64.
|
||||
|
@ -628,7 +643,7 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
|||
, Nothing
|
||||
, Nothing
|
||||
]
|
||||
where inputRows = CoreOps.slice (shape (x :: Tensor Value a)) (scalar (0 :: Int32)) (scalar 1)
|
||||
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1
|
||||
|
||||
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||
|
|
|
@ -101,6 +101,7 @@ module TensorFlow.Ops
|
|||
, vector
|
||||
, zeros
|
||||
, CoreOps.zerosLike
|
||||
, scalarize
|
||||
) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
|
@ -256,6 +257,15 @@ constant (Shape shape') values
|
|||
[def & TensorShape.size .~ x | x <- shape']
|
||||
& tensorVal .~ values
|
||||
|
||||
-- | Reshape a N-D tensor down to a scalar.
|
||||
--
|
||||
-- See `TensorFlow.GenOps.Core.reshape`.
|
||||
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
|
||||
scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||
where
|
||||
scalarShape = [] :: [Int32]
|
||||
|
||||
|
||||
-- | Create a constant vector.
|
||||
vector :: TensorType a => [a] -> Tensor Value a
|
||||
vector xs = constant [fromIntegral $ length xs] xs
|
||||
|
|
|
@ -60,6 +60,7 @@ Test-Suite EmbeddingOpsTest
|
|||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-test
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
|
|
|
@ -23,10 +23,12 @@ import Data.List (genericLength)
|
|||
import Google.Test (googleTest)
|
||||
import TensorFlow.EmbeddingOps (embeddingLookup)
|
||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||
import Test.HUnit ((@=?))
|
||||
import Test.HUnit.Lang (Assertion(..))
|
||||
import Test.HUnit ((@=?), (@?))
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||
import TensorFlow.Test (assertAllClose)
|
||||
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
@ -34,21 +36,26 @@ import qualified TensorFlow.Ops as TF
|
|||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
|
||||
|
||||
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
|
||||
buildAndRun = TF.runSession . TF.buildAnd TF.run
|
||||
|
||||
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
||||
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||
let embedding1 = [ 1, 1, 1 ] :: [Int32]
|
||||
let embedding2 = [ 0, 0, 0 ] :: [Int32]
|
||||
|
||||
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||
testEmbeddingLookupHasRightShapeWithPartition =
|
||||
testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
||||
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||
let embedding1 = [1, 1, 1 :: Int32]
|
||||
let embedding2 = [0, 0, 0 :: Int32]
|
||||
let embedding = [ TF.constant shape embedding1
|
||||
, TF.constant shape embedding2
|
||||
]
|
||||
|
||||
let idValues = [0, 1] :: [Int32]
|
||||
let idValues = [0, 1 :: Int32]
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
let op = embeddingLookup embedding ids
|
||||
|
||||
|
@ -57,20 +64,23 @@ testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHas
|
|||
return (vs, TF.shape vs)
|
||||
|
||||
-- This is the shape that is returned in the equiv. Python.
|
||||
shape @=? V.fromList [ 1, 2, 3 ]
|
||||
shape @=? V.fromList [1, 2, 3]
|
||||
|
||||
-- "[0, 1]" should pull out the resulting vector.
|
||||
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
|
||||
values @=? V.fromList [1, 1, 1, 0, 0, 0]
|
||||
|
||||
|
||||
-- | Tries to perform a simple embedding lookup, with only a single partition.
|
||||
testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do
|
||||
let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items.
|
||||
testEmbeddingLookupHasRightShape =
|
||||
testCase "testEmbeddingLookupHasRightShape" $ do
|
||||
-- Consider a 3-dim embedding of two items
|
||||
let shape = TF.Shape [2, 3]
|
||||
let embeddingInit = [ 1, 1, 1
|
||||
, 0, 0, 0 ] :: [Int32]
|
||||
, 0, 0, 0 :: Int32
|
||||
]
|
||||
|
||||
let embedding = TF.constant shape embeddingInit
|
||||
let idValues = [0, 1] :: [Int32]
|
||||
let idValues = [0, 1 :: Int32]
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
let op = embeddingLookup [embedding] ids
|
||||
|
||||
|
@ -79,10 +89,39 @@ testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $
|
|||
return (vs, TF.shape vs)
|
||||
|
||||
-- This is the shape that is returned in the equiv. Python.
|
||||
shape @=? V.fromList [ 1, 2, 3 ]
|
||||
shape @=? V.fromList [1, 2, 3]
|
||||
|
||||
-- "[0, 1]" should pull out the resulting vector.
|
||||
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
|
||||
values @=? V.fromList [1, 1, 1, 0, 0, 0]
|
||||
|
||||
|
||||
-- | Check that we can calculate gradients w.r.t embeddings.
|
||||
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||
-- Agrees with "embedding", so gradient should be zero.
|
||||
let xVals = V.fromList ([20, 20 :: Float])
|
||||
let shape = TF.Shape [2]
|
||||
|
||||
gs <- TF.runSession $ do
|
||||
grads <- TF.build $ do
|
||||
let shape = TF.Shape [2, 1]
|
||||
let embeddingInit = [1, 20 ::Float]
|
||||
let idValues = [1, 1 :: Int32]
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
|
||||
x <- TF.placeholder (TF.Shape [2])
|
||||
embedding <- TF.initializedVariable
|
||||
=<< TF.render (TF.constant shape embeddingInit)
|
||||
|
||||
op <- embeddingLookup [embedding] ids
|
||||
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
||||
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
||||
|
||||
grad <- fmap head (TF.gradients loss [embedding])
|
||||
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
|
||||
|
||||
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
|
||||
-- Gradients should be zero (or close)
|
||||
assertAllClose gs (V.fromList ([0, 0 :: Float]))
|
||||
|
||||
|
||||
-- Verifies that direct gather is the same as dynamic split into
|
||||
|
@ -138,4 +177,5 @@ main = googleTest
|
|||
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
|
||||
, testEmbeddingLookupHasRightShape
|
||||
, testEmbeddingLookupHasRightShapeWithPartition
|
||||
, testEmbeddingLookupGradients
|
||||
]
|
||||
|
|
3
tensorflow-test/Setup.hs
Normal file
3
tensorflow-test/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
|||
import Distribution.Simple
|
||||
|
||||
main = defaultMain
|
34
tensorflow-test/src/TensorFlow/Test.hs
Normal file
34
tensorflow-test/src/TensorFlow/Test.hs
Normal file
|
@ -0,0 +1,34 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module TensorFlow.Test
|
||||
( assertAllClose
|
||||
) where
|
||||
|
||||
import qualified Data.Vector as V
|
||||
import Test.HUnit ((@?))
|
||||
import Test.HUnit.Lang (Assertion(..))
|
||||
|
||||
-- | Compares that the vectors are element-by-element equal within the given
|
||||
-- tolerance. Raises an assertion and prints some information if not.
|
||||
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
||||
"Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
|
||||
++ "\ntolerance: " ++ show tol
|
||||
where
|
||||
absDiff x y = abs (x - y)
|
||||
tol = 0.001 :: Float
|
||||
|
24
tensorflow-test/tensorflow-test.cabal
Normal file
24
tensorflow-test/tensorflow-test.cabal
Normal file
|
@ -0,0 +1,24 @@
|
|||
name: tensorflow-test
|
||||
version: 0.1.0.0
|
||||
synopsis: Some common functions for test suites.
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Simple
|
||||
cabal-version: >=1.22
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.Test
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, HUnit
|
||||
, vector
|
||||
default-language: Haskell2010
|
||||
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
Loading…
Reference in a new issue