2016-10-24 21:26:42 +02:00
|
|
|
-- Copyright 2016 TensorFlow authors.
|
|
|
|
--
|
|
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
-- you may not use this file except in compliance with the License.
|
|
|
|
-- You may obtain a copy of the License at
|
|
|
|
--
|
|
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
--
|
|
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
-- See the License for the specific language governing permissions and
|
|
|
|
-- limitations under the License.
|
|
|
|
|
|
|
|
{-# LANGUAGE ConstraintKinds #-}
|
|
|
|
{-# LANGUAGE DataKinds #-}
|
2016-11-29 06:15:09 +01:00
|
|
|
{-# LANGUAGE FlexibleContexts #-}
|
2016-10-24 21:26:42 +02:00
|
|
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
|
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
{-# LANGUAGE RankNTypes #-}
|
|
|
|
|
|
|
|
-- | Parallel lookups on the list of tensors.
|
|
|
|
module TensorFlow.EmbeddingOps where
|
|
|
|
|
|
|
|
import Control.Monad (zipWithM)
|
|
|
|
import Data.Int (Int32, Int64)
|
2017-04-07 00:10:33 +02:00
|
|
|
import TensorFlow.Build (MonadBuild)
|
2016-11-18 19:42:02 +01:00
|
|
|
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
2017-04-07 00:10:33 +02:00
|
|
|
import TensorFlow.Tensor (Tensor, Value, Rendered, colocateWith, render)
|
2016-10-24 21:26:42 +02:00
|
|
|
import TensorFlow.Types (OneOf, TensorType)
|
|
|
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
|
|
|
|
|
|
|
-- | Looks up `ids` in a list of embedding tensors.
|
|
|
|
--
|
|
|
|
-- This function is used to perform parallel lookups on the list of
|
|
|
|
-- tensors in `params`. It is a generalization of `TF.gather`, where
|
|
|
|
-- `params` is interpreted as a partition of a larger embedding
|
|
|
|
-- tensor.
|
|
|
|
--
|
|
|
|
-- The partition_strategy is "mod", we assign each id to partition
|
|
|
|
-- `p = id % len(params)`. For instance,
|
|
|
|
-- 13 ids are split across 5 partitions as:
|
|
|
|
-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
|
|
|
|
--
|
|
|
|
-- The results of the lookup are concatenated into a dense
|
|
|
|
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
2017-04-07 00:10:33 +02:00
|
|
|
embeddingLookup :: forall a b v1 v2 m .
|
2017-03-18 20:08:53 +01:00
|
|
|
( MonadBuild m
|
2017-05-14 22:32:19 +02:00
|
|
|
, Rendered (Tensor v1)
|
2017-03-18 20:08:53 +01:00
|
|
|
, TensorType a
|
2016-10-24 21:26:42 +02:00
|
|
|
, OneOf '[Int64, Int32] b
|
|
|
|
, Num b
|
|
|
|
)
|
2017-04-07 00:10:33 +02:00
|
|
|
=> [Tensor v1 a]
|
2016-10-24 21:26:42 +02:00
|
|
|
-- ^ A list of tensors which can be concatenated along
|
|
|
|
-- dimension 0. Each `Tensor` must be appropriately
|
|
|
|
-- sized for `mod` partition strategy.
|
2017-04-07 00:10:33 +02:00
|
|
|
-> Tensor v2 b
|
2016-10-24 21:26:42 +02:00
|
|
|
-- ^ A `Tensor` with type `int32` or `int64`
|
|
|
|
-- containing the ids to be looked up in `params`.
|
2016-11-09 20:47:49 +01:00
|
|
|
-- The ids are required to have fewer than 2^31
|
|
|
|
-- entries.
|
2017-03-18 20:08:53 +01:00
|
|
|
-> m (Tensor Value a)
|
2016-10-24 21:26:42 +02:00
|
|
|
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
2016-11-09 20:47:49 +01:00
|
|
|
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
|
|
|
|
embeddingLookup params@(p0 : _) ids = do
|
|
|
|
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
|
|
|
|
partitionedResult <- zipWithM
|
|
|
|
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
|
|
|
|
params gatherIds
|
|
|
|
let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult
|
|
|
|
-- Shape restoration is not as optimal as it would be with client
|
|
|
|
-- side shape tracking.
|
|
|
|
paramShape <- colocateWith p0 (render (shape p0))
|
|
|
|
let finalShape = CoreOps.concat 0 [shape ids, tailShape]
|
|
|
|
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
|
|
|
|
render $ CoreOps.reshape unshapedResult finalShape
|
|
|
|
where
|
2016-11-09 23:20:26 +01:00
|
|
|
-- Avoids genericLength here which would be evaluated by TF.
|
|
|
|
np = fromIntegral (length params)
|
2016-11-09 20:47:49 +01:00
|
|
|
flatIds = CoreOps.reshape ids (singleton (-1))
|
|
|
|
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
|
|
|
|
newIds = flatIds `CoreOps.div` np
|
|
|
|
originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1
|
|
|
|
-- Partition list of ids based on assignments into np separate lists
|
|
|
|
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
|
|
|
|
-- Similarly, partition the original indices.
|
|
|
|
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
|
|
|
|
singleton i = vector [i :: Int32]
|
|
|
|
|
|
|
|
embeddingLookup [] _ = error "embeddingLookup requires params to be non empty"
|