-- Copyright 2016 TensorFlow authors. -- -- Licensed under the Apache License, Version 2.0 (the "License"); -- you may not use this file except in compliance with the License. -- You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- See the License for the specific language governing permissions and -- limitations under the License. {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE NoMonomorphismRestriction #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} -- | Parallel lookups on the list of tensors. module TensorFlow.EmbeddingOps where import Control.Monad (zipWithM) import Data.Int (Int32, Int64) import Data.List (genericLength) import TensorFlow.Build (Build, colocateWith, render) import TensorFlow.Ops () -- Num instance for Tensor import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Types (OneOf, TensorType) import qualified TensorFlow.GenOps.Core as CoreOps -- | Looks up `ids` in a list of embedding tensors. -- -- This function is used to perform parallel lookups on the list of -- tensors in `params`. It is a generalization of `TF.gather`, where -- `params` is interpreted as a partition of a larger embedding -- tensor. -- -- The partition_strategy is "mod", we assign each id to partition -- `p = id % len(params)`. For instance, -- 13 ids are split across 5 partitions as: -- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` -- -- The results of the lookup are concatenated into a dense -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. embeddingLookup :: forall a b v . ( TensorType a , OneOf '[Int64, Int32] b , Num b ) => [Tensor v a] -- ^ A list of tensors which can be concatenated along -- dimension 0. Each `Tensor` must be appropriately -- sized for `mod` partition strategy. -> Tensor Value b -- ^ A `Tensor` with type `int32` or `int64` -- containing the ids to be looked up in `params`. -- The ids are required to be flat on entry and have -- fewer than 2^31 entries. -> Build (Tensor Value a) -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. embeddingLookup params ids = CoreOps.dynamicStitch pindices <$> partitionedResult where np = genericLength params pAssignments = CoreOps.cast (ids `CoreOps.mod` np) newIds = ids `CoreOps.div` np originalIndices = CoreOps.range 0 (CoreOps.size ids) 1 -- Partition list of ids based on assignments into np separate lists gatherIds = CoreOps.dynamicPartition np newIds pAssignments -- Similarly, partition the original indices. pindices = CoreOps.dynamicPartition np originalIndices pAssignments -- Do np separate lookups, finding embeddings for plist[p] in params[p] partitionedResult = zipWithM (\p g -> colocateWith p $ render $ CoreOps.gather p g) params gatherIds