<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
<html>
<head>
<!-- Generated by HsColour, http://code.haskell.org/~malcolm/hscolour/ -->
<title>src/TensorFlow/EmbeddingOps.hs</title>
<link type='text/css' rel='stylesheet' href='hscolour.css' />
</head>
<body>
<pre><a name="line-1"></a><span class='hs-comment'>-- Copyright 2016 TensorFlow authors.</span>
<a name="line-2"></a><span class='hs-comment'>--</span>
<a name="line-3"></a><span class='hs-comment'>-- Licensed under the Apache License, Version 2.0 (the "License");</span>
<a name="line-4"></a><span class='hs-comment'>-- you may not use this file except in compliance with the License.</span>
<a name="line-5"></a><span class='hs-comment'>-- You may obtain a copy of the License at</span>
<a name="line-6"></a><span class='hs-comment'>--</span>
<a name="line-7"></a><span class='hs-comment'>--     <a href="http://www.apache.org/licenses/LICENSE-2.0">http://www.apache.org/licenses/LICENSE-2.0</a></span>
<a name="line-8"></a><span class='hs-comment'>--</span>
<a name="line-9"></a><span class='hs-comment'>-- Unless required by applicable law or agreed to in writing, software</span>
<a name="line-10"></a><span class='hs-comment'>-- distributed under the License is distributed on an "AS IS" BASIS,</span>
<a name="line-11"></a><span class='hs-comment'>-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<a name="line-12"></a><span class='hs-comment'>-- See the License for the specific language governing permissions and</span>
<a name="line-13"></a><span class='hs-comment'>-- limitations under the License.</span>
<a name="line-14"></a>
<a name="line-15"></a><span class='hs-comment'>{-# LANGUAGE ConstraintKinds #-}</span>
<a name="line-16"></a><span class='hs-comment'>{-# LANGUAGE DataKinds #-}</span>
<a name="line-17"></a><span class='hs-comment'>{-# LANGUAGE NoMonomorphismRestriction #-}</span>
<a name="line-18"></a><span class='hs-comment'>{-# LANGUAGE OverloadedStrings #-}</span>
<a name="line-19"></a><span class='hs-comment'>{-# LANGUAGE RankNTypes #-}</span>
<a name="line-20"></a>
<a name="line-21"></a><span class='hs-comment'>-- | Parallel lookups on the list of tensors.</span>
<a name="line-22"></a><span class='hs-keyword'>module</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>EmbeddingOps</span> <span class='hs-keyword'>where</span>
<a name="line-23"></a>
<a name="line-24"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Control</span><span class='hs-varop'>.</span><span class='hs-conid'>Monad</span> <span class='hs-layout'>(</span><span class='hs-varid'>zipWithM</span><span class='hs-layout'>)</span>
<a name="line-25"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Int</span> <span class='hs-layout'>(</span><span class='hs-conid'>Int32</span><span class='hs-layout'>,</span> <span class='hs-conid'>Int64</span><span class='hs-layout'>)</span>
<a name="line-26"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Build</span> <span class='hs-layout'>(</span><span class='hs-conid'>Build</span><span class='hs-layout'>,</span> <span class='hs-varid'>colocateWith</span><span class='hs-layout'>,</span> <span class='hs-varid'>render</span><span class='hs-layout'>)</span>
<a name="line-27"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Ops</span> <span class='hs-layout'>(</span><span class='hs-varid'>shape</span><span class='hs-layout'>,</span> <span class='hs-varid'>vector</span><span class='hs-layout'>)</span>  <span class='hs-comment'>-- Also Num instance for Tensor</span>
<a name="line-28"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Tensor</span> <span class='hs-layout'>(</span><span class='hs-conid'>Tensor</span><span class='hs-layout'>,</span> <span class='hs-conid'>Value</span><span class='hs-layout'>)</span>
<a name="line-29"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Types</span> <span class='hs-layout'>(</span><span class='hs-conid'>OneOf</span><span class='hs-layout'>,</span> <span class='hs-conid'>TensorType</span><span class='hs-layout'>)</span>
<a name="line-30"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>GenOps</span><span class='hs-varop'>.</span><span class='hs-conid'>Core</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>CoreOps</span>
<a name="line-31"></a>
<a name="line-32"></a><a name="embeddingLookup"></a><span class='hs-comment'>-- | Looks up `ids` in a list of embedding tensors.</span>
<a name="line-33"></a><span class='hs-comment'>--</span>
<a name="line-34"></a><span class='hs-comment'>-- This function is used to perform parallel lookups on the list of</span>
<a name="line-35"></a><span class='hs-comment'>-- tensors in `params`.  It is a generalization of `TF.gather`, where</span>
<a name="line-36"></a><span class='hs-comment'>-- `params` is interpreted as a partition of a larger embedding</span>
<a name="line-37"></a><span class='hs-comment'>-- tensor.</span>
<a name="line-38"></a><span class='hs-comment'>--</span>
<a name="line-39"></a><span class='hs-comment'>-- The partition_strategy is "mod", we assign each id to partition</span>
<a name="line-40"></a><span class='hs-comment'>-- `p = id % len(params)`. For instance,</span>
<a name="line-41"></a><span class='hs-comment'>-- 13 ids are split across 5 partitions as:</span>
<a name="line-42"></a><span class='hs-comment'>-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`</span>
<a name="line-43"></a><span class='hs-comment'>--</span>
<a name="line-44"></a><span class='hs-comment'>-- The results of the lookup are concatenated into a dense</span>
<a name="line-45"></a><span class='hs-comment'>-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.</span>
<a name="line-46"></a><span class='hs-definition'>embeddingLookup</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyword'>forall</span> <span class='hs-varid'>a</span> <span class='hs-varid'>b</span> <span class='hs-varid'>v</span> <span class='hs-varop'>.</span>
<a name="line-47"></a>                   <span class='hs-layout'>(</span> <span class='hs-conid'>TensorType</span> <span class='hs-varid'>a</span>
<a name="line-48"></a>                   <span class='hs-layout'>,</span> <span class='hs-conid'>OneOf</span> <span class='hs-chr'>'</span><span class='hs-keyglyph'>[</span><span class='hs-conid'>Int64</span><span class='hs-layout'>,</span> <span class='hs-conid'>Int32</span><span class='hs-keyglyph'>]</span> <span class='hs-varid'>b</span>
<a name="line-49"></a>                   <span class='hs-layout'>,</span> <span class='hs-conid'>Num</span> <span class='hs-varid'>b</span>
<a name="line-50"></a>                   <span class='hs-layout'>)</span>
<a name="line-51"></a>                <span class='hs-keyglyph'>=&gt;</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Tensor</span> <span class='hs-varid'>v</span> <span class='hs-varid'>a</span><span class='hs-keyglyph'>]</span>
<a name="line-52"></a>                <span class='hs-comment'>-- ^ A list of tensors which can be concatenated along</span>
<a name="line-53"></a>                <span class='hs-comment'>-- dimension 0. Each `Tensor` must be appropriately</span>
<a name="line-54"></a>                <span class='hs-comment'>-- sized for `mod` partition strategy.</span>
<a name="line-55"></a>                <span class='hs-keyglyph'>-&gt;</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>b</span>
<a name="line-56"></a>                <span class='hs-comment'>-- ^ A `Tensor` with type `int32` or `int64`</span>
<a name="line-57"></a>                <span class='hs-comment'>-- containing the ids to be looked up in `params`.</span>
<a name="line-58"></a>                <span class='hs-comment'>-- The ids are required to have fewer than 2^31</span>
<a name="line-59"></a>                <span class='hs-comment'>-- entries.</span>
<a name="line-60"></a>                <span class='hs-keyglyph'>-&gt;</span> <span class='hs-conid'>Build</span> <span class='hs-layout'>(</span><span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
<a name="line-61"></a>                <span class='hs-comment'>-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.</span>
<a name="line-62"></a><span class='hs-definition'>embeddingLookup</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>p0</span><span class='hs-keyglyph'>]</span> <span class='hs-varid'>ids</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>colocateWith</span> <span class='hs-varid'>p0</span> <span class='hs-layout'>(</span><span class='hs-varid'>render</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>gather</span> <span class='hs-varid'>p0</span> <span class='hs-varid'>ids</span><span class='hs-layout'>)</span>
<a name="line-63"></a><span class='hs-definition'>embeddingLookup</span> <span class='hs-varid'>params</span><span class='hs-keyglyph'>@</span><span class='hs-layout'>(</span><span class='hs-varid'>p0</span> <span class='hs-conop'>:</span> <span class='hs-keyword'>_</span><span class='hs-layout'>)</span> <span class='hs-varid'>ids</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyword'>do</span>
<a name="line-64"></a>    <span class='hs-comment'>-- Do np separate lookups, finding embeddings for plist[p] in params[p]</span>
<a name="line-65"></a>    <span class='hs-varid'>partitionedResult</span> <span class='hs-keyglyph'>&lt;-</span> <span class='hs-varid'>zipWithM</span>
<a name="line-66"></a>                        <span class='hs-layout'>(</span><span class='hs-keyglyph'>\</span><span class='hs-varid'>p</span> <span class='hs-varid'>g</span> <span class='hs-keyglyph'>-&gt;</span> <span class='hs-varid'>colocateWith</span> <span class='hs-varid'>p</span> <span class='hs-varop'>$</span> <span class='hs-varid'>render</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>gather</span> <span class='hs-varid'>p</span> <span class='hs-varid'>g</span><span class='hs-layout'>)</span>
<a name="line-67"></a>                        <span class='hs-varid'>params</span> <span class='hs-varid'>gatherIds</span>
<a name="line-68"></a>    <span class='hs-keyword'>let</span> <span class='hs-varid'>unshapedResult</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>dynamicStitch</span> <span class='hs-varid'>pindices</span> <span class='hs-varid'>partitionedResult</span>
<a name="line-69"></a>    <span class='hs-comment'>-- Shape restoration is not as optimal as it would be with client</span>
<a name="line-70"></a>    <span class='hs-comment'>-- side shape tracking.</span>
<a name="line-71"></a>    <span class='hs-varid'>paramShape</span> <span class='hs-keyglyph'>&lt;-</span> <span class='hs-varid'>colocateWith</span> <span class='hs-varid'>p0</span> <span class='hs-layout'>(</span><span class='hs-varid'>render</span> <span class='hs-layout'>(</span><span class='hs-varid'>shape</span> <span class='hs-varid'>p0</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
<a name="line-72"></a>    <span class='hs-keyword'>let</span> <span class='hs-varid'>finalShape</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>concat</span> <span class='hs-num'>0</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>shape</span> <span class='hs-varid'>ids</span><span class='hs-layout'>,</span> <span class='hs-varid'>tailShape</span><span class='hs-keyglyph'>]</span>
<a name="line-73"></a>        <span class='hs-varid'>tailShape</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>slice</span> <span class='hs-varid'>paramShape</span> <span class='hs-layout'>(</span><span class='hs-varid'>singleton</span> <span class='hs-num'>1</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-varid'>singleton</span> <span class='hs-layout'>(</span><span class='hs-comment'>-</span><span class='hs-num'>1</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
<a name="line-74"></a>    <span class='hs-varid'>render</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>reshape</span> <span class='hs-varid'>unshapedResult</span> <span class='hs-varid'>finalShape</span>
<a name="line-75"></a>  <span class='hs-keyword'>where</span>
<a name="line-76"></a>    <span class='hs-comment'>-- Avoids genericLength here which would be evaluated by TF.</span>
<a name="line-77"></a>    <span class='hs-varid'>np</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>fromIntegral</span> <span class='hs-layout'>(</span><span class='hs-varid'>length</span> <span class='hs-varid'>params</span><span class='hs-layout'>)</span>
<a name="line-78"></a>    <span class='hs-varid'>flatIds</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>reshape</span> <span class='hs-varid'>ids</span> <span class='hs-layout'>(</span><span class='hs-varid'>singleton</span> <span class='hs-layout'>(</span><span class='hs-comment'>-</span><span class='hs-num'>1</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
<a name="line-79"></a>    <span class='hs-varid'>pAssignments</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>cast</span> <span class='hs-layout'>(</span><span class='hs-varid'>flatIds</span> <span class='hs-varop'>`</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>mod</span><span class='hs-varop'>`</span> <span class='hs-varid'>np</span><span class='hs-layout'>)</span>
<a name="line-80"></a>    <span class='hs-varid'>newIds</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>flatIds</span> <span class='hs-varop'>`</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>div</span><span class='hs-varop'>`</span> <span class='hs-varid'>np</span>
<a name="line-81"></a>    <span class='hs-varid'>originalIndices</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>range</span> <span class='hs-num'>0</span> <span class='hs-layout'>(</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>size</span> <span class='hs-varid'>flatIds</span><span class='hs-layout'>)</span> <span class='hs-num'>1</span>
<a name="line-82"></a>    <span class='hs-comment'>-- Partition list of ids based on assignments into np separate lists</span>
<a name="line-83"></a>    <span class='hs-varid'>gatherIds</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>dynamicPartition</span> <span class='hs-varid'>np</span> <span class='hs-varid'>newIds</span> <span class='hs-varid'>pAssignments</span>
<a name="line-84"></a>    <span class='hs-comment'>-- Similarly, partition the original indices.</span>
<a name="line-85"></a>    <span class='hs-varid'>pindices</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>dynamicPartition</span> <span class='hs-varid'>np</span> <span class='hs-varid'>originalIndices</span> <span class='hs-varid'>pAssignments</span>
<a name="line-86"></a>    <span class='hs-varid'>singleton</span> <span class='hs-varid'>i</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>vector</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>i</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Int32</span><span class='hs-keyglyph'>]</span>
<a name="line-87"></a>
<a name="line-88"></a><span class='hs-definition'>embeddingLookup</span> <span class='hs-conid'>[]</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>error</span> <span class='hs-str'>"embeddingLookup requires params to be non empty"</span>
</pre></body>
</html>