mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Tests for "embedding_lookup" and minor fix
- added a test that fails for a partitioned embedding - added a test that passes for a single embedding
This commit is contained in:
parent
4ec78a8fca
commit
9c81241439
2 changed files with 59 additions and 3 deletions
|
@ -60,9 +60,12 @@ embeddingLookup :: forall a b v .
|
||||||
-- fewer than 2^31 entries.
|
-- fewer than 2^31 entries.
|
||||||
-> Build (Tensor Value a)
|
-> Build (Tensor Value a)
|
||||||
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
||||||
embeddingLookup params ids =
|
embeddingLookup params@(p1 : _) ids =
|
||||||
CoreOps.dynamicStitch pindices <$> partitionedResult
|
go (np :: Int32)
|
||||||
where np = genericLength params
|
where
|
||||||
|
go 1 = colocateWith p1 (render $ CoreOps.gather p1 ids)
|
||||||
|
go _ = CoreOps.dynamicStitch pindices <$> partitionedResult
|
||||||
|
np = genericLength params
|
||||||
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
|
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
|
||||||
newIds = ids `CoreOps.div` np
|
newIds = ids `CoreOps.div` np
|
||||||
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
|
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
|
||||||
|
|
|
@ -24,6 +24,7 @@ import Google.Test (googleTest)
|
||||||
import TensorFlow.EmbeddingOps (embeddingLookup)
|
import TensorFlow.EmbeddingOps (embeddingLookup)
|
||||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?))
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||||
import Test.QuickCheck.Monadic (monadicIO, run)
|
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||||
|
|
||||||
|
@ -34,6 +35,56 @@ import qualified TensorFlow.Session as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
import qualified TensorFlow.Tensor as TF
|
||||||
import qualified TensorFlow.Types as TF
|
import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
|
|
||||||
|
buildAndRun = TF.runSession . TF.buildAnd TF.run
|
||||||
|
|
||||||
|
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||||
|
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
||||||
|
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||||
|
let embedding1 = [ 1, 1, 1 ] :: [Int32]
|
||||||
|
let embedding2 = [ 0, 0, 0 ] :: [Int32]
|
||||||
|
|
||||||
|
let embedding = [ TF.constant shape embedding1
|
||||||
|
, TF.constant shape embedding2
|
||||||
|
]
|
||||||
|
|
||||||
|
let idValues = [0, 1] :: [Int32]
|
||||||
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
|
let op = embeddingLookup embedding ids
|
||||||
|
|
||||||
|
(values, shape) <- buildAndRun $ do
|
||||||
|
vs <- op
|
||||||
|
return (vs, TF.shape vs)
|
||||||
|
|
||||||
|
-- This is the shape that is returned in the equiv. Python.
|
||||||
|
shape @=? V.fromList [ 1, 2, 3 ]
|
||||||
|
|
||||||
|
-- "[0, 1]" should pull out the resulting vector.
|
||||||
|
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
|
||||||
|
|
||||||
|
|
||||||
|
-- | Tries to perform a simple embedding lookup, with only a single partition.
|
||||||
|
testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do
|
||||||
|
let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items.
|
||||||
|
let embeddingInit = [ 1, 1, 1
|
||||||
|
, 0, 0, 0 ] :: [Int32]
|
||||||
|
|
||||||
|
let embedding = TF.constant shape embeddingInit
|
||||||
|
let idValues = [0, 1] :: [Int32]
|
||||||
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
|
let op = embeddingLookup [embedding] ids
|
||||||
|
|
||||||
|
(values, shape) <- buildAndRun $ do
|
||||||
|
vs <- op
|
||||||
|
return (vs, TF.shape vs)
|
||||||
|
|
||||||
|
-- This is the shape that is returned in the equiv. Python.
|
||||||
|
shape @=? V.fromList [ 1, 2, 3 ]
|
||||||
|
|
||||||
|
-- "[0, 1]" should pull out the resulting vector.
|
||||||
|
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
|
||||||
|
|
||||||
|
|
||||||
-- Verifies that direct gather is the same as dynamic split into
|
-- Verifies that direct gather is the same as dynamic split into
|
||||||
-- partitions, followed by embedding lookup.
|
-- partitions, followed by embedding lookup.
|
||||||
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
|
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
|
||||||
|
@ -85,4 +136,6 @@ main :: IO ()
|
||||||
main = googleTest
|
main = googleTest
|
||||||
[ testProperty "EmbeddingLookupUndoesSplit"
|
[ testProperty "EmbeddingLookupUndoesSplit"
|
||||||
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
|
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
|
||||||
|
, testEmbeddingLookupHasRightShape
|
||||||
|
, testEmbeddingLookupHasRightShapeWithPartition
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue