mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
e511f49828
Only a handful of types had sensible tensorVal implementations. This is now evident in type signatures at the expense of them being more verbose.
78 lines
3.4 KiB
Haskell
78 lines
3.4 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE ConstraintKinds #-}
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
|
|
-- | Parallel lookups on the list of tensors.
|
|
module TensorFlow.EmbeddingOps where
|
|
|
|
import Control.Monad (zipWithM)
|
|
import Data.Int (Int32, Int64)
|
|
import Data.List (genericLength)
|
|
import TensorFlow.Build (Build, colocateWith, render)
|
|
import TensorFlow.Ops () -- Num instance for Tensor
|
|
import TensorFlow.Tensor (Tensor, Value)
|
|
import TensorFlow.Types (OneOf, TensorType, TensorProtoLens)
|
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
|
|
|
-- | Looks up `ids` in a list of embedding tensors.
|
|
--
|
|
-- This function is used to perform parallel lookups on the list of
|
|
-- tensors in `params`. It is a generalization of `TF.gather`, where
|
|
-- `params` is interpreted as a partition of a larger embedding
|
|
-- tensor.
|
|
--
|
|
-- The partition_strategy is "mod", we assign each id to partition
|
|
-- `p = id % len(params)`. For instance,
|
|
-- 13 ids are split across 5 partitions as:
|
|
-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
|
|
--
|
|
-- The results of the lookup are concatenated into a dense
|
|
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
|
embeddingLookup :: forall a b v .
|
|
( TensorType a
|
|
, TensorProtoLens b
|
|
, OneOf '[Int64, Int32] b
|
|
, Num b
|
|
)
|
|
=> [Tensor v a]
|
|
-- ^ A list of tensors which can be concatenated along
|
|
-- dimension 0. Each `Tensor` must be appropriately
|
|
-- sized for `mod` partition strategy.
|
|
-> Tensor Value b
|
|
-- ^ A `Tensor` with type `int32` or `int64`
|
|
-- containing the ids to be looked up in `params`.
|
|
-- The ids are required to be flat on entry and have
|
|
-- fewer than 2^31 entries.
|
|
-> Build (Tensor Value a)
|
|
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
|
embeddingLookup params ids =
|
|
CoreOps.dynamicStitch pindices <$> partitionedResult
|
|
where np = genericLength params
|
|
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
|
|
newIds = ids `CoreOps.div` np
|
|
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
|
|
-- Partition list of ids based on assignments into np separate lists
|
|
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
|
|
-- Similarly, partition the original indices.
|
|
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
|
|
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
|
|
partitionedResult = zipWithM
|
|
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
|
|
params gatherIds
|