mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
test case to show can't calculate grad for embedding (and associated fix) (#23)
* Fix for embedding gradient calculation - Passes vectors instead of scalars to slice - converts the numRows to a scalar - add `toScalar` utility function - minor change to test case so that it actually works * added lib for testing helper functions * add flatSlice function
This commit is contained in:
parent
fc3d398ca9
commit
69fdbf677f
10 changed files with 153 additions and 32 deletions
|
@ -11,6 +11,7 @@ packages:
|
||||||
- tensorflow-mnist-input-data
|
- tensorflow-mnist-input-data
|
||||||
- tensorflow-queue
|
- tensorflow-queue
|
||||||
- tensorflow-nn
|
- tensorflow-nn
|
||||||
|
- tensorflow-test
|
||||||
|
|
||||||
extra-deps:
|
extra-deps:
|
||||||
# proto-lens is not yet in Stackage.
|
# proto-lens is not yet in Stackage.
|
||||||
|
|
|
@ -30,6 +30,7 @@ Test-Suite NNTest
|
||||||
, QuickCheck
|
, QuickCheck
|
||||||
, base
|
, base
|
||||||
, tensorflow
|
, tensorflow
|
||||||
|
, tensorflow-test
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
, tensorflow-nn
|
, tensorflow-nn
|
||||||
, google-shim
|
, google-shim
|
||||||
|
|
|
@ -22,6 +22,7 @@ module Main where
|
||||||
|
|
||||||
import Data.Maybe (fromMaybe)
|
import Data.Maybe (fromMaybe)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
|
import TensorFlow.Test (assertAllClose)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@?))
|
import Test.HUnit ((@?))
|
||||||
import Test.HUnit.Lang (Assertion(..))
|
import Test.HUnit.Lang (Assertion(..))
|
||||||
|
@ -63,15 +64,6 @@ defInputs = Inputs {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
|
||||||
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
|
||||||
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
|
|
||||||
++ "\ntolerance: " ++ show tol)
|
|
||||||
where
|
|
||||||
absDiff x y = abs (x - y)
|
|
||||||
tol = 0.001 :: Float
|
|
||||||
|
|
||||||
|
|
||||||
testLogisticOutput = testCase "testLogisticOutput" $ do
|
testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||||
let inputs = defInputs
|
let inputs = defInputs
|
||||||
vLogits = TF.vector $ logits inputs
|
vLogits = TF.vector $ logits inputs
|
||||||
|
|
|
@ -77,6 +77,7 @@ import TensorFlow.Ops
|
||||||
, shape
|
, shape
|
||||||
, softmaxCrossEntropyWithLogits
|
, softmaxCrossEntropyWithLogits
|
||||||
, sum
|
, sum
|
||||||
|
, scalarize
|
||||||
, vector
|
, vector
|
||||||
, zerosLike
|
, zerosLike
|
||||||
)
|
)
|
||||||
|
@ -402,6 +403,21 @@ type GradientFunc a = NodeDef
|
||||||
toT :: Output -> Tensor Value a
|
toT :: Output -> Tensor Value a
|
||||||
toT = Tensor ValueKind
|
toT = Tensor ValueKind
|
||||||
|
|
||||||
|
|
||||||
|
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
|
||||||
|
-- simple slicing operations.
|
||||||
|
flatSlice :: forall v1 t i . (TensorType t)
|
||||||
|
=> Tensor v1 t -- ^ __input__
|
||||||
|
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
||||||
|
-- 'input' to slice from.
|
||||||
|
-> Int32 -- ^ __size__: specifies the number of elements of the first dimension
|
||||||
|
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
|
||||||
|
-- are included in the slice (i.e. this is equivalent to setting
|
||||||
|
-- size = input.dim_size(0) - begin).
|
||||||
|
-> Tensor Value t -- ^ __output__
|
||||||
|
flatSlice input begin size = CoreOps.slice input (vector [begin]) (vector [size])
|
||||||
|
|
||||||
|
|
||||||
-- | The gradient function for an op type.
|
-- | The gradient function for an op type.
|
||||||
--
|
--
|
||||||
-- These implementations should match their python counterparts in:
|
-- These implementations should match their python counterparts in:
|
||||||
|
@ -430,10 +446,9 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
||||||
where
|
where
|
||||||
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
||||||
denseShape = shape (x :: Tensor Value a)
|
denseShape = shape (x :: Tensor Value a)
|
||||||
numRows = CoreOps.slice denseShape 0 (1 :: Tensor Value Int32)
|
numRows = scalarize $ flatSlice denseShape 0 1
|
||||||
valuesShape = CoreOps.concat 0 [
|
valuesShape = CoreOps.concat 0 [ allDimensions
|
||||||
allDimensions
|
, flatSlice denseShape 1 (-1)
|
||||||
, CoreOps.slice denseShape 1 (-1 :: Tensor Value Int32)
|
|
||||||
]
|
]
|
||||||
values = reshape dz valuesShape
|
values = reshape dz valuesShape
|
||||||
-- TODO(fmayle): This could be either Int32 or Int64.
|
-- TODO(fmayle): This could be either Int32 or Int64.
|
||||||
|
@ -628,7 +643,7 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
||||||
, Nothing
|
, Nothing
|
||||||
, Nothing
|
, Nothing
|
||||||
]
|
]
|
||||||
where inputRows = CoreOps.slice (shape (x :: Tensor Value a)) (scalar (0 :: Int32)) (scalar 1)
|
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1
|
||||||
|
|
||||||
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||||
|
|
|
@ -101,6 +101,7 @@ module TensorFlow.Ops
|
||||||
, vector
|
, vector
|
||||||
, zeros
|
, zeros
|
||||||
, CoreOps.zerosLike
|
, CoreOps.zerosLike
|
||||||
|
, scalarize
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Data.ByteString (ByteString)
|
import Data.ByteString (ByteString)
|
||||||
|
@ -256,6 +257,15 @@ constant (Shape shape') values
|
||||||
[def & TensorShape.size .~ x | x <- shape']
|
[def & TensorShape.size .~ x | x <- shape']
|
||||||
& tensorVal .~ values
|
& tensorVal .~ values
|
||||||
|
|
||||||
|
-- | Reshape a N-D tensor down to a scalar.
|
||||||
|
--
|
||||||
|
-- See `TensorFlow.GenOps.Core.reshape`.
|
||||||
|
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
|
||||||
|
scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||||
|
where
|
||||||
|
scalarShape = [] :: [Int32]
|
||||||
|
|
||||||
|
|
||||||
-- | Create a constant vector.
|
-- | Create a constant vector.
|
||||||
vector :: TensorType a => [a] -> Tensor Value a
|
vector :: TensorType a => [a] -> Tensor Value a
|
||||||
vector xs = constant [fromIntegral $ length xs] xs
|
vector xs = constant [fromIntegral $ length xs] xs
|
||||||
|
|
|
@ -60,6 +60,7 @@ Test-Suite EmbeddingOpsTest
|
||||||
, lens-family
|
, lens-family
|
||||||
, google-shim
|
, google-shim
|
||||||
, tensorflow
|
, tensorflow
|
||||||
|
, tensorflow-test
|
||||||
, tensorflow-core-ops
|
, tensorflow-core-ops
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
, tensorflow-proto
|
, tensorflow-proto
|
||||||
|
|
|
@ -23,10 +23,12 @@ import Data.List (genericLength)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import TensorFlow.EmbeddingOps (embeddingLookup)
|
import TensorFlow.EmbeddingOps (embeddingLookup)
|
||||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit.Lang (Assertion(..))
|
||||||
|
import Test.HUnit ((@=?), (@?))
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||||
import Test.QuickCheck.Monadic (monadicIO, run)
|
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||||
|
import TensorFlow.Test (assertAllClose)
|
||||||
|
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
@ -34,21 +36,26 @@ import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
import qualified TensorFlow.Tensor as TF
|
||||||
import qualified TensorFlow.Types as TF
|
import qualified TensorFlow.Types as TF
|
||||||
|
import qualified TensorFlow.Gradient as TF
|
||||||
|
import qualified TensorFlow.Build as TF
|
||||||
|
import qualified TensorFlow.Nodes as TF
|
||||||
|
|
||||||
|
|
||||||
|
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
|
||||||
buildAndRun = TF.runSession . TF.buildAnd TF.run
|
buildAndRun = TF.runSession . TF.buildAnd TF.run
|
||||||
|
|
||||||
-- | Tries to perform a simple embedding lookup, with two partitions.
|
|
||||||
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
|
||||||
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
|
||||||
let embedding1 = [ 1, 1, 1 ] :: [Int32]
|
|
||||||
let embedding2 = [ 0, 0, 0 ] :: [Int32]
|
|
||||||
|
|
||||||
|
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||||
|
testEmbeddingLookupHasRightShapeWithPartition =
|
||||||
|
testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
||||||
|
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||||
|
let embedding1 = [1, 1, 1 :: Int32]
|
||||||
|
let embedding2 = [0, 0, 0 :: Int32]
|
||||||
let embedding = [ TF.constant shape embedding1
|
let embedding = [ TF.constant shape embedding1
|
||||||
, TF.constant shape embedding2
|
, TF.constant shape embedding2
|
||||||
]
|
]
|
||||||
|
|
||||||
let idValues = [0, 1] :: [Int32]
|
let idValues = [0, 1 :: Int32]
|
||||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
let op = embeddingLookup embedding ids
|
let op = embeddingLookup embedding ids
|
||||||
|
|
||||||
|
@ -57,20 +64,23 @@ testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHas
|
||||||
return (vs, TF.shape vs)
|
return (vs, TF.shape vs)
|
||||||
|
|
||||||
-- This is the shape that is returned in the equiv. Python.
|
-- This is the shape that is returned in the equiv. Python.
|
||||||
shape @=? V.fromList [ 1, 2, 3 ]
|
shape @=? V.fromList [1, 2, 3]
|
||||||
|
|
||||||
-- "[0, 1]" should pull out the resulting vector.
|
-- "[0, 1]" should pull out the resulting vector.
|
||||||
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
|
values @=? V.fromList [1, 1, 1, 0, 0, 0]
|
||||||
|
|
||||||
|
|
||||||
-- | Tries to perform a simple embedding lookup, with only a single partition.
|
-- | Tries to perform a simple embedding lookup, with only a single partition.
|
||||||
testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do
|
testEmbeddingLookupHasRightShape =
|
||||||
let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items.
|
testCase "testEmbeddingLookupHasRightShape" $ do
|
||||||
|
-- Consider a 3-dim embedding of two items
|
||||||
|
let shape = TF.Shape [2, 3]
|
||||||
let embeddingInit = [ 1, 1, 1
|
let embeddingInit = [ 1, 1, 1
|
||||||
, 0, 0, 0 ] :: [Int32]
|
, 0, 0, 0 :: Int32
|
||||||
|
]
|
||||||
|
|
||||||
let embedding = TF.constant shape embeddingInit
|
let embedding = TF.constant shape embeddingInit
|
||||||
let idValues = [0, 1] :: [Int32]
|
let idValues = [0, 1 :: Int32]
|
||||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
let op = embeddingLookup [embedding] ids
|
let op = embeddingLookup [embedding] ids
|
||||||
|
|
||||||
|
@ -79,10 +89,39 @@ testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $
|
||||||
return (vs, TF.shape vs)
|
return (vs, TF.shape vs)
|
||||||
|
|
||||||
-- This is the shape that is returned in the equiv. Python.
|
-- This is the shape that is returned in the equiv. Python.
|
||||||
shape @=? V.fromList [ 1, 2, 3 ]
|
shape @=? V.fromList [1, 2, 3]
|
||||||
|
|
||||||
-- "[0, 1]" should pull out the resulting vector.
|
-- "[0, 1]" should pull out the resulting vector.
|
||||||
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
|
values @=? V.fromList [1, 1, 1, 0, 0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
-- | Check that we can calculate gradients w.r.t embeddings.
|
||||||
|
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||||
|
-- Agrees with "embedding", so gradient should be zero.
|
||||||
|
let xVals = V.fromList ([20, 20 :: Float])
|
||||||
|
let shape = TF.Shape [2]
|
||||||
|
|
||||||
|
gs <- TF.runSession $ do
|
||||||
|
grads <- TF.build $ do
|
||||||
|
let shape = TF.Shape [2, 1]
|
||||||
|
let embeddingInit = [1, 20 ::Float]
|
||||||
|
let idValues = [1, 1 :: Int32]
|
||||||
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
|
|
||||||
|
x <- TF.placeholder (TF.Shape [2])
|
||||||
|
embedding <- TF.initializedVariable
|
||||||
|
=<< TF.render (TF.constant shape embeddingInit)
|
||||||
|
|
||||||
|
op <- embeddingLookup [embedding] ids
|
||||||
|
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
||||||
|
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
||||||
|
|
||||||
|
grad <- fmap head (TF.gradients loss [embedding])
|
||||||
|
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
|
||||||
|
|
||||||
|
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
|
||||||
|
-- Gradients should be zero (or close)
|
||||||
|
assertAllClose gs (V.fromList ([0, 0 :: Float]))
|
||||||
|
|
||||||
|
|
||||||
-- Verifies that direct gather is the same as dynamic split into
|
-- Verifies that direct gather is the same as dynamic split into
|
||||||
|
@ -138,4 +177,5 @@ main = googleTest
|
||||||
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
|
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
|
||||||
, testEmbeddingLookupHasRightShape
|
, testEmbeddingLookupHasRightShape
|
||||||
, testEmbeddingLookupHasRightShapeWithPartition
|
, testEmbeddingLookupHasRightShapeWithPartition
|
||||||
|
, testEmbeddingLookupGradients
|
||||||
]
|
]
|
||||||
|
|
3
tensorflow-test/Setup.hs
Normal file
3
tensorflow-test/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
34
tensorflow-test/src/TensorFlow/Test.hs
Normal file
34
tensorflow-test/src/TensorFlow/Test.hs
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
module TensorFlow.Test
|
||||||
|
( assertAllClose
|
||||||
|
) where
|
||||||
|
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
import Test.HUnit ((@?))
|
||||||
|
import Test.HUnit.Lang (Assertion(..))
|
||||||
|
|
||||||
|
-- | Compares that the vectors are element-by-element equal within the given
|
||||||
|
-- tolerance. Raises an assertion and prints some information if not.
|
||||||
|
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||||
|
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
||||||
|
"Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
|
||||||
|
++ "\ntolerance: " ++ show tol
|
||||||
|
where
|
||||||
|
absDiff x y = abs (x - y)
|
||||||
|
tol = 0.001 :: Float
|
||||||
|
|
24
tensorflow-test/tensorflow-test.cabal
Normal file
24
tensorflow-test/tensorflow-test.cabal
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
name: tensorflow-test
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Some common functions for test suites.
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.Test
|
||||||
|
build-depends: base >= 4.7 && < 5
|
||||||
|
, HUnit
|
||||||
|
, vector
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
Loading…
Reference in a new issue