<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"><head><link rel="stylesheet" type="text/css" href="style.css" /><script type="text/javascript" src="highlight.js"></script></head><body><pre><span class="hs-comment">-- Copyright 2016 TensorFlow authors.</span><span> </span><span id="line-2"></span><span class="hs-comment">--</span><span> </span><span id="line-3"></span><span class="hs-comment">-- Licensed under the Apache License, Version 2.0 (the "License");</span><span> </span><span id="line-4"></span><span class="hs-comment">-- you may not use this file except in compliance with the License.</span><span> </span><span id="line-5"></span><span class="hs-comment">-- You may obtain a copy of the License at</span><span> </span><span id="line-6"></span><span class="hs-comment">--</span><span> </span><span id="line-7"></span><span class="hs-comment">-- http://www.apache.org/licenses/LICENSE-2.0</span><span> </span><span id="line-8"></span><span class="hs-comment">--</span><span> </span><span id="line-9"></span><span class="hs-comment">-- Unless required by applicable law or agreed to in writing, software</span><span> </span><span id="line-10"></span><span class="hs-comment">-- distributed under the License is distributed on an "AS IS" BASIS,</span><span> </span><span id="line-11"></span><span class="hs-comment">-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span><span> </span><span id="line-12"></span><span class="hs-comment">-- See the License for the specific language governing permissions and</span><span> </span><span id="line-13"></span><span class="hs-comment">-- limitations under the License.</span><span> </span><span id="line-14"></span><span> </span><span id="line-15"></span><span class="hs-pragma">{-# LANGUAGE ConstraintKinds #-}</span><span> </span><span id="line-16"></span><span class="hs-pragma">{-# LANGUAGE DataKinds #-}</span><span> </span><span id="line-17"></span><span class="hs-pragma">{-# LANGUAGE FlexibleContexts #-}</span><span> </span><span id="line-18"></span><span class="hs-pragma">{-# LANGUAGE NoMonomorphismRestriction #-}</span><span> </span><span id="line-19"></span><span class="hs-pragma">{-# LANGUAGE OverloadedStrings #-}</span><span> </span><span id="line-20"></span><span class="hs-pragma">{-# LANGUAGE RankNTypes #-}</span><span> </span><span id="line-21"></span><span> </span><span id="line-22"></span><span class="hs-comment">-- | Parallel lookups on the list of tensors.</span><span> </span><span id="line-23"></span><span class="hs-keyword">module</span><span> </span><span class="hs-identifier">TensorFlow.EmbeddingOps</span><span> </span><span class="hs-keyword">where</span><span> </span><span id="line-24"></span><span> </span><span id="line-25"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">Control.Monad</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">zipWithM</span></span><span class="hs-special">)</span><span> </span><span id="line-26"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">Data.Int</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">Int32</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">Int64</span></span><span class="hs-special">)</span><span> </span><span id="line-27"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">TensorFlow.Build</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">MonadBuild</span></span><span class="hs-special">)</span><span> </span><span id="line-28"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><a href="TensorFlow.Ops.html"><span class="hs-identifier">TensorFlow.Ops</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><a href="TensorFlow.Ops.html#shape"><span class="hs-identifier">shape</span></a></span><span class="hs-special">,</span><span> </span><span class="annot"><a href="TensorFlow.Ops.html#vector"><span class="hs-identifier">vector</span></a></span><span class="hs-special">)</span><span> </span><span class="hs-comment">-- Also Num instance for Tensor</span><span> </span><span id="line-29"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">TensorFlow.Tensor</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">Tensor</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">Value</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">Rendered</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">colocateWith</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">render</span></span><span class="hs-special">)</span><span> </span><span id="line-30"></span><span class="hs-keyword">import</span><span> </span><span class="annot"><span class="hs-identifier">TensorFlow.Types</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier">OneOf</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier">TensorType</span></span><span class="hs-special">)</span><span> </span><span id="line-31"></span><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="annot"><span class="hs-identifier">TensorFlow.GenOps.Core</span></span><span> </span><span class="hs-keyword">as</span><span> </span><span class="annot"><span class="hs-identifier">CoreOps</span></span><span> </span><span id="line-32"></span><span> </span><span id="line-33"></span><span class="hs-comment">-- | Looks up `ids` in a list of embedding tensors.</span><span> </span><span id="line-34"></span><span class="hs-comment">--</span><span> </span><span id="line-35"></span><span class="hs-comment">-- This function is used to perform parallel lookups on the list of</span><span> </span><span id="line-36"></span><span class="hs-comment">-- tensors in `params`. It is a generalization of `TF.gather`, where</span><span> </span><span id="line-37"></span><span class="hs-comment">-- `params` is interpreted as a partition of a larger embedding</span><span> </span><span id="line-38"></span><span class="hs-comment">-- tensor.</span><span> </span><span id="line-39"></span><span class="hs-comment">--</span><span> </span><span id="line-40"></span><span class="hs-comment">-- The partition_strategy is "mod", we assign each id to partition</span><span> </span><span id="line-41"></span><span class="hs-comment">-- `p = id % len(params)`. For instance,</span><span> </span><span id="line-42"></span><span class="hs-comment">-- 13 ids are split across 5 partitions as:</span><span> </span><span id="line-43"></span><span class="hs-comment">-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`</span><span> </span><span id="line-44"></span><span class="hs-comment">--</span><span> </span><span id="line-45"></span><span class="hs-comment">-- The results of the lookup are concatenated into a dense</span><span> </span><span id="line-46"></span><span class="hs-comment">-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.</span><span> </span><span id="line-47"></span><span class="annot"><a href="TensorFlow.EmbeddingOps.html#embeddingLookup"><span class="hs-identifier hs-type">embeddingLookup</span></a></span><span> </span><span class="hs-glyph">::</span><span> </span><span class="hs-keyword">forall</span><span> </span><span id="local-6989586621679157529"><span class="annot"><a href="#local-6989586621679157529"><span class="hs-identifier hs-type">a</span></a></span></span><span> </span><span id="local-6989586621679157528"><span class="annot"><a href="#local-6989586621679157528"><span class="hs-identifier hs-type">b</span></a></span></span><span> </span><span id="local-6989586621679157527"><span class="annot"><a href="#local-6989586621679157527"><span class="hs-identifier hs-type">v1</span></a></span></span><span> </span><span id="local-6989586621679157526"><span class="annot"><a href="#local-6989586621679157526"><span class="hs-identifier hs-type">v2</span></a></span></span><span> </span><span id="local-6989586621679157525"><span class="annot"><a href="#local-6989586621679157525"><span class="hs-identifier hs-type">m</span></a></span></span><span> </span><span class="hs-operator">.</span><span> </span><span id="line-48"></span><span> </span><span class="hs-special">(</span><span> </span><span class="annot"><span class="hs-identifier hs-type">MonadBuild</span></span><span> </span><span class="annot"><a href="#local-6989586621679157525"><span class="hs-identifier hs-type">m</span></a></span><span> </span><span id="line-49"></span><span> </span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">Rendered</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier hs-type">Tensor</span></span><span> </span><span class="annot"><a href="#local-6989586621679157527"><span class="hs-identifier hs-type">v1</span></a></span><span class="hs-special">)</span><span> </span><span id="line-50"></span><span> </span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">TensorType</span></span><span> </span><span class="annot"><a href="#local-6989586621679157529"><span class="hs-identifier hs-type">a</span></a></span><span> </span><span id="line-51"></span><span> </span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">OneOf</span></span><span> </span><span class="hs-special">'</span><span class="hs-special">[</span><span class="annot"><span class="hs-identifier hs-type">Int64</span></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">Int32</span></span><span class="hs-special">]</span><span> </span><span class="annot"><a href="#local-6989586621679157528"><span class="hs-identifier hs-type">b</span></a></span><span> </span><span id="line-52"></span><span> </span><span class="hs-special">,</span><span> </span><span class="annot"><span class="hs-identifier hs-type">Num</span></span><span> </span><span class="annot"><a href="#local-6989586621679157528"><span class="hs-identifier hs-type">b</span></a></span><span> </span><span id="line-53"></span><span> </span><span class="hs-special">)</span><span> </span><span id="line-54"></span><span> </span><span class="hs-glyph">=></span><span> </span><span class="hs-special">[</span><span class="annot"><span class="hs-identifier hs-type">Tensor</span></span><span> </span><span class="annot"><a href="#local-6989586621679157527"><span class="hs-identifier hs-type">v1</span></a></span><span> </span><span class="annot"><a href="#local-6989586621679157529"><span class="hs-identifier hs-type">a</span></a></span><span class="hs-special">]</span><span> </span><span id="line-55"></span><span> </span><span class="hs-comment">-- ^ A list of tensors which can be concatenated along</span><span> </span><span id="line-56"></span><span> </span><span class="hs-comment">-- dimension 0. Each `Tensor` must be appropriately</span><span> </span><span id="line-57"></span><span> </span><span class="hs-comment">-- sized for `mod` partition strategy.</span><span> </span><span id="line-58"></span><span> </span><span class="hs-glyph">-></span><span> </span><span class="annot"><span class="hs-identifier hs-type">Tensor</span></span><span> </span><span class="annot"><a href="#local-6989586621679157526"><span class="hs-identifier hs-type">v2</span></a></span><span> </span><span class="annot"><a href="#local-6989586621679157528"><span class="hs-identifier hs-type">b</span></a></span><span> </span><span id="line-59"></span><span> </span><span class="hs-comment">-- ^ A `Tensor` with type `int32` or `int64`</span><span> </span><span id="line-60"></span><span> </span><span class="hs-comment">-- containing the ids to be looked up in `params`.</span><span> </span><span id="line-61"></span><span> </span><span class="hs-comment">-- The ids are required to have fewer than 2^31</span><span> </span><span id="line-62"></span><span> </span><span class="hs-comment">-- entries.</span><span> </span><span id="line-63"></span><span> </span><span class="hs-glyph">-></span><span> </span><span class="annot"><a href="#local-6989586621679157525"><span class="hs-identifier hs-type">m</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="hs-identifier hs-type">Tensor</span></span><span> </span><span class="annot"><span class="hs-identifier hs-type">Value</span></span><span> </span><span class="annot"><a href="#local-6989586621679157529"><span class="hs-identifier hs-type">a</span></a></span><span class="hs-special">)</span><span> </span><span id="line-64"></span><span> </span><span class="hs-comment">-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.</span><span> </span><span id="line-65"></span><span id="embeddingLookup"><span class="annot"><span class="annottext">embeddingLookup :: [Tensor v1 a] -> Tensor v2 b -> m (Tensor Value a) </span><a href="TensorFlow.EmbeddingOps.html#embeddingLookup"><span class="hs-identifier hs-var hs-var">embeddingLookup</span></a></span></span><span> </span><span class="hs-special">[</span><span id="local-6989586621679157524"><span class="annot"><span class="annottext">p0 :: Tensor v1 a </span><a href="#local-6989586621679157524"><span class="hs-identifier hs-var">p0</span></a></span></span><span class="hs-special">]</span><span> </span><span id="local-6989586621679157523"><span class="annot"><span class="annottext">ids :: Tensor v2 b </span><a href="#local-6989586621679157523"><span class="hs-identifier hs-var">ids</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -> m (Tensor Value a) -> m (Tensor Value a) forall (m :: * -> *) (t :: * -> *) b a. (MonadBuild m, Rendered t) => t b -> m a -> m a </span><span class="hs-identifier hs-var">colocateWith</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a </span><a href="#local-6989586621679157524"><span class="hs-identifier hs-var">p0</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor Build a -> m (Tensor Value a) forall (m :: * -> *) a. MonadBuild m => Tensor Build a -> m (Tensor Value a) </span><span class="hs-identifier hs-var">render</span></span><span> </span><span class="annot"><span class="annottext">(Tensor Build a -> m (Tensor Value a)) -> Tensor Build a -> m (Tensor Value a) forall a b. (a -> b) -> a -> b </span><span class="hs-operator hs-var">$</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -> Tensor v2 b -> Tensor Build a forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices. (TensorType tparams, OneOf '[Int32, Int64] tindices) => Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams </span><span class="hs-identifier hs-var">CoreOps.gather</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a </span><a href="#local-6989586621679157524"><span class="hs-identifier hs-var">p0</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor v2 b </span><a href="#local-6989586621679157523"><span class="hs-identifier hs-var">ids</span></a></span><span class="hs-special">)</span><span> </span><span id="line-66"></span><span class="annot"><a href="TensorFlow.EmbeddingOps.html#embeddingLookup"><span class="hs-identifier hs-var">embeddingLookup</span></a></span><span> </span><span id="local-6989586621679157521"><span class="annot"><span class="annottext">params :: [Tensor v1 a] </span><a href="#local-6989586621679157521"><span class="hs-identifier hs-var">params</span></a></span></span><span class="hs-glyph">@</span><span class="hs-special">(</span><span id="local-6989586621679157520"><span class="annot"><span class="annottext">p0 :: Tensor v1 a </span><a href="#local-6989586621679157520"><span class="hs-identifier hs-var">p0</span></a></span></span><span> </span><span class="annot"><span class="hs-glyph hs-type">:</span></span><span> </span><span class="hs-identifier">_</span><span class="hs-special">)</span><span> </span><span id="local-6989586621679157519"><span class="annot"><span class="annottext">ids :: Tensor v2 b </span><a href="#local-6989586621679157519"><span class="hs-identifier hs-var">ids</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="hs-keyword">do</span><span> </span><span id="line-67"></span><span> </span><span class="hs-comment">-- Do np separate lookups, finding embeddings for plist[p] in params[p]</span><span> </span><span id="line-68"></span><span> </span><span id="local-6989586621679157518"><span class="annot"><span class="annottext">[Tensor Value a] </span><a href="#local-6989586621679157518"><span class="hs-identifier hs-var">partitionedResult</span></a></span></span><span> </span><span class="hs-glyph"><-</span><span> </span><span class="annot"><span class="annottext">(Tensor v1 a -> Tensor Build b -> m (Tensor Value a)) -> [Tensor v1 a] -> [Tensor Build b] -> m [Tensor Value a] forall (m :: * -> *) a b c. Applicative m => (a -> b -> m c) -> [a] -> [b] -> m [c] </span><span class="hs-identifier hs-var">zipWithM</span></span><span> </span><span id="line-69"></span><span> </span><span class="hs-special">(</span><span class="hs-glyph">\</span><span id="local-6989586621679157517"><span class="annot"><span class="annottext">p :: Tensor v1 a </span><a href="#local-6989586621679157517"><span class="hs-identifier hs-var">p</span></a></span></span><span> </span><span id="local-6989586621679157516"><span class="annot"><span class="annottext">g :: Tensor Build b </span><a href="#local-6989586621679157516"><span class="hs-identifier hs-var">g</span></a></span></span><span> </span><span class="hs-glyph">-></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -> m (Tensor Value a) -> m (Tensor Value a) forall (m :: * -> *) (t :: * -> *) b a. (MonadBuild m, Rendered t) => t b -> m a -> m a </span><span class="hs-identifier hs-var">colocateWith</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a </span><a href="#local-6989586621679157517"><span class="hs-identifier hs-var">p</span></a></span><span> </span><span class="annot"><span class="annottext">(m (Tensor Value a) -> m (Tensor Value a)) -> m (Tensor Value a) -> m (Tensor Value a) forall a b. (a -> b) -> a -> b </span><span class="hs-operator hs-var">$</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build a -> m (Tensor Value a) forall (m :: * -> *) a. MonadBuild m => Tensor Build a -> m (Tensor Value a) </span><span class="hs-identifier hs-var">render</span></span><span> </span><span class="annot"><span class="annottext">(Tensor Build a -> m (Tensor Value a)) -> Tensor Build a -> m (Tensor Value a) forall a b. (a -> b) -> a -> b </span><span class="hs-operator hs-var">$</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -> Tensor Build b -> Tensor Build a forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices. (TensorType tparams, OneOf '[Int32, Int64] tindices) => Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams </span><span class="hs-identifier hs-var">CoreOps.gather</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a </span><a href="#local-6989586621679157517"><span class="hs-identifier hs-var">p</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build b </span><a href="#local-6989586621679157516"><span class="hs-identifier hs-var">g</span></a></span><span class="hs-special">)</span><span> </span><span id="line-70"></span><span> </span><span class="annot"><span class="annottext">[Tensor v1 a] </span><a href="#local-6989586621679157521"><span class="hs-identifier hs-var">params</span></a></span><span> </span><span class="annot"><span class="annottext">[Tensor Build b] </span><a href="#local-6989586621679157515"><span class="hs-identifier hs-var">gatherIds</span></a></span><span> </span><span id="line-71"></span><span> </span><span class="hs-keyword">let</span><span> </span><span id="local-6989586621679157514"><span class="annot"><span class="annottext">unshapedResult :: Tensor Build a </span><a href="#local-6989586621679157514"><span class="hs-identifier hs-var hs-var">unshapedResult</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">[Tensor Build Int32] -> [Tensor Value a] -> Tensor Build a forall (v'1 :: * -> *) (v'2 :: * -> *) t. TensorType t => [Tensor v'1 Int32] -> [Tensor v'2 t] -> Tensor Build t </span><span class="hs-identifier hs-var">CoreOps.dynamicStitch</span></span><span> </span><span class="annot"><span class="annottext">[Tensor Build Int32] forall t. (t /= Int8, t /= Int16, t /= Word8, t /= ByteString, t /= Bool, t /= Word16, t /= Float, t /= Double, TensorType t, Num t) => [Tensor Build t] </span><a href="#local-6989586621679157512"><span class="hs-identifier hs-var">pindices</span></a></span><span> </span><span class="annot"><span class="annottext">[Tensor Value a] </span><a href="#local-6989586621679157518"><span class="hs-identifier hs-var">partitionedResult</span></a></span><span> </span><span id="line-72"></span><span> </span><span class="hs-comment">-- Shape restoration is not as optimal as it would be with client</span><span> </span><span id="line-73"></span><span> </span><span class="hs-comment">-- side shape tracking.</span><span> </span><span id="line-74"></span><span> </span><span id="local-6989586621679157511"><span class="annot"><span class="annottext">Tensor Value Int32 </span><a href="#local-6989586621679157511"><span class="hs-identifier hs-var">paramShape</span></a></span></span><span> </span><span class="hs-glyph"><-</span><span> </span><span class="annot"><span class="annottext">Tensor v1 a -> m (Tensor Value Int32) -> m (Tensor Value Int32) forall (m :: * -> *) (t :: * -> *) b a. (MonadBuild m, Rendered t) => t b -> m a -> m a </span><span class="hs-identifier hs-var">colocateWith</span></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a </span><a href="#local-6989586621679157520"><span class="hs-identifier hs-var">p0</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor Build Int32 -> m (Tensor Value Int32) forall (m :: * -> *) a. MonadBuild m => Tensor Build a -> m (Tensor Value a) </span><span class="hs-identifier hs-var">render</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor v1 a -> Tensor Build Int32 forall t (v :: * -> *). TensorType t => Tensor v t -> Tensor Build Int32 </span><a href="TensorFlow.Ops.html#shape"><span class="hs-identifier hs-var">shape</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor v1 a </span><a href="#local-6989586621679157520"><span class="hs-identifier hs-var">p0</span></a></span><span class="hs-special">)</span><span class="hs-special">)</span><span> </span><span id="line-75"></span><span> </span><span class="hs-keyword">let</span><span> </span><span id="local-6989586621679157510"><span class="annot"><span class="annottext">finalShape :: Tensor Build Int32 </span><a href="#local-6989586621679157510"><span class="hs-identifier hs-var hs-var">finalShape</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32 forall (v'1 :: * -> *) (v'2 :: * -> *) t. TensorType t => Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t </span><span class="hs-identifier hs-var">CoreOps.concat</span></span><span> </span><span class="annot"><span class="hs-number">0</span></span><span> </span><span class="hs-special">[</span><span class="annot"><span class="annottext">Tensor v2 b -> Tensor Build Int32 forall t (v :: * -> *). TensorType t => Tensor v t -> Tensor Build Int32 </span><a href="TensorFlow.Ops.html#shape"><span class="hs-identifier hs-var">shape</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor v2 b </span><a href="#local-6989586621679157519"><span class="hs-identifier hs-var">ids</span></a></span><span class="hs-special">,</span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32 </span><a href="#local-6989586621679157508"><span class="hs-identifier hs-var">tailShape</span></a></span><span class="hs-special">]</span><span> </span><span id="line-76"></span><span> </span><span id="local-6989586621679157508"><span class="annot"><span class="annottext">tailShape :: Tensor Build Int32 </span><a href="#local-6989586621679157508"><span class="hs-identifier hs-var hs-var">tailShape</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Value Int32 -> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32 forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index. (TensorType t, OneOf '[Int32, Int64] index) => Tensor v'1 t -> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t </span><span class="hs-identifier hs-var">CoreOps.slice</span></span><span> </span><span class="annot"><span class="annottext">Tensor Value Int32 </span><a href="#local-6989586621679157511"><span class="hs-identifier hs-var">paramShape</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Int32 -> Tensor Build Int32 </span><a href="#local-6989586621679157506"><span class="hs-identifier hs-var">singleton</span></a></span><span> </span><span class="annot"><span class="hs-number">1</span></span><span class="hs-special">)</span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Int32 -> Tensor Build Int32 </span><a href="#local-6989586621679157506"><span class="hs-identifier hs-var">singleton</span></a></span><span> </span><span class="hs-special">(</span><span class="hs-glyph">-</span><span class="annot"><span class="hs-number">1</span></span><span class="hs-special">)</span><span class="hs-special">)</span><span> </span><span id="line-77"></span><span> </span><span class="annot"><span class="annottext">Tensor Build a -> m (Tensor Value a) forall (m :: * -> *) a. MonadBuild m => Tensor Build a -> m (Tensor Value a) </span><span class="hs-identifier hs-var">render</span></span><span> </span><span class="annot"><span class="annottext">(Tensor Build a -> m (Tensor Value a)) -> Tensor Build a -> m (Tensor Value a) forall a b. (a -> b) -> a -> b </span><span class="hs-operator hs-var">$</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build a -> Tensor Build Int32 -> Tensor Build a forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices. (TensorType tparams, OneOf '[Int32, Int64] tindices) => Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams </span><span class="hs-identifier hs-var">CoreOps.reshape</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build a </span><a href="#local-6989586621679157514"><span class="hs-identifier hs-var">unshapedResult</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32 </span><a href="#local-6989586621679157510"><span class="hs-identifier hs-var">finalShape</span></a></span><span> </span><span id="line-78"></span><span> </span><span class="hs-keyword">where</span><span> </span><span id="line-79"></span><span> </span><span class="hs-comment">-- Avoids genericLength here which would be evaluated by TF.</span><span> </span><span id="line-80"></span><span> </span><span id="local-6989586621679157504"><span class="annot"><span class="annottext">np :: b </span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var hs-var">np</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Int -> b forall a b. (Integral a, Num b) => a -> b </span><span class="hs-identifier hs-var">fromIntegral</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">[Tensor v1 a] -> Int forall (t :: * -> *) a. Foldable t => t a -> Int </span><span class="hs-identifier hs-var">length</span></span><span> </span><span class="annot"><span class="annottext">[Tensor v1 a] </span><a href="#local-6989586621679157521"><span class="hs-identifier hs-var">params</span></a></span><span class="hs-special">)</span><span> </span><span id="line-81"></span><span> </span><span id="local-6989586621679157502"><span class="annot"><span class="annottext">flatIds :: Tensor Build b </span><a href="#local-6989586621679157502"><span class="hs-identifier hs-var hs-var">flatIds</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor v2 b -> Tensor Build Int32 -> Tensor Build b forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices. (TensorType tparams, OneOf '[Int32, Int64] tindices) => Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams </span><span class="hs-identifier hs-var">CoreOps.reshape</span></span><span> </span><span class="annot"><span class="annottext">Tensor v2 b </span><a href="#local-6989586621679157519"><span class="hs-identifier hs-var">ids</span></a></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Int32 -> Tensor Build Int32 </span><a href="#local-6989586621679157506"><span class="hs-identifier hs-var">singleton</span></a></span><span> </span><span class="hs-special">(</span><span class="hs-glyph">-</span><span class="annot"><span class="hs-number">1</span></span><span class="hs-special">)</span><span class="hs-special">)</span><span> </span><span id="line-82"></span><span> </span><span id="local-6989586621679157501"><span class="annot"><span class="annottext">pAssignments :: Tensor Build dstT </span><a href="#local-6989586621679157501"><span class="hs-identifier hs-var hs-var">pAssignments</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Build b -> Tensor Build dstT forall (v'1 :: * -> *) srcT dstT. (TensorType srcT, TensorType dstT) => Tensor v'1 srcT -> Tensor Build dstT </span><span class="hs-identifier hs-var">CoreOps.cast</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor Build b </span><a href="#local-6989586621679157502"><span class="hs-identifier hs-var">flatIds</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build b -> Tensor Build b -> Tensor Build b forall (v'1 :: * -> *) (v'2 :: * -> *) t. OneOf '[Int32, Int64, Word16, Double, Float] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t </span><span class="hs-operator hs-var">`CoreOps.mod`</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build b forall b. Num b => b </span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var">np</span></a></span><span class="hs-special">)</span><span> </span><span id="line-83"></span><span> </span><span id="local-6989586621679157498"><span class="annot"><span class="annottext">newIds :: Tensor Build b </span><a href="#local-6989586621679157498"><span class="hs-identifier hs-var hs-var">newIds</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Build b </span><a href="#local-6989586621679157502"><span class="hs-identifier hs-var">flatIds</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build b -> Tensor Build b -> Tensor Build b forall (v'1 :: * -> *) (v'2 :: * -> *) t. OneOf '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t </span><span class="hs-operator hs-var">`CoreOps.div`</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build b forall b. Num b => b </span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var">np</span></a></span><span> </span><span id="line-84"></span><span> </span><span id="local-6989586621679157496"><span class="annot"><span class="annottext">originalIndices :: Tensor Build tidx </span><a href="#local-6989586621679157496"><span class="hs-identifier hs-var hs-var">originalIndices</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Tensor Build tidx -> Tensor Build tidx -> Tensor Build tidx -> Tensor Build tidx forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx. OneOf '[Int32, Int64, Word16, Double, Float] tidx => Tensor v'1 tidx -> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx </span><span class="hs-identifier hs-var">CoreOps.range</span></span><span> </span><span class="annot"><span class="hs-number">0</span></span><span> </span><span class="hs-special">(</span><span class="annot"><span class="annottext">Tensor Build b -> Tensor Build tidx forall (v'1 :: * -> *) t out_type. (TensorType t, OneOf '[Int32, Int64] out_type) => Tensor v'1 t -> Tensor Build out_type </span><span class="hs-identifier hs-var">CoreOps.size</span></span><span> </span><span class="annot"><span class="annottext">Tensor Build b </span><a href="#local-6989586621679157502"><span class="hs-identifier hs-var">flatIds</span></a></span><span class="hs-special">)</span><span> </span><span class="annot"><span class="hs-number">1</span></span><span> </span><span id="line-85"></span><span> </span><span class="hs-comment">-- Partition list of ids based on assignments into np separate lists</span><span> </span><span id="line-86"></span><span> </span><span id="local-6989586621679157515"><span class="annot"><span class="annottext">gatherIds :: [Tensor Build b] </span><a href="#local-6989586621679157515"><span class="hs-identifier hs-var hs-var">gatherIds</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Int64 -> Tensor Build b -> Tensor Build Int32 -> [Tensor Build b] forall (v'1 :: * -> *) (v'2 :: * -> *) t. TensorType t => Int64 -> Tensor v'1 t -> Tensor v'2 Int32 -> [Tensor Build t] </span><span class="hs-identifier hs-var">CoreOps.dynamicPartition</span></span><span> </span><span class="annot"><span class="annottext">Int64 forall b. Num b => b </span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var">np</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build b </span><a href="#local-6989586621679157498"><span class="hs-identifier hs-var">newIds</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32 forall dstT. TensorType dstT => Tensor Build dstT </span><a href="#local-6989586621679157501"><span class="hs-identifier hs-var">pAssignments</span></a></span><span> </span><span id="line-87"></span><span> </span><span class="hs-comment">-- Similarly, partition the original indices.</span><span> </span><span id="line-88"></span><span> </span><span id="local-6989586621679157512"><span class="annot"><span class="annottext">pindices :: [Tensor Build t] </span><a href="#local-6989586621679157512"><span class="hs-identifier hs-var hs-var">pindices</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">Int64 -> Tensor Build t -> Tensor Build Int32 -> [Tensor Build t] forall (v'1 :: * -> *) (v'2 :: * -> *) t. TensorType t => Int64 -> Tensor v'1 t -> Tensor v'2 Int32 -> [Tensor Build t] </span><span class="hs-identifier hs-var">CoreOps.dynamicPartition</span></span><span> </span><span class="annot"><span class="annottext">Int64 forall b. Num b => b </span><a href="#local-6989586621679157504"><span class="hs-identifier hs-var">np</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build t forall tidx. (tidx /= Int8, tidx /= Int16, tidx /= Word8, tidx /= ByteString, tidx /= Bool, tidx /= Word16, tidx /= Float, tidx /= Double, TensorType tidx, Num tidx) => Tensor Build tidx </span><a href="#local-6989586621679157496"><span class="hs-identifier hs-var">originalIndices</span></a></span><span> </span><span class="annot"><span class="annottext">Tensor Build Int32 forall dstT. TensorType dstT => Tensor Build dstT </span><a href="#local-6989586621679157501"><span class="hs-identifier hs-var">pAssignments</span></a></span><span> </span><span id="line-89"></span><span> </span><span id="local-6989586621679157506"><span class="annot"><span class="annottext">singleton :: Int32 -> Tensor Build Int32 </span><a href="#local-6989586621679157506"><span class="hs-identifier hs-var hs-var">singleton</span></a></span></span><span> </span><span id="local-6989586621679157492"><span class="annot"><span class="annottext">i :: Int32 </span><a href="#local-6989586621679157492"><span class="hs-identifier hs-var">i</span></a></span></span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">[Int32] -> Tensor Build Int32 forall a. TensorType a => [a] -> Tensor Build a </span><a href="TensorFlow.Ops.html#vector"><span class="hs-identifier hs-var">vector</span></a></span><span> </span><span class="hs-special">[</span><span class="annot"><span class="annottext">Int32 </span><a href="#local-6989586621679157492"><span class="hs-identifier hs-var">i</span></a></span><span> </span><span class="hs-glyph">::</span><span> </span><span class="annot"><span class="hs-identifier hs-type">Int32</span></span><span class="hs-special">]</span><span> </span><span id="line-90"></span><span> </span><span id="line-91"></span><span class="annot"><a href="TensorFlow.EmbeddingOps.html#embeddingLookup"><span class="hs-identifier hs-var">embeddingLookup</span></a></span><span> </span><span class="hs-special">[</span><span class="hs-special">]</span><span> </span><span class="hs-identifier">_</span><span> </span><span class="hs-glyph">=</span><span> </span><span class="annot"><span class="annottext">[Char] -> m (Tensor Value a) forall a. HasCallStack => [Char] -> a </span><span class="hs-identifier hs-var">error</span></span><span> </span><span class="annot"><span class="hs-string">"embeddingLookup requires params to be non empty"</span></span><span> </span><span id="line-92"></span></pre></body></html>