Safe Haskell | None |
---|---|
Language | Haskell2010 |
Parallel lookups on the list of tensors.
Synopsis
- embeddingLookup :: forall a b v1 v2 m. (MonadBuild m, Rendered (Tensor v1), TensorType a, OneOf '[Int64, Int32] b, Num b) => [Tensor v1 a] -> Tensor v2 b -> m (Tensor Value a)
Documentation
:: forall a b v1 v2 m. (MonadBuild m, Rendered (Tensor v1), TensorType a, OneOf '[Int64, Int32] b, Num b) | |
=> [Tensor v1 a] | A list of tensors which can be concatenated along
dimension 0. Each |
-> Tensor v2 b | A |
-> m (Tensor Value a) | A dense tensor with shape `shape(ids) + shape(params)[1:]`. |
Looks up ids
in a list of embedding tensors.
This function is used to perform parallel lookups on the list of
tensors in params
. It is a generalization of gather
, where
params
is interpreted as a partition of a larger embedding
tensor.
The partition_strategy is "mod", we assign each id to partition `p = id % len(params)`. For instance, 13 ids are split across 5 partitions as: `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
The results of the lookup are concatenated into a dense tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.