1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Merge pull request #22 from tensorflow/embedding-lookup-fix

Embedding lookup fix
This commit is contained in:
Greg Steuck 2016-11-09 15:59:10 -08:00 committed by GitHub
commit 9e005e3af7
2 changed files with 83 additions and 18 deletions

View file

@ -23,9 +23,8 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM) import Control.Monad (zipWithM)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render) import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops () -- Num instance for Tensor import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType) import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
@ -56,21 +55,34 @@ embeddingLookup :: forall a b v .
-> Tensor Value b -> Tensor Value b
-- ^ A `Tensor` with type `int32` or `int64` -- ^ A `Tensor` with type `int32` or `int64`
-- containing the ids to be looked up in `params`. -- containing the ids to be looked up in `params`.
-- The ids are required to be flat on entry and have -- The ids are required to have fewer than 2^31
-- fewer than 2^31 entries. -- entries.
-> Build (Tensor Value a) -> Build (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup params ids = embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
CoreOps.dynamicStitch pindices <$> partitionedResult embeddingLookup params@(p0 : _) ids = do
where np = genericLength params -- Do np separate lookups, finding embeddings for plist[p] in params[p]
pAssignments = CoreOps.cast (ids `CoreOps.mod` np) partitionedResult <- zipWithM
newIds = ids `CoreOps.div` np (\p g -> colocateWith p $ render $ CoreOps.gather p g)
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1 params gatherIds
-- Partition list of ids based on assignments into np separate lists let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult
gatherIds = CoreOps.dynamicPartition np newIds pAssignments -- Shape restoration is not as optimal as it would be with client
-- Similarly, partition the original indices. -- side shape tracking.
pindices = CoreOps.dynamicPartition np originalIndices pAssignments paramShape <- colocateWith p0 (render (shape p0))
-- Do np separate lookups, finding embeddings for plist[p] in params[p] let finalShape = CoreOps.concat 0 [shape ids, tailShape]
partitionedResult = zipWithM tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
(\p g -> colocateWith p $ render $ CoreOps.gather p g) render $ CoreOps.reshape unshapedResult finalShape
params gatherIds where
-- Avoids genericLength here which would be evaluated by TF.
np = fromIntegral (length params)
flatIds = CoreOps.reshape ids (singleton (-1))
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
newIds = flatIds `CoreOps.div` np
originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1
-- Partition list of ids based on assignments into np separate lists
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
-- Similarly, partition the original indices.
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
singleton i = vector [i :: Int32]
embeddingLookup [] _ = error "embeddingLookup requires params to be non empty"

View file

@ -24,6 +24,7 @@ import Google.Test (googleTest)
import TensorFlow.EmbeddingOps (embeddingLookup) import TensorFlow.EmbeddingOps (embeddingLookup)
import Test.Framework.Providers.QuickCheck2 (testProperty) import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
import Test.Framework.Providers.HUnit (testCase)
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf) import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run) import Test.QuickCheck.Monadic (monadicIO, run)
@ -34,6 +35,56 @@ import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF import qualified TensorFlow.Types as TF
buildAndRun = TF.runSession . TF.buildAnd TF.run
-- | Tries to perform a simple embedding lookup, with two partitions.
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
let embedding1 = [ 1, 1, 1 ] :: [Int32]
let embedding2 = [ 0, 0, 0 ] :: [Int32]
let embedding = [ TF.constant shape embedding1
, TF.constant shape embedding2
]
let idValues = [0, 1] :: [Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids
(values, shape) <- buildAndRun $ do
vs <- op
return (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]
-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
-- | Tries to perform a simple embedding lookup, with only a single partition.
testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do
let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items.
let embeddingInit = [ 1, 1, 1
, 0, 0, 0 ] :: [Int32]
let embedding = TF.constant shape embeddingInit
let idValues = [0, 1] :: [Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids
(values, shape) <- buildAndRun $ do
vs <- op
return (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]
-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
-- Verifies that direct gather is the same as dynamic split into -- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup. -- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a) testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
@ -85,4 +136,6 @@ main :: IO ()
main = googleTest main = googleTest
[ testProperty "EmbeddingLookupUndoesSplit" [ testProperty "EmbeddingLookupUndoesSplit"
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property) (testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
, testEmbeddingLookupHasRightShape
, testEmbeddingLookupHasRightShapeWithPartition
] ]