<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"><head><link rel="stylesheet" type="text/css" href="style.css" /><script type="text/javascript" src="highlight.js"></script></head><body><pre><span class="hs-comment">-- Copyright 2016 TensorFlow authors.</span><span> </span><a name="line-2"></a><span class="hs-comment">--</span><span> </span><a name="line-3"></a><span class="hs-comment">-- Licensed under the Apache License, Version 2.0 (the "License");</span><span> </span><a name="line-4"></a><span class="hs-comment">-- you may not use this file except in compliance with the License.</span><span> </span><a name="line-5"></a><span class="hs-comment">-- You may obtain a copy of the License at</span><span> </span><a name="line-6"></a><span class="hs-comment">--</span><span> </span><a name="line-7"></a><span class="hs-comment">-- http://www.apache.org/licenses/LICENSE-2.0</span><span> </span><a name="line-8"></a><span class="hs-comment">--</span><span> </span><a name="line-9"></a><span class="hs-comment">-- Unless required by applicable law or agreed to in writing, software</span><span> </span><a name="line-10"></a><span class="hs-comment">-- distributed under the License is distributed on an "AS IS" BASIS,</span><span> </span><a name="line-11"></a><span class="hs-comment">-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span><span> </span><a name="line-12"></a><span class="hs-comment">-- See the License for the specific language governing permissions and</span><span> </span><a name="line-13"></a><span class="hs-comment">-- limitations under the License.</span><span> </span><a name="line-14"></a><span> </span><a name="line-15"></a><span class="hs-pragma">{-# LANGUAGE ConstraintKinds #-}</span><span> </span><a name="line-16"></a><span class="hs-pragma">{-# LANGUAGE DataKinds #-}</span><span> </span><a name="line-17"></a><span class="hs-pragma">{-# LANGUAGE FlexibleContexts #-}</span><span> </span><a name="line-18"></a><span class="hs-pragma">{-# LANGUAGE NoMonomorphismRestriction #-}</span><span> </span><a name="line-19"></a><span class="hs-pragma">{-# LANGUAGE OverloadedStrings #-}</span><span> </span><a name="line-20"></a><span class="hs-pragma">{-# LANGUAGE RankNTypes #-}</span><span> </span><a name="line-21"></a><span> </span><a name="line-22"></a><span class="hs-comment">-- | Parallel lookups on the list of tensors.</span><span> </span><a name="line-23"></a><span class="hs-keyword">module</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">EmbeddingOps</span><span> </span><span class="hs-keyword">where</span><span> </span><a name="line-24"></a><span> </span><a name="line-25"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Control</span><span class="hs-operator">.</span><span class="hs-identifier">Monad</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">zipWithM</span><span class="hs-special">)</span><span> </span><a name="line-26"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Int</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Int32</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Int64</span><span class="hs-special">)</span><span> </span><a name="line-27"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Build</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">MonadBuild</span><span class="hs-special">)</span><span> </span><a name="line-28"></a><span class="hs-keyword">import</span><span> </span><a href="TensorFlow.Ops.html"><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Ops</span></a><span> </span><span class="hs-special">(</span><a href="TensorFlow.Ops.html#shape"><span class="hs-identifier hs-var">shape</span></a><span class="hs-special">,</span><span> </span><a href="TensorFlow.Ops.html#vector"><span class="hs-identifier hs-var">vector</span></a><span class="hs-special">)</span><span> </span><span class="hs-comment">-- Also Num instance for Tensor</span><span> </span><a name="line-29"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Tensor</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Tensor</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Value</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Rendered</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">colocateWith</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">render</span><span class="hs-special">)</span><span> </span><a name="line-30"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Types</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">OneOf</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">TensorType</span><span class="hs-special">)</span><span> </span><a name="line-31"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">GenOps</span><span class="hs-operator">.</span><span class="hs-identifier">Core</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">CoreOps</span><span> </span><a name="line-32"></a><span> </span><a name="line-33"></a><span class="hs-comment">-- | Looks up `ids` in a list of embedding tensors.</span><span> </span><a name="line-34"></a><span class="hs-comment">--</span><span> </span><a name="line-35"></a><span class="hs-comment">-- This function is used to perform parallel lookups on the list of</span><span> </span><a name="line-36"></a><span class="hs-comment">-- tensors in `params`. It is a generalization of `TF.gather`, where</span><span> </span><a name="line-37"></a><span class="hs-comment">-- `params` is interpreted as a partition of a larger embedding</span><span> </span><a name="line-38"></a><span class="hs-comment">-- tensor.</span><span> </span><a name="line-39"></a><span class="hs-comment">--</span><span> </span><a name="line-40"></a><span class="hs-comment">-- The partition_strategy is "mod", we assign each id to partition</span><span> </span><a name="line-41"></a><span class="hs-comment">-- `p = id % len(params)`. For instance,</span><span> </span><a name="line-42"></a><span class="hs-comment">-- 13 ids are split across 5 partitions as:</span><span> </span><a name="line-43"></a><span class="hs-comment">-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`</span><span> </span><a name="line-44"></a><span class="hs-comment">--</span><span> </span><a name="line-45"></a><span class="hs-comment">-- The results of the lookup are concatenated into a dense</span><span> </span><a name="line-46"></a><span class="hs-comment">-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.</span><span> </span><a name="line-47"></a><span class="hs-identifier">embeddingLookup</span><span> </span><span class="hs-glyph">::</span><span> </span><span class="hs-keyword">forall</span><span> </span><a name="local-6989586621679110665"><a href="#local-6989586621679110665"><span class="hs-identifier">a</span></a></a><span> </span><a name="local-6989586621679110666"><a href="#local-6989586621679110666"><span class="hs-identifier">b</span></a></a><span> </span><a name="local-6989586621679110667"><a href="#local-6989586621679110667"><span class="hs-identifier">v1</span></a></a><span> </span><a name="local-6989586621679110668"><a href="#local-6989586621679110668"><span class="hs-identifier">v2</span></a></a><span> </span><a name="local-6989586621679110669"><a href="#local-6989586621679110669"><span class="hs-identifier">m</span></a></a><span> </span><span class="hs-operator">.</span><span> </span><a name="line-48"></a><span> </span><span class="hs-special">(</span><span> </span><span class="hs-identifier hs-type">MonadBuild</span><span> </span><a href="#local-6989586621679110669"><span class="hs-identifier hs-type">m</span></a><span> </span><a name="line-49"></a><span> </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Rendered</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Tensor</span><span> </span><a href="#local-6989586621679110667"><span class="hs-identifier hs-type">v1</span></a><span class="hs-special">)</span><span> </span><a name="line-50"></a><span> </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">TensorType</span><span> </span><a href="#local-6989586621679110665"><span class="hs-identifier hs-type">a</span></a><span> </span><a name="line-51"></a><span> </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">OneOf</span><span> </span><span class="hs-char">'[Int64, Int32] b , Num b ) => [Tensor v1 a] -- ^ A list of tensors which can be concatenated along -- dimension 0. Each `Tensor` must be appropriately -- sized for `mod` partition strategy. -> Tensor v2 b -- ^ A `Tensor` with type `int32` or `int64` -- containing the ids to be looked up in `params`. -- The ids are required to have fewer than 2^31 -- entries. -> m (Tensor Value a) -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids) embeddingLookup params@(p0 : _) ids = do -- Do np separate lookups, finding embeddings for plist[p] in params[p] partitionedResult <- zipWithM (\p g -> colocateWith p $ render $ CoreOps.gather p g) params gatherIds let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult -- Shape restoration is not as optimal as it would be with client -- side shape tracking. paramShape <- colocateWith p0 (render (shape p0)) let finalShape = CoreOps.concat 0 [shape ids, tailShape] tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1)) render $ CoreOps.reshape unshapedResult finalShape where -- Avoids genericLength here which would be evaluated by TF. np = fromIntegral (length params) flatIds = CoreOps.reshape ids (singleton (-1)) pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np) newIds = flatIds `CoreOps.div` np originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1 -- Partition list of ids based on assignments into np separate lists gatherIds = CoreOps.dynamicPartition np newIds pAssignments -- Similarly, partition the original indices. pindices = CoreOps.dynamicPartition np originalIndices pAssignments singleton i = vector [i :: Int32] embeddingLookup [] _ = error "embeddingLookup requires params to be non empty" </span></pre></body></html>