<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"><head><link rel="stylesheet" type="text/css" href="style.css" /><script type="text/javascript" src="highlight.js"></script></head><body><pre><span class="hs-comment">-- Copyright 2016 TensorFlow authors.</span><span>
</span><span id="line-2"></span><span class="hs-comment">--</span><span>
</span><span id="line-3"></span><span class="hs-comment">-- Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span><span>
</span><span id="line-4"></span><span class="hs-comment">-- you may not use this file except in compliance with the License.</span><span>
</span><span id="line-5"></span><span class="hs-comment">-- You may obtain a copy of the License at</span><span>
</span><span id="line-6"></span><span class="hs-comment">--</span><span>
</span><span id="line-7"></span><span class="hs-comment">--     http://www.apache.org/licenses/LICENSE-2.0</span><span>
</span><span id="line-8"></span><span class="hs-comment">--</span><span>
</span><span id="line-9"></span><span class="hs-comment">-- Unless required by applicable law or agreed to in writing, software</span><span>
</span><span id="line-10"></span><span class="hs-comment">-- distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span><span>
</span><span id="line-11"></span><span class="hs-comment">-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span><span>
</span><span id="line-12"></span><span class="hs-comment">-- See the License for the specific language governing permissions and</span><span>
</span><span id="line-13"></span><span class="hs-comment">-- limitations under the License.</span><span>
</span><span id="line-14"></span><span>
</span><span id="line-15"></span><span class="hs-pragma">{-# LANGUAGE ConstraintKinds #-}</span><span>
</span><span id="line-16"></span><span class="hs-pragma">{-# LANGUAGE DataKinds #-}</span><span>
</span><span id="line-17"></span><span class="hs-pragma">{-# LANGUAGE FlexibleContexts #-}</span><span>
</span><span id="line-18"></span><span class="hs-pragma">{-# LANGUAGE NoMonomorphismRestriction #-}</span><span>
</span><span id="line-19"></span><span class="hs-pragma">{-# LANGUAGE OverloadedStrings #-}</span><span>
</span><span id="line-20"></span><span class="hs-pragma">{-# LANGUAGE RankNTypes #-}</span><span>
</span><span id="line-21"></span><span>
</span><span id="line-22"></span><span class="hs-comment">-- | Parallel lookups on the list of tensors.</span><span>
</span><span id="line-23"></span><span class="hs-keyword">module</span><span> </span><span class="hs-identifier">TensorFlow.EmbeddingOps</span><span> </span><span class="hs-keyword">where</span><span>
</span><span id="line-24"></span><span>
</span><span id="line-25"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">Control.Monad</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">zipWithM</span></span><span class="hs-special">)</span><span>
</span><span id="line-26"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">Data.Int</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">Int32</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">Int64</span></span><span class="hs-special">)</span><span>
</span><span id="line-27"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">TensorFlow.Build</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">MonadBuild</span></span><span class="hs-special">)</span><span>
</span><span id="line-28"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><a href="TensorFlow.Ops.html"><span class="hs-identifier">TensorFlow.Ops</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><a href="TensorFlow.Ops.html#shape"><span class="hs-identifier">shape</span></a></span><span class="hs-special">,</span><span> </span><span class="annot"><a href="TensorFlow.Ops.html#vector"><span class="hs-identifier">vector</span></a></span><span class="hs-special">)</span><span>  </span><span class="hs-comment">-- Also Num instance for Tensor</span><span>
</span><span id="line-29"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">TensorFlow.Tensor</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">Tensor</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">Value</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">Rendered</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">colocateWith</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">render</span></span><span class="hs-special">)</span><span>
</span><span id="line-30"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">TensorFlow.Types</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">OneOf</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">TensorType</span></span><span class="hs-special">)</span><span>
</span><span id="line-31"></span><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="annot"><span class="hs-identifier">TensorFlow.GenOps.Core</span></span><span> </span><span class="hs-keyword">as</span><span> </span><span class="annot"><span class="hs-identifier">CoreOps</span></span><span>
</span><span id="line-32"></span><span>
</span><span id="line-33"></span><span class="hs-comment">-- | Looks up `ids` in a list of embedding tensors.</span><span>
</span><span id="line-34"></span><span class="hs-comment">--</span><span>
</span><span id="line-35"></span><span class="hs-comment">-- This function is used to perform parallel lookups on the list of</span><span>
</span><span id="line-36"></span><span class="hs-comment">-- tensors in `params`.  It is a generalization of `TF.gather`, where</span><span>
</span><span id="line-37"></span><span class="hs-comment">-- `params` is interpreted as a partition of a larger embedding</span><span>
</span><span id="line-38"></span><span class="hs-comment">-- tensor.</span><span>
</span><span id="line-39"></span><span class="hs-comment">--</span><span>
</span><span id="line-40"></span><span class="hs-comment">-- The partition_strategy is &quot;mod&quot;, we assign each id to partition</span><span>
</span><span id="line-41"></span><span class="hs-comment">-- `p = id % len(params)`. For instance,</span><span>
</span><span id="line-42"></span><span class="hs-comment">-- 13 ids are split across 5 partitions as:</span><span>
</span><span id="line-43"></span><span class="hs-comment">-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`</span><span>
</span><span id="line-44"></span><span class="hs-comment">--</span><span>
</span><span id="line-45"></span><span class="hs-comment">-- The results of the lookup are concatenated into a dense</span><span>
</span><span id="line-46"></span><span class="hs-comment">-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.</span><span>
</span><span id="line-47"></span><span class="annot"><a href="TensorFlow.EmbeddingOps.html#embeddingLookup"><span class="hs-identifier hs-type">embeddingLookup</span></a></span><span> </span><span class="hs-glyph">::</span><span> </span><span class="hs-keyword">forall</span><span> </span><span id="local-6989586621679157529"><span class="annot"><a href="#local-6989586621679157529"><span class="hs-identifier hs-type">a</span></a></span></span><span> </span><span id="local-6989586621679157528"><span class="annot"><a href="#local-6989586621679157528"><span class="hs-identifier hs-type">b</span></a></span></span><span> </span><span id="local-6989586621679157527"><span class="annot"><a href="#local-6989586621679157527"><span class="hs-identifier hs-type">v1</span></a></span></span><span> </span><span id="local-6989586621679157526"><span class="annot"><a href="#local-6989586621679157526"><span class="hs-identifier hs-type">v2</span></a></span></span><span> </span><span id="local-6989586621679157525"><span class="annot"><a href="#local-6989586621679157525"><span class="hs-identifier hs-type">m</span></a></span></span><span> </span><span class="hs-operator">.</span><span>
</span><span id="line-48"></span><span>                   </span><span class="hs-special">(</span><span> </span><span class="annot"><span class="hs-identifier hs-type">MonadBuild</span></span><span> </span><span class="annot"><a href="#local-6989586621679157525"><span class="hs-identifier hs-type">m</span></a></span><span>
</span><span id="line-49"></span><span>                   </span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">Rendered</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier hs-type">Tensor</span></span><span> </span><span class="annot"><a href="#local-6989586621679157527"><span class="hs-identifier hs-type">v1</span></a></span><span class="hs-special">)</span><span>
</span><span id="line-50"></span><span>                   </span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">TensorType</span></span><span> </span><span class="annot"><a href="#local-6989586621679157529"><span class="hs-identifier hs-type">a</span></a></span><span>
</span><span id="line-51"></span><span>                   </span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">OneOf</span></span><span> </span><span class="hs-special">'</span><span class="hs-special">[</span><span class="annot"><span class="hs-identifier hs-type">Int64</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">Int32</span></span><span class="hs-special">]</span><span> </span><span class="annot"><a href="#local-6989586621679157528"><span class="hs-identifier hs-type">b</span></a></span><span>
</span><span id="line-52"></span><span>                   </span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">Num</span></span><span> </span><span class="annot"><a href="#local-6989586621679157528"><span class="hs-identifier hs-type">b</span></a></span><span>
</span><span id="line-53"></span><span>                   </span><span class="hs-special">)</span><span>
</span><span id="line-54"></span><span>                </span><span class="hs-glyph">=&gt;</span><span> </span><span class="hs-special">[</span><span class="annot"><span class="hs-identifier hs-type">Tensor</span></span><span> </span><span class="annot"><a href="#local-6989586621679157527"><span class="hs-identifier hs-type">v1</span></a></span><span> </span><span class="annot"><a href="#local-6989586621679157529"><span class="hs-identifier hs-type">a</span></a></span><span class="hs-special">]</span><span>
</span><span id="line-55"></span><span>                </span><span class="hs-comment">-- ^ A list of tensors which can be concatenated along</span><span>
</span><span id="line-56"></span><span>                </span><span class="hs-comment">-- dimension 0. Each `Tensor` must be appropriately</span><span>
</span><span id="line-57"></span><span>                </span><span class="hs-comment">-- sized for `mod` partition strategy.</span><span>
</span><span id="line-58"></span><span>                </span><span class="hs-glyph">-&gt;</span><span> </span><span class="annot"><span class="hs-identifier hs-type">Tensor</span></span><span> </span><span class="annot"><a href="#local-6989586621679157526"><span class="hs-identifier hs-type">v2</span></a></span><span> </span><span class="annot"><a href="#local-6989586621679157528"><span class="hs-identifier hs-type">b</span></a></span><span>
</span><span id="line-59"></span><span>                </span><span class="hs-comment">-- ^ A `Tensor` with type `int32` or `int64`</span><span>
</span><span id="line-60"></span><span>                </span><span class="hs-comment">-- containing the ids to be looked up in `params`.</span><span>
</span><span id="line-61"></span><span>                </span><span class="hs-comment">-- The ids are required to have fewer than 2^31</span><span>
</span><span id="line-62"></span><span>                </span><span class="hs-comment">-- entries.</span><span>
</span><span id="line-63"></span><span>                </span><span class="hs-glyph">-&gt;</span><span> </span><span class="annot"><a href="#local-6989586621679157525"><span class="hs-identifier hs-type">m</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier hs-type">Tensor</span></span><span> </span><span class="annot"><span class="hs-identifier hs-type">Value</span></span><span> </span><span class="annot"><a href="#local-6989586621679157529"><span class="hs-identifier hs-type">a</span></a></span><span class="hs-special">)</span><span>
</span><span id="line-64"></span><span>                </span><span class="hs-comment">-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.</span><span>
</span><span id="line-65"></span><span id="embeddingLookup"><span class="annot"><span class="annottext">embeddingLookup :: [Tensor v1 a] -&gt; Tensor v2 b -&gt; m (Tensor Value a)
</span><a href="TensorFlow.EmbeddingOps.html#embeddingLookup"><span class="hs-identifier hs-var hs-var">embeddingLookup</span></a></span></span><span> </span><span class="hs-special">[</span><span id="local-6989586621679157524"><span class="annot"><span class="annottext">p0 :: Tensor v1 a
</span><a href="#local-6989586621679157524"><span class="hs-identifier hs-var">p0</span></a></span></span><span class="hs-special">]</span><span> </span><span id="local-6989586621679157523"><span class="annot"><span class="annottext">ids :: Tensor v2 b
</span><a href="#local-6989586621679157523"><span class="hs-identifier hs-var">ids</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -&gt; m (Tensor Value a) -&gt; m (Tensor Value a)
forall (m :: * -&gt; *) (t :: * -&gt; *) b a.
(MonadBuild m, Rendered t) =&gt;
t b -&gt; m a -&gt; m a
</span><span class="hs-identifier hs-var">colocateWith</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a
</span><a href="#local-6989586621679157524"><span class="hs-identifier hs-var">p0</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor Build a -&gt; m (Tensor Value a)
forall (m :: * -&gt; *) a.
MonadBuild m =&gt;
Tensor Build a -&gt; m (Tensor Value a)
</span><span class="hs-identifier hs-var">render</span></span><span> </span><span class="annot"><span class="annottext">(Tensor Build a -&gt; m (Tensor Value a))
-&gt; Tensor Build a -&gt; m (Tensor Value a)
forall a b. (a -&gt; b) -&gt; a -&gt; b
</span><span class="hs-operator hs-var">$</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -&gt; Tensor v2 b -&gt; Tensor Build a
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =&gt;
Tensor v'1 tparams -&gt; Tensor v'2 tindices -&gt; Tensor Build tparams
</span><span class="hs-identifier hs-var">CoreOps.gather</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a
</span><a href="#local-6989586621679157524"><span class="hs-identifier hs-var">p0</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor v2 b
</span><a href="#local-6989586621679157523"><span class="hs-identifier hs-var">ids</span></a></span><span class="hs-special">)</span><span>
</span><span id="line-66"></span><span class="annot"><a href="TensorFlow.EmbeddingOps.html#embeddingLookup"><span class="hs-identifier hs-var">embeddingLookup</span></a></span><span> </span><span id="local-6989586621679157521"><span class="annot"><span class="annottext">params :: [Tensor v1 a]
</span><a href="#local-6989586621679157521"><span class="hs-identifier hs-var">params</span></a></span></span><span class="hs-glyph">@</span><span class="hs-special">(</span><span id="local-6989586621679157520"><span class="annot"><span class="annottext">p0 :: Tensor v1 a
</span><a href="#local-6989586621679157520"><span class="hs-identifier hs-var">p0</span></a></span></span><span> </span><span class="annot"><span class="hs-glyph hs-type">:</span></span><span> </span><span class="hs-identifier">_</span><span class="hs-special">)</span><span> </span><span id="local-6989586621679157519"><span class="annot"><span class="annottext">ids :: Tensor v2 b
</span><a href="#local-6989586621679157519"><span class="hs-identifier hs-var">ids</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="hs-keyword">do</span><span>
</span><span id="line-67"></span><span>    </span><span class="hs-comment">-- Do np separate lookups, finding embeddings for plist[p] in params[p]</span><span>
</span><span id="line-68"></span><span>    </span><span id="local-6989586621679157518"><span class="annot"><span class="annottext">[Tensor Value a]
</span><a href="#local-6989586621679157518"><span class="hs-identifier hs-var">partitionedResult</span></a></span></span><span> </span><span class="hs-glyph">&lt;-</span><span> </span><span class="annot"><span class="annottext">(Tensor v1 a -&gt; Tensor Build b -&gt; m (Tensor Value a))
-&gt; [Tensor v1 a] -&gt; [Tensor Build b] -&gt; m [Tensor Value a]
forall (m :: * -&gt; *) a b c.
Applicative m =&gt;
(a -&gt; b -&gt; m c) -&gt; [a] -&gt; [b] -&gt; m [c]
</span><span class="hs-identifier hs-var">zipWithM</span></span><span>
</span><span id="line-69"></span><span>                        </span><span class="hs-special">(</span><span class="hs-glyph">\</span><span id="local-6989586621679157517"><span class="annot"><span class="annottext">p :: Tensor v1 a
</span><a href="#local-6989586621679157517"><span class="hs-identifier hs-var">p</span></a></span></span><span> </span><span id="local-6989586621679157516"><span class="annot"><span class="annottext">g :: Tensor Build b
</span><a href="#local-6989586621679157516"><span class="hs-identifier hs-var">g</span></a></span></span><span> </span><span class="hs-glyph">-&gt;</span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -&gt; m (Tensor Value a) -&gt; m (Tensor Value a)
forall (m :: * -&gt; *) (t :: * -&gt; *) b a.
(MonadBuild m, Rendered t) =&gt;
t b -&gt; m a -&gt; m a
</span><span class="hs-identifier hs-var">colocateWith</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a
</span><a href="#local-6989586621679157517"><span class="hs-identifier hs-var">p</span></a></span><span> </span><span class="annot"><span class="annottext">(m (Tensor Value a) -&gt; m (Tensor Value a))
-&gt; m (Tensor Value a) -&gt; m (Tensor Value a)
forall a b. (a -&gt; b) -&gt; a -&gt; b
</span><span class="hs-operator hs-var">$</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build a -&gt; m (Tensor Value a)
forall (m :: * -&gt; *) a.
MonadBuild m =&gt;
Tensor Build a -&gt; m (Tensor Value a)
</span><span class="hs-identifier hs-var">render</span></span><span> </span><span class="annot"><span class="annottext">(Tensor Build a -&gt; m (Tensor Value a))
-&gt; Tensor Build a -&gt; m (Tensor Value a)
forall a b. (a -&gt; b) -&gt; a -&gt; b
</span><span class="hs-operator hs-var">$</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -&gt; Tensor Build b -&gt; Tensor Build a
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =&gt;
Tensor v'1 tparams -&gt; Tensor v'2 tindices -&gt; Tensor Build tparams
</span><span class="hs-identifier hs-var">CoreOps.gather</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a
</span><a href="#local-6989586621679157517"><span class="hs-identifier hs-var">p</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build b
</span><a href="#local-6989586621679157516"><span class="hs-identifier hs-var">g</span></a></span><span class="hs-special">)</span><span>
</span><span id="line-70"></span><span>                        </span><span class="annot"><span class="annottext">[Tensor v1 a]
</span><a href="#local-6989586621679157521"><span class="hs-identifier hs-var">params</span></a></span><span> </span><span class="annot"><span class="annottext">[Tensor Build b]
</span><a href="#local-6989586621679157515"><span class="hs-identifier hs-var">gatherIds</span></a></span><span>
</span><span id="line-71"></span><span>    </span><span class="hs-keyword">let</span><span> </span><span id="local-6989586621679157514"><span class="annot"><span class="annottext">unshapedResult :: Tensor Build a
</span><a href="#local-6989586621679157514"><span class="hs-identifier hs-var hs-var">unshapedResult</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">[Tensor Build Int32] -&gt; [Tensor Value a] -&gt; Tensor Build a
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) t.
TensorType t =&gt;
[Tensor v'1 Int32] -&gt; [Tensor v'2 t] -&gt; Tensor Build t
</span><span class="hs-identifier hs-var">CoreOps.dynamicStitch</span></span><span> </span><span class="annot"><span class="annottext">[Tensor Build Int32]
forall t.
(t /= Int8, t /= Int16, t /= Word8, t /= ByteString, t /= Bool,
 t /= Word16, t /= Float, t /= Double, TensorType t, Num t) =&gt;
[Tensor Build t]
</span><a href="#local-6989586621679157512"><span class="hs-identifier hs-var">pindices</span></a></span><span> </span><span class="annot"><span class="annottext">[Tensor Value a]
</span><a href="#local-6989586621679157518"><span class="hs-identifier hs-var">partitionedResult</span></a></span><span>
</span><span id="line-72"></span><span>    </span><span class="hs-comment">-- Shape restoration is not as optimal as it would be with client</span><span>
</span><span id="line-73"></span><span>    </span><span class="hs-comment">-- side shape tracking.</span><span>
</span><span id="line-74"></span><span>    </span><span id="local-6989586621679157511"><span class="annot"><span class="annottext">Tensor Value Int32
</span><a href="#local-6989586621679157511"><span class="hs-identifier hs-var">paramShape</span></a></span></span><span> </span><span class="hs-glyph">&lt;-</span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -&gt; m (Tensor Value Int32) -&gt; m (Tensor Value Int32)
forall (m :: * -&gt; *) (t :: * -&gt; *) b a.
(MonadBuild m, Rendered t) =&gt;
t b -&gt; m a -&gt; m a
</span><span class="hs-identifier hs-var">colocateWith</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a
</span><a href="#local-6989586621679157520"><span class="hs-identifier hs-var">p0</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor Build Int32 -&gt; m (Tensor Value Int32)
forall (m :: * -&gt; *) a.
MonadBuild m =&gt;
Tensor Build a -&gt; m (Tensor Value a)
</span><span class="hs-identifier hs-var">render</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor v1 a -&gt; Tensor Build Int32
forall t (v :: * -&gt; *).
TensorType t =&gt;
Tensor v t -&gt; Tensor Build Int32
</span><a href="TensorFlow.Ops.html#shape"><span class="hs-identifier hs-var">shape</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a
</span><a href="#local-6989586621679157520"><span class="hs-identifier hs-var">p0</span></a></span><span class="hs-special">)</span><span class="hs-special">)</span><span>
</span><span id="line-75"></span><span>    </span><span class="hs-keyword">let</span><span> </span><span id="local-6989586621679157510"><span class="annot"><span class="annottext">finalShape :: Tensor Build Int32
</span><a href="#local-6989586621679157510"><span class="hs-identifier hs-var hs-var">finalShape</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32 -&gt; [Tensor Build Int32] -&gt; Tensor Build Int32
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) t.
TensorType t =&gt;
Tensor v'1 Int32 -&gt; [Tensor v'2 t] -&gt; Tensor Build t
</span><span class="hs-identifier hs-var">CoreOps.concat</span></span><span> </span><span class="annot"><span class="hs-number">0</span></span><span> </span><span class="hs-special">[</span><span class="annot"><span class="annottext">Tensor v2 b -&gt; Tensor Build Int32
forall t (v :: * -&gt; *).
TensorType t =&gt;
Tensor v t -&gt; Tensor Build Int32
</span><a href="TensorFlow.Ops.html#shape"><span class="hs-identifier hs-var">shape</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor v2 b
</span><a href="#local-6989586621679157519"><span class="hs-identifier hs-var">ids</span></a></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32
</span><a href="#local-6989586621679157508"><span class="hs-identifier hs-var">tailShape</span></a></span><span class="hs-special">]</span><span>
</span><span id="line-76"></span><span>        </span><span id="local-6989586621679157508"><span class="annot"><span class="annottext">tailShape :: Tensor Build Int32
</span><a href="#local-6989586621679157508"><span class="hs-identifier hs-var hs-var">tailShape</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Value Int32
-&gt; Tensor Build Int32 -&gt; Tensor Build Int32 -&gt; Tensor Build Int32
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) (v'3 :: * -&gt; *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =&gt;
Tensor v'1 t
-&gt; Tensor v'2 index -&gt; Tensor v'3 index -&gt; Tensor Build t
</span><span class="hs-identifier hs-var">CoreOps.slice</span></span><span> </span><span class="annot"><span class="annottext">Tensor Value Int32
</span><a href="#local-6989586621679157511"><span class="hs-identifier hs-var">paramShape</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Int32 -&gt; Tensor Build Int32
</span><a href="#local-6989586621679157506"><span class="hs-identifier hs-var">singleton</span></a></span><span> </span><span class="annot"><span class="hs-number">1</span></span><span class="hs-special">)</span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Int32 -&gt; Tensor Build Int32
</span><a href="#local-6989586621679157506"><span class="hs-identifier hs-var">singleton</span></a></span><span> </span><span class="hs-special">(</span><span class="hs-glyph">-</span><span class="annot"><span class="hs-number">1</span></span><span class="hs-special">)</span><span class="hs-special">)</span><span>
</span><span id="line-77"></span><span>    </span><span class="annot"><span class="annottext">Tensor Build a -&gt; m (Tensor Value a)
forall (m :: * -&gt; *) a.
MonadBuild m =&gt;
Tensor Build a -&gt; m (Tensor Value a)
</span><span class="hs-identifier hs-var">render</span></span><span> </span><span class="annot"><span class="annottext">(Tensor Build a -&gt; m (Tensor Value a))
-&gt; Tensor Build a -&gt; m (Tensor Value a)
forall a b. (a -&gt; b) -&gt; a -&gt; b
</span><span class="hs-operator hs-var">$</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build a -&gt; Tensor Build Int32 -&gt; Tensor Build a
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =&gt;
Tensor v'1 tparams -&gt; Tensor v'2 tindices -&gt; Tensor Build tparams
</span><span class="hs-identifier hs-var">CoreOps.reshape</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build a
</span><a href="#local-6989586621679157514"><span class="hs-identifier hs-var">unshapedResult</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32
</span><a href="#local-6989586621679157510"><span class="hs-identifier hs-var">finalShape</span></a></span><span>
</span><span id="line-78"></span><span>  </span><span class="hs-keyword">where</span><span>
</span><span id="line-79"></span><span>    </span><span class="hs-comment">-- Avoids genericLength here which would be evaluated by TF.</span><span>
</span><span id="line-80"></span><span>    </span><span id="local-6989586621679157504"><span class="annot"><span class="annottext">np :: b
</span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var hs-var">np</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Int -&gt; b
forall a b. (Integral a, Num b) =&gt; a -&gt; b
</span><span class="hs-identifier hs-var">fromIntegral</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">[Tensor v1 a] -&gt; Int
forall (t :: * -&gt; *) a. Foldable t =&gt; t a -&gt; Int
</span><span class="hs-identifier hs-var">length</span></span><span> </span><span class="annot"><span class="annottext">[Tensor v1 a]
</span><a href="#local-6989586621679157521"><span class="hs-identifier hs-var">params</span></a></span><span class="hs-special">)</span><span>
</span><span id="line-81"></span><span>    </span><span id="local-6989586621679157502"><span class="annot"><span class="annottext">flatIds :: Tensor Build b
</span><a href="#local-6989586621679157502"><span class="hs-identifier hs-var hs-var">flatIds</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor v2 b -&gt; Tensor Build Int32 -&gt; Tensor Build b
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =&gt;
Tensor v'1 tparams -&gt; Tensor v'2 tindices -&gt; Tensor Build tparams
</span><span class="hs-identifier hs-var">CoreOps.reshape</span></span><span> </span><span class="annot"><span class="annottext">Tensor v2 b
</span><a href="#local-6989586621679157519"><span class="hs-identifier hs-var">ids</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Int32 -&gt; Tensor Build Int32
</span><a href="#local-6989586621679157506"><span class="hs-identifier hs-var">singleton</span></a></span><span> </span><span class="hs-special">(</span><span class="hs-glyph">-</span><span class="annot"><span class="hs-number">1</span></span><span class="hs-special">)</span><span class="hs-special">)</span><span>
</span><span id="line-82"></span><span>    </span><span id="local-6989586621679157501"><span class="annot"><span class="annottext">pAssignments :: Tensor Build dstT
</span><a href="#local-6989586621679157501"><span class="hs-identifier hs-var hs-var">pAssignments</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Build b -&gt; Tensor Build dstT
forall (v'1 :: * -&gt; *) srcT dstT.
(TensorType srcT, TensorType dstT) =&gt;
Tensor v'1 srcT -&gt; Tensor Build dstT
</span><span class="hs-identifier hs-var">CoreOps.cast</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor Build b
</span><a href="#local-6989586621679157502"><span class="hs-identifier hs-var">flatIds</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build b -&gt; Tensor Build b -&gt; Tensor Build b
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) t.
OneOf '[Int32, Int64, Word16, Double, Float] t =&gt;
Tensor v'1 t -&gt; Tensor v'2 t -&gt; Tensor Build t
</span><span class="hs-operator hs-var">`CoreOps.mod`</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build b
forall b. Num b =&gt; b
</span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var">np</span></a></span><span class="hs-special">)</span><span>
</span><span id="line-83"></span><span>    </span><span id="local-6989586621679157498"><span class="annot"><span class="annottext">newIds :: Tensor Build b
</span><a href="#local-6989586621679157498"><span class="hs-identifier hs-var hs-var">newIds</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Build b
</span><a href="#local-6989586621679157502"><span class="hs-identifier hs-var">flatIds</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build b -&gt; Tensor Build b -&gt; Tensor Build b
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =&gt;
Tensor v'1 t -&gt; Tensor v'2 t -&gt; Tensor Build t
</span><span class="hs-operator hs-var">`CoreOps.div`</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build b
forall b. Num b =&gt; b
</span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var">np</span></a></span><span>
</span><span id="line-84"></span><span>    </span><span id="local-6989586621679157496"><span class="annot"><span class="annottext">originalIndices :: Tensor Build tidx
</span><a href="#local-6989586621679157496"><span class="hs-identifier hs-var hs-var">originalIndices</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Build tidx
-&gt; Tensor Build tidx -&gt; Tensor Build tidx -&gt; Tensor Build tidx
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) (v'3 :: * -&gt; *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =&gt;
Tensor v'1 tidx
-&gt; Tensor v'2 tidx -&gt; Tensor v'3 tidx -&gt; Tensor Build tidx
</span><span class="hs-identifier hs-var">CoreOps.range</span></span><span> </span><span class="annot"><span class="hs-number">0</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor Build b -&gt; Tensor Build tidx
forall (v'1 :: * -&gt; *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =&gt;
Tensor v'1 t -&gt; Tensor Build out_type
</span><span class="hs-identifier hs-var">CoreOps.size</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build b
</span><a href="#local-6989586621679157502"><span class="hs-identifier hs-var">flatIds</span></a></span><span class="hs-special">)</span><span> </span><span class="annot"><span class="hs-number">1</span></span><span>
</span><span id="line-85"></span><span>    </span><span class="hs-comment">-- Partition list of ids based on assignments into np separate lists</span><span>
</span><span id="line-86"></span><span>    </span><span id="local-6989586621679157515"><span class="annot"><span class="annottext">gatherIds :: [Tensor Build b]
</span><a href="#local-6989586621679157515"><span class="hs-identifier hs-var hs-var">gatherIds</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Int64 -&gt; Tensor Build b -&gt; Tensor Build Int32 -&gt; [Tensor Build b]
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) t.
TensorType t =&gt;
Int64 -&gt; Tensor v'1 t -&gt; Tensor v'2 Int32 -&gt; [Tensor Build t]
</span><span class="hs-identifier hs-var">CoreOps.dynamicPartition</span></span><span> </span><span class="annot"><span class="annottext">Int64
forall b. Num b =&gt; b
</span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var">np</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build b
</span><a href="#local-6989586621679157498"><span class="hs-identifier hs-var">newIds</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32
forall dstT. TensorType dstT =&gt; Tensor Build dstT
</span><a href="#local-6989586621679157501"><span class="hs-identifier hs-var">pAssignments</span></a></span><span>
</span><span id="line-87"></span><span>    </span><span class="hs-comment">-- Similarly, partition the original indices.</span><span>
</span><span id="line-88"></span><span>    </span><span id="local-6989586621679157512"><span class="annot"><span class="annottext">pindices :: [Tensor Build t]
</span><a href="#local-6989586621679157512"><span class="hs-identifier hs-var hs-var">pindices</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Int64 -&gt; Tensor Build t -&gt; Tensor Build Int32 -&gt; [Tensor Build t]
forall (v'1 :: * -&gt; *) (v'2 :: * -&gt; *) t.
TensorType t =&gt;
Int64 -&gt; Tensor v'1 t -&gt; Tensor v'2 Int32 -&gt; [Tensor Build t]
</span><span class="hs-identifier hs-var">CoreOps.dynamicPartition</span></span><span> </span><span class="annot"><span class="annottext">Int64
forall b. Num b =&gt; b
</span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var">np</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build t
forall tidx.
(tidx /= Int8, tidx /= Int16, tidx /= Word8, tidx /= ByteString,
 tidx /= Bool, tidx /= Word16, tidx /= Float, tidx /= Double,
 TensorType tidx, Num tidx) =&gt;
Tensor Build tidx
</span><a href="#local-6989586621679157496"><span class="hs-identifier hs-var">originalIndices</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32
forall dstT. TensorType dstT =&gt; Tensor Build dstT
</span><a href="#local-6989586621679157501"><span class="hs-identifier hs-var">pAssignments</span></a></span><span>
</span><span id="line-89"></span><span>    </span><span id="local-6989586621679157506"><span class="annot"><span class="annottext">singleton :: Int32 -&gt; Tensor Build Int32
</span><a href="#local-6989586621679157506"><span class="hs-identifier hs-var hs-var">singleton</span></a></span></span><span> </span><span id="local-6989586621679157492"><span class="annot"><span class="annottext">i :: Int32
</span><a href="#local-6989586621679157492"><span class="hs-identifier hs-var">i</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">[Int32] -&gt; Tensor Build Int32
forall a. TensorType a =&gt; [a] -&gt; Tensor Build a
</span><a href="TensorFlow.Ops.html#vector"><span class="hs-identifier hs-var">vector</span></a></span><span> </span><span class="hs-special">[</span><span class="annot"><span class="annottext">Int32
</span><a href="#local-6989586621679157492"><span class="hs-identifier hs-var">i</span></a></span><span> </span><span class="hs-glyph">::</span><span> </span><span class="annot"><span class="hs-identifier hs-type">Int32</span></span><span class="hs-special">]</span><span>
</span><span id="line-90"></span><span>
</span><span id="line-91"></span><span class="annot"><a href="TensorFlow.EmbeddingOps.html#embeddingLookup"><span class="hs-identifier hs-var">embeddingLookup</span></a></span><span> </span><span class="hs-special">[</span><span class="hs-special">]</span><span> </span><span class="hs-identifier">_</span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">[Char] -&gt; m (Tensor Value a)
forall a. HasCallStack =&gt; [Char] -&gt; a
</span><span class="hs-identifier hs-var">error</span></span><span> </span><span class="annot"><span class="hs-string">&quot;embeddingLookup requires params to be non empty&quot;</span></span><span>
</span><span id="line-92"></span></pre></body></html>