77 lines
3.3 KiB
Haskell
77 lines
3.3 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE ConstraintKinds #-}
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
|
|
-- | Parallel lookups on the list of tensors.
|
|
module TensorFlow.EmbeddingOps where
|
|
|
|
import Control.Monad (zipWithM)
|
|
import Data.Int (Int32, Int64)
|
|
import Data.List (genericLength)
|
|
import TensorFlow.Build (Build, colocateWith, render)
|
|
import TensorFlow.Ops () -- Num instance for Tensor
|
|
import TensorFlow.Tensor (Tensor, Value)
|
|
import TensorFlow.Types (OneOf, TensorType)
|
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
|
|
|
-- | Looks up `ids` in a list of embedding tensors.
|
|
--
|
|
-- This function is used to perform parallel lookups on the list of
|
|
-- tensors in `params`. It is a generalization of `TF.gather`, where
|
|
-- `params` is interpreted as a partition of a larger embedding
|
|
-- tensor.
|
|
--
|
|
-- The partition_strategy is "mod", we assign each id to partition
|
|
-- `p = id % len(params)`. For instance,
|
|
-- 13 ids are split across 5 partitions as:
|
|
-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
|
|
--
|
|
-- The results of the lookup are concatenated into a dense
|
|
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
|
embeddingLookup :: forall a b v .
|
|
( TensorType a
|
|
, OneOf '[Int64, Int32] b
|
|
, Num b
|
|
)
|
|
=> [Tensor v a]
|
|
-- ^ A list of tensors which can be concatenated along
|
|
-- dimension 0. Each `Tensor` must be appropriately
|
|
-- sized for `mod` partition strategy.
|
|
-> Tensor Value b
|
|
-- ^ A `Tensor` with type `int32` or `int64`
|
|
-- containing the ids to be looked up in `params`.
|
|
-- The ids are required to be flat on entry and have
|
|
-- fewer than 2^31 entries.
|
|
-> Build (Tensor Value a)
|
|
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
|
embeddingLookup params ids =
|
|
CoreOps.dynamicStitch pindices <$> partitionedResult
|
|
where np = genericLength params
|
|
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
|
|
newIds = ids `CoreOps.div` np
|
|
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
|
|
-- Partition list of ids based on assignments into np separate lists
|
|
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
|
|
-- Similarly, partition the original indices.
|
|
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
|
|
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
|
|
partitionedResult = zipWithM
|
|
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
|
|
params gatherIds
|