Safe Haskell | None |
---|---|
Language | Haskell2010 |
Parallel lookups on the list of tensors.
Documentation
:: (TensorType a, OneOf `[Int64, Int32]` b, Num b) | |
=> [Tensor v a] | A list of tensors which can be concatenated along
dimension 0. Each |
-> Tensor Value b | A |
-> Build (Tensor Value a) | A dense tensor with shape `shape(ids) + shape(params)[1:]`. |
Looks up ids
in a list of embedding tensors.
This function is used to perform parallel lookups on the list of
tensors in params
. It is a generalization of gather
, where
params
is interpreted as a partition of a larger embedding
tensor.
The partition_strategy is "mod", we assign each id to partition `p = id % len(params)`. For instance, 13 ids are split across 5 partitions as: `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
The results of the lookup are concatenated into a dense tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.