1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Tests for "embedding_lookup" and minor fix

- added a test that fails for a partitioned embedding
- added a test that passes for a single embedding
This commit is contained in:
silky 2016-11-09 10:30:05 +11:00
parent 4ec78a8fca
commit 9c81241439
2 changed files with 59 additions and 3 deletions

View file

@ -60,9 +60,12 @@ embeddingLookup :: forall a b v .
-- fewer than 2^31 entries. -- fewer than 2^31 entries.
-> Build (Tensor Value a) -> Build (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup params ids = embeddingLookup params@(p1 : _) ids =
CoreOps.dynamicStitch pindices <$> partitionedResult go (np :: Int32)
where np = genericLength params where
go 1 = colocateWith p1 (render $ CoreOps.gather p1 ids)
go _ = CoreOps.dynamicStitch pindices <$> partitionedResult
np = genericLength params
pAssignments = CoreOps.cast (ids `CoreOps.mod` np) pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
newIds = ids `CoreOps.div` np newIds = ids `CoreOps.div` np
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1 originalIndices = CoreOps.range 0 (CoreOps.size ids) 1

View file

@ -24,6 +24,7 @@ import Google.Test (googleTest)
import TensorFlow.EmbeddingOps (embeddingLookup) import TensorFlow.EmbeddingOps (embeddingLookup)
import Test.Framework.Providers.QuickCheck2 (testProperty) import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
import Test.Framework.Providers.HUnit (testCase)
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf) import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run) import Test.QuickCheck.Monadic (monadicIO, run)
@ -34,6 +35,56 @@ import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF import qualified TensorFlow.Types as TF
buildAndRun = TF.runSession . TF.buildAnd TF.run
-- | Tries to perform a simple embedding lookup, with two partitions.
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
let embedding1 = [ 1, 1, 1 ] :: [Int32]
let embedding2 = [ 0, 0, 0 ] :: [Int32]
let embedding = [ TF.constant shape embedding1
, TF.constant shape embedding2
]
let idValues = [0, 1] :: [Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids
(values, shape) <- buildAndRun $ do
vs <- op
return (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]
-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
-- | Tries to perform a simple embedding lookup, with only a single partition.
testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do
let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items.
let embeddingInit = [ 1, 1, 1
, 0, 0, 0 ] :: [Int32]
let embedding = TF.constant shape embeddingInit
let idValues = [0, 1] :: [Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids
(values, shape) <- buildAndRun $ do
vs <- op
return (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]
-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
-- Verifies that direct gather is the same as dynamic split into -- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup. -- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a) testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
@ -85,4 +136,6 @@ main :: IO ()
main = googleTest main = googleTest
[ testProperty "EmbeddingLookupUndoesSplit" [ testProperty "EmbeddingLookupUndoesSplit"
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property) (testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
, testEmbeddingLookupHasRightShape
, testEmbeddingLookupHasRightShapeWithPartition
] ]