diff --git a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs index 9eb396b..af90406 100644 --- a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs +++ b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs @@ -23,9 +23,8 @@ module TensorFlow.EmbeddingOps where import Control.Monad (zipWithM) import Data.Int (Int32, Int64) -import Data.List (genericLength) import TensorFlow.Build (Build, colocateWith, render) -import TensorFlow.Ops () -- Num instance for Tensor +import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Types (OneOf, TensorType) import qualified TensorFlow.GenOps.Core as CoreOps @@ -56,21 +55,34 @@ embeddingLookup :: forall a b v . -> Tensor Value b -- ^ A `Tensor` with type `int32` or `int64` -- containing the ids to be looked up in `params`. - -- The ids are required to be flat on entry and have - -- fewer than 2^31 entries. + -- The ids are required to have fewer than 2^31 + -- entries. -> Build (Tensor Value a) -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. -embeddingLookup params ids = - CoreOps.dynamicStitch pindices <$> partitionedResult - where np = genericLength params - pAssignments = CoreOps.cast (ids `CoreOps.mod` np) - newIds = ids `CoreOps.div` np - originalIndices = CoreOps.range 0 (CoreOps.size ids) 1 - -- Partition list of ids based on assignments into np separate lists - gatherIds = CoreOps.dynamicPartition np newIds pAssignments - -- Similarly, partition the original indices. - pindices = CoreOps.dynamicPartition np originalIndices pAssignments - -- Do np separate lookups, finding embeddings for plist[p] in params[p] - partitionedResult = zipWithM - (\p g -> colocateWith p $ render $ CoreOps.gather p g) - params gatherIds +embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids) +embeddingLookup params@(p0 : _) ids = do + -- Do np separate lookups, finding embeddings for plist[p] in params[p] + partitionedResult <- zipWithM + (\p g -> colocateWith p $ render $ CoreOps.gather p g) + params gatherIds + let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult + -- Shape restoration is not as optimal as it would be with client + -- side shape tracking. + paramShape <- colocateWith p0 (render (shape p0)) + let finalShape = CoreOps.concat 0 [shape ids, tailShape] + tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1)) + render $ CoreOps.reshape unshapedResult finalShape + where + -- Avoids genericLength here which would be evaluated by TF. + np = fromIntegral (length params) + flatIds = CoreOps.reshape ids (singleton (-1)) + pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np) + newIds = flatIds `CoreOps.div` np + originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1 + -- Partition list of ids based on assignments into np separate lists + gatherIds = CoreOps.dynamicPartition np newIds pAssignments + -- Similarly, partition the original indices. + pindices = CoreOps.dynamicPartition np originalIndices pAssignments + singleton i = vector [i :: Int32] + +embeddingLookup [] _ = error "embeddingLookup requires params to be non empty" diff --git a/tensorflow-ops/tests/EmbeddingOpsTest.hs b/tensorflow-ops/tests/EmbeddingOpsTest.hs index 0a6b97d..722492b 100644 --- a/tensorflow-ops/tests/EmbeddingOpsTest.hs +++ b/tensorflow-ops/tests/EmbeddingOpsTest.hs @@ -24,6 +24,7 @@ import Google.Test (googleTest) import TensorFlow.EmbeddingOps (embeddingLookup) import Test.Framework.Providers.QuickCheck2 (testProperty) import Test.HUnit ((@=?)) +import Test.Framework.Providers.HUnit (testCase) import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf) import Test.QuickCheck.Monadic (monadicIO, run) @@ -34,6 +35,56 @@ import qualified TensorFlow.Session as TF import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Types as TF + +buildAndRun = TF.runSession . TF.buildAnd TF.run + +-- | Tries to perform a simple embedding lookup, with two partitions. +testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do + let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items. + let embedding1 = [ 1, 1, 1 ] :: [Int32] + let embedding2 = [ 0, 0, 0 ] :: [Int32] + + let embedding = [ TF.constant shape embedding1 + , TF.constant shape embedding2 + ] + + let idValues = [0, 1] :: [Int32] + let ids = TF.constant (TF.Shape [1, 2]) idValues + let op = embeddingLookup embedding ids + + (values, shape) <- buildAndRun $ do + vs <- op + return (vs, TF.shape vs) + + -- This is the shape that is returned in the equiv. Python. + shape @=? V.fromList [ 1, 2, 3 ] + + -- "[0, 1]" should pull out the resulting vector. + values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ] + + +-- | Tries to perform a simple embedding lookup, with only a single partition. +testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do + let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items. + let embeddingInit = [ 1, 1, 1 + , 0, 0, 0 ] :: [Int32] + + let embedding = TF.constant shape embeddingInit + let idValues = [0, 1] :: [Int32] + let ids = TF.constant (TF.Shape [1, 2]) idValues + let op = embeddingLookup [embedding] ids + + (values, shape) <- buildAndRun $ do + vs <- op + return (vs, TF.shape vs) + + -- This is the shape that is returned in the equiv. Python. + shape @=? V.fromList [ 1, 2, 3 ] + + -- "[0, 1]" should pull out the resulting vector. + values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ] + + -- Verifies that direct gather is the same as dynamic split into -- partitions, followed by embedding lookup. testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a) @@ -85,4 +136,6 @@ main :: IO () main = googleTest [ testProperty "EmbeddingLookupUndoesSplit" (testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property) + , testEmbeddingLookupHasRightShape + , testEmbeddingLookupHasRightShapeWithPartition ]