1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-07-03 18:03:29 +02:00
tensorflow-haskell/docs/haddock/tensorflow-ops-0.1.0.0/src/TensorFlow.EmbeddingOps.html
2017-10-19 20:56:38 -07:00

92 lines
12 KiB
HTML

<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"><head><link rel="stylesheet" type="text/css" href="style.css" /><script type="text/javascript" src="highlight.js"></script></head><body><pre><span class="hs-comment">-- Copyright 2016 TensorFlow authors.</span><span>
</span><a name="line-2"></a><span class="hs-comment">--</span><span>
</span><a name="line-3"></a><span class="hs-comment">-- Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span><span>
</span><a name="line-4"></a><span class="hs-comment">-- you may not use this file except in compliance with the License.</span><span>
</span><a name="line-5"></a><span class="hs-comment">-- You may obtain a copy of the License at</span><span>
</span><a name="line-6"></a><span class="hs-comment">--</span><span>
</span><a name="line-7"></a><span class="hs-comment">-- http://www.apache.org/licenses/LICENSE-2.0</span><span>
</span><a name="line-8"></a><span class="hs-comment">--</span><span>
</span><a name="line-9"></a><span class="hs-comment">-- Unless required by applicable law or agreed to in writing, software</span><span>
</span><a name="line-10"></a><span class="hs-comment">-- distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span><span>
</span><a name="line-11"></a><span class="hs-comment">-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span><span>
</span><a name="line-12"></a><span class="hs-comment">-- See the License for the specific language governing permissions and</span><span>
</span><a name="line-13"></a><span class="hs-comment">-- limitations under the License.</span><span>
</span><a name="line-14"></a><span>
</span><a name="line-15"></a><span class="hs-pragma">{-# LANGUAGE ConstraintKinds #-}</span><span>
</span><a name="line-16"></a><span class="hs-pragma">{-# LANGUAGE DataKinds #-}</span><span>
</span><a name="line-17"></a><span class="hs-pragma">{-# LANGUAGE FlexibleContexts #-}</span><span>
</span><a name="line-18"></a><span class="hs-pragma">{-# LANGUAGE NoMonomorphismRestriction #-}</span><span>
</span><a name="line-19"></a><span class="hs-pragma">{-# LANGUAGE OverloadedStrings #-}</span><span>
</span><a name="line-20"></a><span class="hs-pragma">{-# LANGUAGE RankNTypes #-}</span><span>
</span><a name="line-21"></a><span>
</span><a name="line-22"></a><span class="hs-comment">-- | Parallel lookups on the list of tensors.</span><span>
</span><a name="line-23"></a><span class="hs-keyword">module</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">EmbeddingOps</span><span> </span><span class="hs-keyword">where</span><span>
</span><a name="line-24"></a><span>
</span><a name="line-25"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Control</span><span class="hs-operator">.</span><span class="hs-identifier">Monad</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">zipWithM</span><span class="hs-special">)</span><span>
</span><a name="line-26"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Int</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Int32</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Int64</span><span class="hs-special">)</span><span>
</span><a name="line-27"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Build</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">MonadBuild</span><span class="hs-special">)</span><span>
</span><a name="line-28"></a><span class="hs-keyword">import</span><span> </span><a href="TensorFlow.Ops.html"><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Ops</span></a><span> </span><span class="hs-special">(</span><a href="TensorFlow.Ops.html#shape"><span class="hs-identifier hs-var">shape</span></a><span class="hs-special">,</span><span> </span><a href="TensorFlow.Ops.html#vector"><span class="hs-identifier hs-var">vector</span></a><span class="hs-special">)</span><span> </span><span class="hs-comment">-- Also Num instance for Tensor</span><span>
</span><a name="line-29"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Tensor</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Tensor</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Value</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Rendered</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">colocateWith</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">render</span><span class="hs-special">)</span><span>
</span><a name="line-30"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Types</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">OneOf</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">TensorType</span><span class="hs-special">)</span><span>
</span><a name="line-31"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">GenOps</span><span class="hs-operator">.</span><span class="hs-identifier">Core</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">CoreOps</span><span>
</span><a name="line-32"></a><span>
</span><a name="line-33"></a><span class="hs-comment">-- | Looks up `ids` in a list of embedding tensors.</span><span>
</span><a name="line-34"></a><span class="hs-comment">--</span><span>
</span><a name="line-35"></a><span class="hs-comment">-- This function is used to perform parallel lookups on the list of</span><span>
</span><a name="line-36"></a><span class="hs-comment">-- tensors in `params`. It is a generalization of `TF.gather`, where</span><span>
</span><a name="line-37"></a><span class="hs-comment">-- `params` is interpreted as a partition of a larger embedding</span><span>
</span><a name="line-38"></a><span class="hs-comment">-- tensor.</span><span>
</span><a name="line-39"></a><span class="hs-comment">--</span><span>
</span><a name="line-40"></a><span class="hs-comment">-- The partition_strategy is &quot;mod&quot;, we assign each id to partition</span><span>
</span><a name="line-41"></a><span class="hs-comment">-- `p = id % len(params)`. For instance,</span><span>
</span><a name="line-42"></a><span class="hs-comment">-- 13 ids are split across 5 partitions as:</span><span>
</span><a name="line-43"></a><span class="hs-comment">-- `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`</span><span>
</span><a name="line-44"></a><span class="hs-comment">--</span><span>
</span><a name="line-45"></a><span class="hs-comment">-- The results of the lookup are concatenated into a dense</span><span>
</span><a name="line-46"></a><span class="hs-comment">-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.</span><span>
</span><a name="line-47"></a><span class="hs-identifier">embeddingLookup</span><span> </span><span class="hs-glyph">::</span><span> </span><span class="hs-keyword">forall</span><span> </span><a name="local-6989586621679110665"><a href="#local-6989586621679110665"><span class="hs-identifier">a</span></a></a><span> </span><a name="local-6989586621679110666"><a href="#local-6989586621679110666"><span class="hs-identifier">b</span></a></a><span> </span><a name="local-6989586621679110667"><a href="#local-6989586621679110667"><span class="hs-identifier">v1</span></a></a><span> </span><a name="local-6989586621679110668"><a href="#local-6989586621679110668"><span class="hs-identifier">v2</span></a></a><span> </span><a name="local-6989586621679110669"><a href="#local-6989586621679110669"><span class="hs-identifier">m</span></a></a><span> </span><span class="hs-operator">.</span><span>
</span><a name="line-48"></a><span> </span><span class="hs-special">(</span><span> </span><span class="hs-identifier hs-type">MonadBuild</span><span> </span><a href="#local-6989586621679110669"><span class="hs-identifier hs-type">m</span></a><span>
</span><a name="line-49"></a><span> </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Rendered</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Tensor</span><span> </span><a href="#local-6989586621679110667"><span class="hs-identifier hs-type">v1</span></a><span class="hs-special">)</span><span>
</span><a name="line-50"></a><span> </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">TensorType</span><span> </span><a href="#local-6989586621679110665"><span class="hs-identifier hs-type">a</span></a><span>
</span><a name="line-51"></a><span> </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">OneOf</span><span> </span><span class="hs-char">'[Int64, Int32] b
, Num b
)
=&gt; [Tensor v1 a]
-- ^ A list of tensors which can be concatenated along
-- dimension 0. Each `Tensor` must be appropriately
-- sized for `mod` partition strategy.
-&gt; Tensor v2 b
-- ^ A `Tensor` with type `int32` or `int64`
-- containing the ids to be looked up in `params`.
-- The ids are required to have fewer than 2^31
-- entries.
-&gt; m (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
embeddingLookup params@(p0 : _) ids = do
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
partitionedResult &lt;- zipWithM
(\p g -&gt; colocateWith p $ render $ CoreOps.gather p g)
params gatherIds
let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult
-- Shape restoration is not as optimal as it would be with client
-- side shape tracking.
paramShape &lt;- colocateWith p0 (render (shape p0))
let finalShape = CoreOps.concat 0 [shape ids, tailShape]
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
render $ CoreOps.reshape unshapedResult finalShape
where
-- Avoids genericLength here which would be evaluated by TF.
np = fromIntegral (length params)
flatIds = CoreOps.reshape ids (singleton (-1))
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
newIds = flatIds `CoreOps.div` np
originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1
-- Partition list of ids based on assignments into np separate lists
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
-- Similarly, partition the original indices.
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
singleton i = vector [i :: Int32]
embeddingLookup [] _ = error &quot;embeddingLookup requires params to be non empty&quot;
</span></pre></body></html>