<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"><head><link rel="stylesheet" type="text/css" href="style.css" /><script type="text/javascript" src="highlight.js"></script></head><body><pre><span class="hs-comment">-- Copyright 2016 TensorFlow authors.</span><span>
</span><a name="line-2"></a><span class="hs-comment">--</span><span>
</span><a name="line-3"></a><span class="hs-comment">-- Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span><span>
</span><a name="line-4"></a><span class="hs-comment">-- you may not use this file except in compliance with the License.</span><span>
</span><a name="line-5"></a><span class="hs-comment">-- You may obtain a copy of the License at</span><span>
</span><a name="line-6"></a><span class="hs-comment">--</span><span>
</span><a name="line-7"></a><span class="hs-comment">--     http://www.apache.org/licenses/LICENSE-2.0</span><span>
</span><a name="line-8"></a><span class="hs-comment">--</span><span>
</span><a name="line-9"></a><span class="hs-comment">-- Unless required by applicable law or agreed to in writing, software</span><span>
</span><a name="line-10"></a><span class="hs-comment">-- distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span><span>
</span><a name="line-11"></a><span class="hs-comment">-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span><span>
</span><a name="line-12"></a><span class="hs-comment">-- See the License for the specific language governing permissions and</span><span>
</span><a name="line-13"></a><span class="hs-comment">-- limitations under the License.</span><span>
</span><a name="line-14"></a><span>
</span><a name="line-15"></a><span class="hs-pragma">{-# LANGUAGE ConstraintKinds #-}</span><span>
</span><a name="line-16"></a><span class="hs-pragma">{-# LANGUAGE DataKinds #-}</span><span>
</span><a name="line-17"></a><span class="hs-pragma">{-# LANGUAGE FlexibleContexts #-}</span><span>
</span><a name="line-18"></a><span class="hs-pragma">{-# LANGUAGE OverloadedStrings #-}</span><span>
</span><a name="line-19"></a><span class="hs-pragma">{-# LANGUAGE RankNTypes #-}</span><span>
</span><a name="line-20"></a><span class="hs-pragma">{-# LANGUAGE ScopedTypeVariables #-}</span><span>
</span><a name="line-21"></a><span class="hs-pragma">{-# LANGUAGE TypeFamilies #-}</span><span>
</span><a name="line-22"></a><span class="hs-pragma">{-# LANGUAGE ViewPatterns #-}</span><span>
</span><a name="line-23"></a><span>
</span><a name="line-24"></a><span class="hs-keyword">module</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Gradient</span><span>
</span><a name="line-25"></a><span>    </span><span class="hs-special">(</span><span> </span><a href="TensorFlow.Gradient.html#GradientCompatible"><span class="hs-identifier hs-type">GradientCompatible</span></a><span>
</span><a name="line-26"></a><span>    </span><span class="hs-special">,</span><span> </span><a href="TensorFlow.Gradient.html#gradients"><span class="hs-identifier hs-var">gradients</span></a><span>
</span><a name="line-27"></a><span>    </span><span class="hs-special">)</span><span> </span><span class="hs-keyword">where</span><span>
</span><a name="line-28"></a><span>
</span><a name="line-29"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Control</span><span class="hs-operator">.</span><span class="hs-identifier">Monad</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">forM</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">zipWithM</span><span class="hs-special">)</span><span>
</span><a name="line-30"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Control</span><span class="hs-operator">.</span><span class="hs-identifier">Monad</span><span class="hs-operator">.</span><span class="hs-identifier">State</span><span class="hs-operator">.</span><span class="hs-identifier">Strict</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">State</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">evalState</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">gets</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">modify</span><span class="hs-special">)</span><span>
</span><a name="line-31"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">ByteString</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">ByteString</span><span class="hs-special">)</span><span>
</span><a name="line-32"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Complex</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Complex</span><span class="hs-special">)</span><span>
</span><a name="line-33"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Default</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">def</span><span class="hs-special">)</span><span>
</span><a name="line-34"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Int</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Int32</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Int64</span><span class="hs-special">)</span><span>
</span><a name="line-35"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Foldable</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">foldlM</span><span class="hs-special">)</span><span>
</span><a name="line-36"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">List</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">foldl'</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">sortBy</span><span class="hs-special">)</span><span>
</span><a name="line-37"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Map</span><span class="hs-operator">.</span><span class="hs-identifier">Strict</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Map</span><span class="hs-special">)</span><span>
</span><a name="line-38"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Maybe</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">fromMaybe</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">maybeToList</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">mapMaybe</span><span class="hs-special">)</span><span>
</span><a name="line-39"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Ord</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">comparing</span><span class="hs-special">)</span><span>
</span><a name="line-40"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">ProtoLens</span><span class="hs-operator">.</span><span class="hs-identifier">TextFormat</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">showMessage</span><span class="hs-special">)</span><span>
</span><a name="line-41"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Set</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Set</span><span class="hs-special">)</span><span>
</span><a name="line-42"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Text</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Text</span><span class="hs-special">)</span><span>
</span><a name="line-43"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Tuple</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">swap</span><span class="hs-special">)</span><span>
</span><a name="line-44"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Lens</span><span class="hs-operator">.</span><span class="hs-identifier">Family2</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Lens'</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">view</span><span class="hs-special">,</span><span> </span><span class="hs-special">(</span><span class="hs-operator hs-var">&amp;</span><span class="hs-special">)</span><span class="hs-special">,</span><span> </span><span class="hs-special">(</span><span class="hs-operator hs-var">^.</span><span class="hs-special">)</span><span class="hs-special">,</span><span> </span><span class="hs-special">(</span><span class="hs-operator hs-var">.~</span><span class="hs-special">)</span><span class="hs-special">,</span><span> </span><span class="hs-special">(</span><span class="hs-operator hs-var">%~</span><span class="hs-special">)</span><span class="hs-special">)</span><span>
</span><a name="line-45"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Lens</span><span class="hs-operator">.</span><span class="hs-identifier">Family2</span><span class="hs-operator">.</span><span class="hs-identifier">State</span><span class="hs-operator">.</span><span class="hs-identifier">Strict</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">uses</span><span class="hs-special">)</span><span>
</span><a name="line-46"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Lens</span><span class="hs-operator">.</span><span class="hs-identifier">Family2</span><span class="hs-operator">.</span><span class="hs-identifier">Stock</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">at</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">intAt</span><span class="hs-special">)</span><span>
</span><a name="line-47"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Lens</span><span class="hs-operator">.</span><span class="hs-identifier">Family2</span><span class="hs-operator">.</span><span class="hs-identifier">Unchecked</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">lens</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">iso</span><span class="hs-special">)</span><span>
</span><a name="line-48"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Prelude</span><span> </span><span class="hs-keyword">hiding</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">sum</span><span class="hs-special">)</span><span>
</span><a name="line-49"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Text</span><span class="hs-operator">.</span><span class="hs-identifier">Printf</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-var">printf</span><span class="hs-special">)</span><span>
</span><a name="line-50"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Graph</span><span class="hs-operator">.</span><span class="hs-identifier">Inductive</span><span class="hs-operator">.</span><span class="hs-identifier">Basic</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">FGL</span><span>
</span><a name="line-51"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Graph</span><span class="hs-operator">.</span><span class="hs-identifier">Inductive</span><span class="hs-operator">.</span><span class="hs-identifier">Graph</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">FGL</span><span>
</span><a name="line-52"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Graph</span><span class="hs-operator">.</span><span class="hs-identifier">Inductive</span><span class="hs-operator">.</span><span class="hs-identifier">PatriciaTree</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">FGL</span><span>
</span><a name="line-53"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Graph</span><span class="hs-operator">.</span><span class="hs-identifier">Inductive</span><span class="hs-operator">.</span><span class="hs-identifier">Query</span><span class="hs-operator">.</span><span class="hs-identifier">DFS</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">FGL</span><span>
</span><a name="line-54"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">IntMap</span><span class="hs-operator">.</span><span class="hs-identifier">Strict</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">IntMap</span><span>
</span><a name="line-55"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Map</span><span class="hs-operator">.</span><span class="hs-identifier">Strict</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">Map</span><span>
</span><a name="line-56"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Set</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">Set</span><span>
</span><a name="line-57"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">Data</span><span class="hs-operator">.</span><span class="hs-identifier">Text</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">Text</span><span>
</span><a name="line-58"></a><span>
</span><a name="line-59"></a><span class="hs-keyword">import</span><span> </span><span class="hs-keyword">qualified</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">GenOps</span><span class="hs-operator">.</span><span class="hs-identifier">Core</span><span> </span><span class="hs-keyword">as</span><span> </span><span class="hs-identifier">CoreOps</span><span>
</span><a name="line-60"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Build</span><span>
</span><a name="line-61"></a><span>    </span><span class="hs-special">(</span><span> </span><span class="hs-identifier hs-type">MonadBuild</span><span>
</span><a name="line-62"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Build</span><span>
</span><a name="line-63"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">build</span><span>
</span><a name="line-64"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">renderedNodeDefs</span><span>
</span><a name="line-65"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">opDef</span><span>
</span><a name="line-66"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">opAttr</span><span>
</span><a name="line-67"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">opInputs</span><span>
</span><a name="line-68"></a><span>    </span><span class="hs-special">)</span><span>
</span><a name="line-69"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">BuildOp</span><span>
</span><a name="line-70"></a><span class="hs-keyword">import</span><span> </span><a href="TensorFlow.Ops.html"><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Ops</span></a><span>
</span><a name="line-71"></a><span>    </span><span class="hs-special">(</span><span> </span><span class="hs-identifier hs-var">addN</span><span>
</span><a name="line-72"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">broadcastGradientArgs</span><span>
</span><a name="line-73"></a><span>    </span><span class="hs-special">,</span><span> </span><a href="TensorFlow.Ops.html#expandDims"><span class="hs-identifier hs-var">expandDims</span></a><span>
</span><a name="line-74"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">fill</span><span>
</span><a name="line-75"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">matMul</span><span>
</span><a name="line-76"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">matMul'</span><span>
</span><a name="line-77"></a><span>    </span><span class="hs-special">,</span><span> </span><a href="TensorFlow.Ops.html#reducedShape"><span class="hs-identifier hs-var">reducedShape</span></a><span>
</span><a name="line-78"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">reluGrad</span><span>
</span><a name="line-79"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">reshape</span><span>
</span><a name="line-80"></a><span>    </span><span class="hs-special">,</span><span> </span><a href="TensorFlow.Ops.html#scalar"><span class="hs-identifier hs-var">scalar</span></a><span>
</span><a name="line-81"></a><span>    </span><span class="hs-special">,</span><span> </span><a href="TensorFlow.Ops.html#shape"><span class="hs-identifier hs-var">shape</span></a><span>
</span><a name="line-82"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">softmaxCrossEntropyWithLogits</span><span>
</span><a name="line-83"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">sum</span><span>
</span><a name="line-84"></a><span>    </span><span class="hs-special">,</span><span> </span><a href="TensorFlow.Ops.html#scalarize"><span class="hs-identifier hs-var">scalarize</span></a><span>
</span><a name="line-85"></a><span>    </span><span class="hs-special">,</span><span> </span><a href="TensorFlow.Ops.html#vector"><span class="hs-identifier hs-var">vector</span></a><span>
</span><a name="line-86"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">zerosLike</span><span>
</span><a name="line-87"></a><span>    </span><span class="hs-special">)</span><span>
</span><a name="line-88"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Output</span><span>
</span><a name="line-89"></a><span>    </span><span class="hs-special">(</span><span> </span><span class="hs-identifier hs-type">NodeName</span><span class="hs-special">(</span><span class="hs-glyph">..</span><span class="hs-special">)</span><span>
</span><a name="line-90"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Output</span><span class="hs-special">(</span><span class="hs-glyph">..</span><span class="hs-special">)</span><span>
</span><a name="line-91"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">OutputIx</span><span class="hs-special">(</span><span class="hs-glyph">..</span><span class="hs-special">)</span><span>
</span><a name="line-92"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">outputIndex</span><span>
</span><a name="line-93"></a><span>    </span><span class="hs-special">)</span><span>
</span><a name="line-94"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Tensor</span><span>
</span><a name="line-95"></a><span>    </span><span class="hs-special">(</span><span> </span><span class="hs-identifier hs-type">Tensor</span><span class="hs-special">(</span><span class="hs-glyph">..</span><span class="hs-special">)</span><span>
</span><a name="line-96"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Value</span><span>
</span><a name="line-97"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">render</span><span>
</span><a name="line-98"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">expr</span><span>
</span><a name="line-99"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">Rendered</span><span>
</span><a name="line-100"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">tensorNodeName</span><span>
</span><a name="line-101"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">renderedOutput</span><span>
</span><a name="line-102"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">renderValue</span><span>
</span><a name="line-103"></a><span>    </span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">ToTensor</span><span class="hs-special">(</span><span class="hs-glyph">..</span><span class="hs-special">)</span><span>
</span><a name="line-104"></a><span>    </span><span class="hs-special">)</span><span>
</span><a name="line-105"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">TensorFlow</span><span class="hs-operator">.</span><span class="hs-identifier">Types</span><span> </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Attribute</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">OneOf</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">TensorType</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">attrLens</span><span class="hs-special">)</span><span>
</span><a name="line-106"></a><span class="hs-keyword">import</span><span> </span><span class="hs-identifier">Proto</span><span class="hs-operator">.</span><span class="hs-identifier">Tensorflow</span><span class="hs-operator">.</span><span class="hs-identifier">Core</span><span class="hs-operator">.</span><span class="hs-identifier">Framework</span><span class="hs-operator">.</span><span class="hs-identifier">NodeDef</span><span>
</span><a name="line-107"></a><span>    </span><span class="hs-special">(</span><span class="hs-identifier hs-type">NodeDef</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">attr</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">input</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">op</span><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-var">name</span><span class="hs-special">)</span><span>
</span><a name="line-108"></a><span>
</span><a name="line-109"></a><span class="hs-keyword">type</span><span> </span><a name="GradientCompatible"><a href="TensorFlow.Gradient.html#GradientCompatible"><span class="hs-identifier">GradientCompatible</span></a></a><span> </span><a name="local-6989586621679093623"><a href="#local-6989586621679093623"><span class="hs-identifier">a</span></a></a><span> </span><span class="hs-glyph">=</span><span>
</span><a name="line-110"></a><span>    </span><span class="hs-comment">-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.</span><span>
</span><a name="line-111"></a><span>    </span><span class="hs-special">(</span><span class="hs-identifier hs-type">Num</span><span> </span><a href="#local-6989586621679093623"><span class="hs-identifier hs-type">a</span></a><span class="hs-special">,</span><span> </span><span class="hs-identifier hs-type">OneOf</span><span> </span><span class="hs-char">'[ Float, Complex Float, Complex Double ] a)

-- TODO(fmayle): Support control flow.
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.
-- TODO(fmayle): Do we need to consider control inputs? See _PendingCount in
-- tensorflow/python/ops/gradients.py.
-- TODO(fmayle): Maybe store the gradient functions and numOutputs on the OpDef.


-- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 t m . ( MonadBuild m
                               , Rendered t
                               , ToTensor t
                               , GradientCompatible a
                               )
          =&gt; Tensor v1 a  -- ^ The output of the graph.
          -&gt; [t a]        -- ^ Tensors for which gradients are computed.
          -&gt; m [Tensor Value a]
gradients y xs = build $ do
    -- The gradients are computed using &quot;reverse accumulation&quot;, similarly to
    -- what is described here:
    -- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
    --
    -- The code is summarised as follows:
    --
    -- 1. Create an fgl graph of the relevant nodes (ops) and edges (tensors).
    -- 2. Initialize the gradient of y to 1 (&#8706;y/&#8706;y = 1) and the rest of tensor's
    --    gradients to nothing.
    -- 3. Process the nodes in reverse topological order (i.e. each node comes
    --    after all of its outputs so that the output gradients for a node have
    --    been completely calculated before it is processed):
    --      a. Record the gradient for each of the node's output tensors (&#8706;y/&#8706;w
    --         for each output tensor w).
    --      b. Calculate the gradient of y w.r.t. each of the node's input
    --         tensors using the gradients of the node's output tensors.
    --
    --         Written differently, for each output tensor w and input tensor v:
    --           &#8706;y/&#8706;w = ...            (calculated in previous steps)
    --           &#8706;w/&#8706;v = ...            (op specific)
    --           &#8706;y/&#8706;v = &#8706;y/&#8706;w * &#8706;w/&#8706;v  (technically, if tensor v is an input
    --                                   to multiple nodes, then this is only
    --                                   part of &#8706;y/&#8706;v)
    --
    -- 4. Lookup the recorded gradient for each x in xs.

    y' &lt;- renderValue y
    let yName = tensorNodeName y'
    yOne &lt;- render $ fill (shape y') (scalar 1)
    -- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
    nodeDefLookup :: (NodeName -&gt; NodeDef) &lt;- uses renderedNodeDefs $
        (\f x -&gt; fromMaybe (error $ &quot;no NodeDef found for &quot; ++ show x) (f x))
        . flip Map.lookup
    let (gr, nodeMap) = createGraph yName nodeDefLookup
    -- Set gradient of y to one.
    -- TODO: nicer
    let initPending :: Map.Map FGL.Node (PendingGradients a)
            = Map.empty &amp; (at (nodeMap Map.! yName)
                                . nonEmpty
                                . outputIxAt (outputIndex $ renderedOutput y')
                                . nonEmpty
                                .~ [yOne]
                                )
    -- Calculate the gradients of y w.r.t. each node in the graph.
    gradientMap &lt;- graphGrads gr initPending
    -- Lookup the gradients for each x.
    forM xs $ \x -&gt;
        let Output i xName = renderedOutput x
        in maybe (render $ zerosLike $ toTensor x) return $ do
            n &lt;- nodeMap ^. at xName
            gradientMap ^. at n . nonEmpty . outputIxAt i

outputIxAt :: OutputIx -&gt; Lens' (IntMap.IntMap v) (Maybe v)
outputIxAt = intAt . unOutputIx

-- | Incomplete gradients of a node's outputs.
--
-- The lists represent partial sums. The key is an OutputIx sans newtype.
type PendingGradients a = IntMap.IntMap [Tensor Value a]

-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
-- TODO: precache the rendering?
type Gradients a = IntMap.IntMap (Tensor Value a)

-- | Graph of TensorFlow operations.
type Graph = FGL.Gr NodeDef EdgeLabel

-- | Data associated with an edge.
--
-- Pair of
--   1. Output index of a tensor from the source node.
--   2. Input index that the tensor connects to on the destination node.
type EdgeLabel = (OutputIx, OutputIx)


-- | State used for calculating gradients.
data GradientsState a = GradientsState
                      { _gradientsPending :: !(Map FGL.Node (PendingGradients a))
                      , _gradientsResult  :: !(Map FGL.Node (Gradients a))
                      }

gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a))
gradientsPending = lens _gradientsPending (\x y -&gt; x { _gradientsPending = y })

gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a))
gradientsResult = lens _gradientsResult (\x y -&gt; x { _gradientsResult = y })


-- TODO(fmayle): Use something like Data.List.Safe.
-- | Safe version of (!!).
safeIndex :: [a] -&gt; Int -&gt; Maybe a
_      `safeIndex` n | n &lt; 0 = Nothing
[]     `safeIndex` _         = Nothing
(x:_)  `safeIndex` 0         = Just x
(_:xs) `safeIndex` n         = xs `safeIndex` (n-1)

-- Copy of http://hackage.haskell.org/package/lens-3.9.0.2/docs/Control-Lens-Iso.html#v%3anon
anon :: a -&gt; (a -&gt; Bool) -&gt; Lens' (Maybe a) a
anon a p = iso (fromMaybe a) go where
  go b | p b       = Nothing
       | otherwise = Just b

non :: Eq a =&gt; a -&gt; Lens' (Maybe a) a
non a = anon a (a==)

-- | Lens that defaults Nothing to mempty.
nonEmpty :: (Monoid (t v), Foldable t) =&gt; Lens' (Maybe (t v)) (t v)
nonEmpty = anon mempty null

-- TODO: strictness (e.g., foldlM')

-- | Calculate the gradients for every node in a graph.
graphGrads :: forall a. GradientCompatible a
           =&gt; Graph
           -&gt; Map FGL.Node (PendingGradients a)
           -- ^ Initial gradients (usually just 1 for the node of interest).
           -&gt; Build (Map FGL.Node (Gradients a))
graphGrads gr initPending = view gradientsResult &lt;$&gt; foldlM go initState nodeOrder
  where
    initState = GradientsState initPending Map.empty
    -- Reverse topological sort.
    -- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
    -- avoid calculating gradients that won't be used.
    nodeOrder = FGL.topsort $ FGL.grev gr
    go :: GradientsState a -&gt; Int -&gt; Build (GradientsState a)
    go state node = do
        -- Aggregate the accumulated gradients for this node.
        outputGrads &lt;-
                sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
        if null outputGrads
           then pure state
           else do
              let ctx = FGL.context gr node
              inputGrads &lt;- calculateInputGrads ctx outputGrads gr
              -- Calculate the gradients for each of the node's inputs.
              let nextState = state &amp; gradientsResult %~ Map.insert node outputGrads
              pure $ updatePendingGradients ctx inputGrads nextState

-- | Reduce accumulated gradients for each output to one Tensor.
sumPendingGradient :: GradientCompatible a
                   =&gt; PendingGradients a -&gt; Build (Gradients a)
sumPendingGradient = sequence . IntMap.mapMaybe f
  where
    f [] = Nothing
    f [x] = Just (pure x)
    f xs = Just (render $ addN xs)


-- | Calculate the gradients of a node's input tensors.
--
-- This is mostly just a wrapper around opGrad.
calculateInputGrads :: forall a. GradientCompatible a
                    =&gt; FGL.Context NodeDef EdgeLabel
                    -&gt; Gradients a  -- ^ Output gradients of the node.
                    -&gt; Graph
                    -&gt; Build [Maybe (Tensor Value a)]
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = do
    fullOutGrads &lt;- fullOutputGrads (numOutputs nodeDef) (nodeDefName nodeDef)
                        outputGrads
    traverse (traverse render) $ opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
  where
    -- Create a tensor from an edge (technically an Output, but it seems less
    -- confusing to refer to it as a tensor here).
    edgeToTensor :: (EdgeLabel, FGL.Node) -&gt; Output
    edgeToTensor ((i, _), n) =
        case FGL.lab gr n of
            Just edgeNodeDef -&gt; Output i (NodeName $ edgeNodeDef ^. name)
            Nothing -&gt; error $ &quot;calculateInputGrads: missing input node for &quot;
                               ++ Text.unpack (nodeDef ^. name)
    -- Input tensors, sorted by input index.
    inputTensors = map edgeToTensor $ sortBy (comparing (snd . fst)) inputEdges

-- | Convert a Map of gradients to a list, with zeros for missing outputs.
fullOutputGrads :: (TensorType a, Num a)
                =&gt; OutputIx  -- ^ Number of outputs.
                -&gt; NodeName
                -&gt; Gradients a
                -&gt; Build [Tensor Value a]
fullOutputGrads n o gs =
    mapM (\i -&gt; maybe (render $ zero i) return (gs ^. outputIxAt i)) [0..n-1]
  where
    -- A tensor of zeros with the same shape as the i'th output.
    zero i = zerosLike $ toT (Output i o)


-- | Update the pending gradients of a node's inputs.
updatePendingGradients :: forall a. (TensorType a, Num a)
                       =&gt; FGL.Context NodeDef EdgeLabel
                       -&gt; [Maybe (Tensor Value a)]
                       -- ^ Gradient of each input tensor.
                       -&gt; GradientsState a
                       -&gt; GradientsState a
updatePendingGradients (inputEdges, _, nodeDef, _) inputGrads initState =
    foldl' go initState inputEdges
  where
    go :: GradientsState a -&gt; (EdgeLabel, FGL.Node) -&gt; GradientsState a
    go state ((outIndex, OutputIx inIndex), node) =
        case maybeGradient of
            Nothing -&gt; state
            Just g -&gt;
                -- Add to the list of pending gradients for this tensor.
                state &amp; gradientsPending
                      . at node
                      . nonEmpty
                      . outputIxAt outIndex
                      . nonEmpty
                      %~ (g:)
      where
        badSizeErr = error $ printf &quot;updatePendingGradients: bad input index \
                                    \%d for inputGrads of length %d in %s&quot;
                                    inIndex (length inputGrads)
                                    (show (nodeDef ^. name))
        maybeGradient = fromMaybe badSizeErr (safeIndex inputGrads inIndex)


-- | Create a graph that includes a node and its transitive dependencies.
createGraph :: NodeName -&gt; (NodeName -&gt; NodeDef)
            -&gt; (Graph, Map NodeName FGL.Node)
createGraph nodeName nodeDefLookup = (FGL.nmap nodeDefLookup graph, nodeMap)
  where
    -- Parse a tensor name.
    parseTensorName :: Text -&gt; Maybe (NodeName, OutputIx)
    parseTensorName n
        | Text.null n        = error &quot;parseTensorName: empty name&quot;
        | Text.head n == '^' = Nothing  -- Control edge
        | otherwise          =
            let (nm, indexStr) = Text.breakOn &quot;:&quot; n
                index | Text.null indexStr = 0
                      | otherwise = read $ Text.unpack $ Text.tail indexStr
            in Just (NodeName nm, OutputIx index)

    -- Build a map from node name to outward edges.
    --
    -- The state is the set of visited nodes.
    collect :: Maybe (NodeName, OutputIx, OutputIx)
            -&gt; NodeName
            -&gt; State (Set NodeName)
                     (Map NodeName [(NodeName, OutputIx, OutputIx)])
    collect outgoingEdge nm = do
        let nextLookup = Map.singleton nm (maybeToList outgoingEdge)
        seen &lt;- gets (Set.member nm)
        modify (Set.insert nm)
        if seen
            then pure nextLookup
            else do
                let inputs = nodeDefLookup nm ^. input
                    recurse inIndex (parentName, outIndex) =
                        collect (Just (nm, outIndex, inIndex)) parentName
                subEdgeLookups &lt;-
                    zipWithM recurse [0..] $ mapMaybe parseTensorName inputs
                pure $ Map.unionsWith (++) (nextLookup:subEdgeLookups)

    edgeLookup = evalState (collect Nothing nodeName) Set.empty
    -- Associate an ID with each node name.
    nodeMap = Map.fromList $ zip (Map.keys edgeLookup) [0..]
    -- Create the graph.
    graph = FGL.mkGraph (swap &lt;$&gt; Map.toList nodeMap)
                        [ (nodeMap Map.! n, nodeMap Map.! m, (i, j))
                        | (n, edges) &lt;- Map.toList edgeLookup
                        , (m, i, j) &lt;- edges
                        ]

-- | Function to compute the gradient of y w.r.t. each input.
--
-- Let y be an arbitrary tensor
-- and [w_0, ..., w_n] be the output tensors of a node
-- and [v_0, ..., v_n] be the input tensors of the same node.
--
-- Given [&#8706;y/&#8706;w_0, ..., &#8706;y/&#8706;w_n] and [v_0, ..., v_n], a GradientFunc computes
-- [&#8706;y/&#8706;v_0, ..., &#8706;y/&#8706;v_n] for a particular op type.
--
-- A Nothing gradient is equivalent to zero (but allows for short circuiting
-- computation when all the gradients for something are Nothing).
type GradientFunc a = NodeDef
                    -&gt; [Output]
                    -- ^ Input tensors.
                    -&gt; [Tensor Value a]
                    -- ^ Gradient of y w.r.t. each output tensor.
                    -&gt; [Maybe (Tensor Build a)]
                    -- ^ Gradient of y w.r.t. each input tensor.


-- TODO(fmayle): Assert the type is correct.
-- | Create a Tensor from an Output.
toT :: Output -&gt; Tensor Build a
toT = Tensor . pure


-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
-- simple slicing operations.
flatSlice :: forall v1 t . TensorType t
         =&gt; Tensor v1 t    -- ^ __input__
         -&gt; Int32          -- ^ __begin__: specifies the offset into the first dimension of
                           -- 'input' to slice from.
         -&gt; Int32          -- ^ __size__: specifies the number of elements of the first dimension
                           -- of 'input' to slice. If size is -1, all remaining elements in the dimension
                           -- are included in the slice (i.e. this is equivalent to setting
                           -- size = input.dim_size(0) - begin).
         -&gt; Tensor Build t -- ^ __output__
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])

nodeDefName :: NodeDef -&gt; NodeName
nodeDefName = NodeName . view name

-- | Gradient helper for binary component wise operations
-- See https://github.com/tensorflow/tensorflow/blob/e9de087fa7f59c39bbe12ac2c83c5547c83f746c/tensorflow/core/ops/math_grad.cc#L329
gradForBinaryCwise :: ( OneOf '[ Int32, Int64, Float, Double, Complex Float, Complex Double ] t
                      )
                   =&gt; (Tensor v1 t, Tensor v1 t)
                   -&gt; (Tensor v1 t, Tensor v1 t)
                   -&gt; [ Maybe (Tensor Build t) ]
gradForBinaryCwise (x, gx) (y, gy) =
    [ Just dx
    , Just dy ]
  where
    dx = reshape (sum gx rx) sx
    dy = reshape (sum gy ry) sy
    sx = shape x
    sy = shape y
    (rx, ry) = broadcastGradientArgs sx sy

-- | The gradient function for an op type.
--
-- These implementations should match their python counterparts in:
-- third_party/tensorflow/python/ops/*_grad.py
opGrad :: forall a . GradientCompatible a =&gt; Text -&gt; GradientFunc a

opGrad &quot;Abs&quot; _ [toT -&gt; x] [dz] = [Just $ expr dz * signum x]
opGrad &quot;Neg&quot; _ [_] [dz] = [Just $ negate $ expr dz]
opGrad &quot;Relu&quot; _ [toT -&gt; x] [dz] = [Just $ reluGrad dz x]
opGrad &quot;ReluGrad&quot; _ [_, toT -&gt; x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]

opGrad &quot;Concat&quot; _ _ix [dy]
    -- Concat concatenates input tensors
    --   x1 of shape s1 = [k1, ..., ki_1, ..., kn]
    --   x2 of shape s2 = [k1, ..., ki_2, ..., kn]
    --    .           .     .          .        .
    --    .           .     .          .        .
    --    .           .     .          .        .
    --   xm of shape sm = [k1, ..., ki_m, ..., kn]
    --  along dimension i to an output tensor
    --   y  of shape sy = [k1, ..., k, ..., kn]
    --  where k = sum ki = sum [ki_1,...,ki_m]
    --
    --  The incoming gradient dy from backpropagation is
    --   simply forwarded split across input tensors yielding dx.
    --   Forwarded gradients have shapes s = [s1, ..., sm].
    | m == 1    = Nothing : [Just $ expr dy]
    | otherwise = Nothing : map Just (dx `reshapeZip` s)
  where
    reshapeZip = zipWith reshape
    dx = CoreOps.splitV (fromIntegral m) dy ki _i
    s  :: [Tensor Build Int32]
    s  = map shape x
    x  :: [Tensor Build a]
    x  = map toT $ tail _ix
    -- i: concat dimension. Adjusted modulo n to handle negative indices.
    _i = toT (head _ix) `CoreOps.floorMod` n
    i  = reshape _i $ vector [1 :: Int32]
    -- sizes along concatenated dimension
    ki :: Tensor Build Int32
    ki = CoreOps.concat 0 $ map (\t -&gt; CoreOps.slice t i $ vector [1 :: Int32]) s
    m  = length x
    n  = CoreOps.rank (head x)

opGrad &quot;Square&quot; _ [toT -&gt; x] [dz] =
    -- TODO(fmayle): Handle complex numbers.
    -- TODO(fmayle): The python code makes dz a control dependency of the 2*x
    -- (for performance reasons?). Will need to put these functions in the Build
    -- monad to replicate that.
    [Just $ dz `CoreOps.mul` (2 * x)]

opGrad &quot;Gather&quot; _ [toT -&gt; x, toT -&gt; indices] [dz] =
    -- TODO(fmayle): The python version uses a better performance implementation
    -- when the shape is known without having to run the graph.
    -- TODO(fmayle): We shouldn't convert the result to a dense tensor. Sparse
    -- tensor support will require some thinking.
    [ Just $ CoreOps.unsortedSegmentSum values indices' numRows
    , Nothing
    ]
  where
    -- TODO(gnezdo): Use colocateWith but it requires Build monad.
    denseShape = shape (x :: Tensor Build a)
    numRows = scalarize $ flatSlice denseShape 0 1
    valuesShape = CoreOps.concat 0 [ allDimensions
                                   , flatSlice denseShape 1 (-1)
                                   ]
    values = reshape dz valuesShape
    -- TODO(fmayle): This could be either Int32 or Int64.
    indices' = reshape indices allDimensions :: Tensor Build Int32

opGrad &quot;Max&quot; _ [toT -&gt; x, toT -&gt; indices] [dz] =
    [Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
  where
    sx = shape (x :: Tensor Build a)
    outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
    y = CoreOps.max x indices
    y' = reshape y outputShapeKeptDims
    dz' = reshape dz outputShapeKeptDims
    indicators = CoreOps.cast $ CoreOps.equal y' x
    numSelected = reshape (sum indicators indices) outputShapeKeptDims

-- Min and Max have identical gradient implementations.
opGrad &quot;Min&quot; u v w = opGrad &quot;Max&quot; u v w

-- Element wise maximum gradient
-- See https://github.com/tensorflow/tensorflow/blob/e9de087fa7f59c39bbe12ac2c83c5547c83f746c/tensorflow/core/ops/math_grad.cc#L473
opGrad &quot;Maximum&quot; _ [toT -&gt; x, toT -&gt; y] [dz] =
    gradForBinaryCwise (x, gx) (y, gy)
  where
    xmask = CoreOps.greaterEqual x y
    gx = CoreOps.select xmask dz (CoreOps.zerosLike dz)
    gy = CoreOps.select (CoreOps.logicalNot xmask) dz (CoreOps.zerosLike dz)

opGrad &quot;Sum&quot; _ [toT -&gt; x, toT -&gt; indices] [dz] =
    [ Just $ CoreOps.tile grad tileScaling, Nothing ]
  where
    -- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
    sx = shape (x :: Tensor Build a)
    outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
    tileScaling = safeShapeDiv sx outputShapeKeptDims
    grad = reshape dz outputShapeKeptDims

opGrad &quot;Mean&quot; u v@[toT -&gt; x, _] w =
    [Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
  where
    [Just dz, Nothing] = opGrad &quot;Sum&quot; u v w
    inputShape = shape (x :: Tensor Build a)
    outputShape = shape (dz :: Tensor Build a)
    -- TODO(fmayle): Add fast path when shape is known.
    inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
    outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
    factor = safeShapeDiv inputSize outputSize

opGrad &quot;Add&quot; _ [toT -&gt; x, toT -&gt; y] [dz] =
    [ Just $ reshape (sum dz rx) sx
    , Just $ reshape (sum dz ry) sy ]
  where
    sx = shape (x :: Tensor Build a)
    sy = shape (y :: Tensor Build a)
    (rx, ry) = broadcastGradientArgs sx sy

-- Copies the gradients to all inputs
-- Not broadcasting
opGrad &quot;AddN&quot; _ inputs [dz] =
    map ((const . Just . expr) dz) inputs

opGrad &quot;Sub&quot; u v w =
    [Just x, Just (-y)]
  where
    [Just x, Just y] = opGrad &quot;Add&quot; u v w

opGrad &quot;SoftmaxCrossEntropyWithLogits&quot; _ [toT -&gt; x, toT -&gt; y] [dz, _] =
    [ Just $ expandDims dz (-1) * snd (softmaxCrossEntropyWithLogits x y)
    , Nothing ]

opGrad &quot;Mul&quot; _ [toT -&gt; x, toT -&gt; y] [dz] =
    -- TODO(fmayle): Handle complex numbers.
    [ Just $ reshape (sum (dz `CoreOps.mul` y) rx) sx
    , Just $ reshape (sum (x `CoreOps.mul` dz) ry) sy ]
  where
    sx = shape (x :: Tensor Build a)
    sy = shape (y :: Tensor Build a)
    (rx, ry) = broadcastGradientArgs sx sy

opGrad &quot;Div&quot; _ [toT -&gt; x, toT -&gt; y] [dz] =
    -- TODO(fmayle): Handle complex numbers.
    -- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
    [ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
    , Just $ reshape (sum (dz `CoreOps.mul` (negate x `CoreOps.div` (y * y)))
                         ry)
                sy
    ]
  where
    sx = shape (x :: Tensor Build a)
    sy = shape (y :: Tensor Build a)
    (rx, ry) = broadcastGradientArgs sx sy

opGrad &quot;MatMul&quot; nodeDef [toT -&gt; x, toT -&gt; y] [dz] =
    let transposeA = lookupAttr nodeDef &quot;transpose_a&quot;
        transposeB = lookupAttr nodeDef &quot;transpose_b&quot;
        transAttrs a b =
            (opAttr &quot;transpose_a&quot; .~ a) . (opAttr &quot;transpose_b&quot; .~ b)
    in case (transposeA, transposeB) of
       (False, False) -&gt;
           [ Just $ matMul' (transAttrs False True) dz y
           , Just $ matMul' (transAttrs True False) x dz]
       (False, True) -&gt;
           [ Just $ matMul dz y
           , Just $ matMul' (transAttrs True False) dz x]
       (True, False) -&gt;
           [ Just $ matMul' (transAttrs False True) y dz
           , Just $ matMul x dz]
       (True, True) -&gt;
           [ Just $ matMul' (transAttrs True True) y dz
           , Just $ matMul' (transAttrs True True) dz x]

opGrad &quot;Transpose&quot; _ [_, toT -&gt; p] [dz] =
    [ Just $ CoreOps.transpose dz
            (CoreOps.invertPermutation p :: Tensor Build Int32)
    , Nothing
    ]

opGrad &quot;Conv2D&quot; nodeDef [toT -&gt; x, toT -&gt; y] [dz] =
    [ Just $ CoreOps.conv2DBackpropInput'
                ((opAttr &quot;strides&quot; .~ strides)
                    . (opAttr &quot;padding&quot; .~ padding)
                    . (opAttr &quot;use_cudnn_on_gpu&quot; .~ useCudnnOnGpu)
                    . (opAttr &quot;data_format&quot; .~ dataFormat))
                (shape x) y dz
    , Just $ CoreOps.conv2DBackpropFilter'
                ((opAttr &quot;strides&quot; .~ strides)
                    . (opAttr &quot;padding&quot; .~ padding)
                    . (opAttr &quot;use_cudnn_on_gpu&quot; .~ useCudnnOnGpu)
                    . (opAttr &quot;data_format&quot; .~ dataFormat))
                x (shape y) dz
    ]
  where
    strides = lookupAttr nodeDef &quot;strides&quot; :: [Int64]
    padding = lookupAttr nodeDef &quot;padding&quot; :: ByteString
    useCudnnOnGpu = lookupAttr nodeDef &quot;use_cudnn_on_gpu&quot; :: Bool
    dataFormat = lookupAttr nodeDef &quot;data_format&quot; :: ByteString

opGrad &quot;Conv2DBackpropInput&quot; nodeDef [_, toT -&gt; x, toT -&gt; y] [dz] =
    [ Nothing
    , Just $ CoreOps.conv2DBackpropFilter'
                ((opAttr &quot;strides&quot; .~ strides)
                    . (opAttr &quot;padding&quot; .~ padding)
                    . (opAttr &quot;use_cudnn_on_gpu&quot; .~ useCudnnOnGpu)
                    . (opAttr &quot;data_format&quot; .~ dataFormat))
                dz (shape x) y
    , Just $ CoreOps.conv2D'
                ((opAttr &quot;strides&quot; .~ strides)
                    . (opAttr &quot;padding&quot; .~ padding)
                    . (opAttr &quot;use_cudnn_on_gpu&quot; .~ useCudnnOnGpu)
                    . (opAttr &quot;data_format&quot; .~ dataFormat))
                dz x
    ]
  where
    strides = lookupAttr nodeDef &quot;strides&quot; :: [Int64]
    padding = lookupAttr nodeDef &quot;padding&quot; :: ByteString
    useCudnnOnGpu = lookupAttr nodeDef &quot;use_cudnn_on_gpu&quot; :: Bool
    dataFormat = lookupAttr nodeDef &quot;data_format&quot; :: ByteString

opGrad &quot;MaxPool&quot; nodeDef [toT -&gt; x] [dz] =
    [ Just $ CoreOps.maxPoolGrad'
                ((opAttr &quot;ksize&quot; .~ ksize)
                    . (opAttr &quot;strides&quot; .~ strides)
                    . (opAttr &quot;padding&quot; .~ padding)
                    . (opAttr &quot;data_format&quot; .~ dataFormat))
                x output dz
    ]
  where
    output :: Tensor Build a
    output = toT $ Output 0 (nodeDefName nodeDef)
    ksize = lookupAttr nodeDef &quot;ksize&quot; :: [Int64]
    strides = lookupAttr nodeDef &quot;strides&quot; :: [Int64]
    padding = lookupAttr nodeDef &quot;padding&quot; :: ByteString
    dataFormat = lookupAttr nodeDef &quot;data_format&quot; :: ByteString

opGrad &quot;Reshape&quot; _ [toT -&gt; x, _] [dz] =
    [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]

opGrad &quot;OneHot&quot; _ _ _ = [Nothing, Nothing, Nothing, Nothing]
opGrad &quot;TruncatedNormal&quot; _ _ _ = [Nothing]

opGrad &quot;RefIdentity&quot; _ _ [dz] = [Just $ expr dz]
opGrad &quot;Cast&quot; nodeDef _ [dz] = [Just reverseCast]
  where
    -- TODO(gnezdo): too permissive, python only allows float types as src_type.
    reverseCast =
        pureOp [] $ pure (opDef &quot;Cast&quot;
                 &amp; opAttr &quot;DstT&quot; .~ (lookupAttr nodeDef &quot;SrcT&quot; :: ByteString)
                 &amp; opAttr &quot;SrcT&quot; .~ (lookupAttr nodeDef &quot;DstT&quot; :: ByteString)
                 &amp; opInputs .~ [renderedOutput dz])

opGrad &quot;DynamicStitch&quot; nodeDef inputs [dz] =
    replicate halfLen Nothing ++ valuesGrads
  where
    halfLen =
        let len = length inputs
            half = len `div` 2
        in if 2 * half == len
           then half
           else error (&quot;Uneven input size &quot; ++ show (len, showMessage nodeDef))
    valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Build Int32)
                  | idx &lt;- take halfLen inputs
                  ]

opGrad &quot;DynamicPartition&quot; nodeDef [toT -&gt; xs, toT -&gt; indices] dz =
    [ Just reconstructed, Nothing ]
  where
    reconstructed = CoreOps.reshape stitched
                    (CoreOps.shape (xs :: Tensor Build a) :: Tensor Build Int32)
    stitched = CoreOps.dynamicStitch partitionedIndices dz
    partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
    np = lookupAttr nodeDef &quot;num_partitions&quot; :: Int64
    originalIndices =
        CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
    prefixShape = shapeInt32 indices
    shapeInt32 t = CoreOps.shape t :: Tensor Build Int32

opGrad &quot;Select&quot; _ [toT -&gt; c, toT -&gt; x, _] [dz] =
    [ Nothing
    , Just $ CoreOps.select c dz zeros
    , Just $ CoreOps.select c zeros dz
    ]
  where zeros = CoreOps.zerosLike x

-- TODO(gnezdo): Unlike Python, no control dependency on dz.
opGrad &quot;Log&quot; _ [toT -&gt; x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.inv x ]
-- TODO(gnezdo): Reuse the output instead of doing another exp,
-- though, it is probably CSE'd away anyway.
opGrad &quot;Exp&quot; _ [toT -&gt; x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.exp x ]
opGrad &quot;SparseSegmentSum&quot; _ [toT -&gt; x, toT -&gt; y, toT -&gt; t] [dz] =
    [ Just $ CoreOps.unsortedSegmentSum
             (CoreOps.gather dz (t :: Tensor Build Int32))
             (y :: Tensor Build Int32) inputRows
    , Nothing
    , Nothing
    ]
  where inputRows = flatSlice (shape (x :: Tensor Build a)) 0 1

opGrad &quot;LabelClasses&quot; _ _ _ = [Nothing, Nothing]
opGrad &quot;LabelWeights&quot; _ _ _ = [Nothing]
opGrad &quot;Size&quot; _ _ _ = [Nothing]

-- TODO (jcberentsen): Python implementation uses set_shape for
-- static shape inference, which is unsupported.
-- TODO: implement support for static shape inference
opGrad &quot;Tile&quot; _ [toT -&gt; x, toT -&gt; multiples] [dz] =
    [Just inputGrad, Nothing]
  where
    inputGrad = sum reshapedDz axes
    inputShape = shape (x :: Tensor Build a)
    packed = CoreOps.pack [multiples, inputShape]
    perm = vector [1, 0 :: Int32]
    splitShape = CoreOps.reshape (CoreOps.transpose packed perm) allDimensions
    axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
    reshapedDz = CoreOps.reshape dz splitShape

opGrad &quot;ZerosLike&quot; _ _ _ = [Nothing]
opGrad &quot;Fill&quot; _ _ [dz] = [Nothing, Just $ sum dz rx]
  where
    rx = rangeOfRank dz

-- Treat read ops as an identity function on the variable. This allows us to
-- take gradients w.r.t. to the variable handle instead of the result of a read
-- op. If a variable is read multiple times, the gradients will propagate back
-- through each read.
opGrad &quot;ReadVariableOp&quot; _ _ [dz] = [Just $ expr dz]

-- TODO(fmayle): These can go away if we properly prune the graph.
opGrad &quot;Const&quot; _ _ _ = [Nothing, Nothing]
opGrad &quot;Placeholder&quot; _ _ _ = []
opGrad &quot;VarHandleOp&quot; _ _ _ = []
opGrad &quot;Variable&quot; _ _ _ = []

opGrad n nodeDef ins grads =
    error $ &quot;no gradient implemented for &quot; ++
            show (n, length ins, length grads, showMessage nodeDef, ins)

-- | The number of outputs for an op type.
numOutputs :: NodeDef -&gt; OutputIx
numOutputs o =
    case o ^. op of
        &quot;Abs&quot; -&gt; 1
        &quot;Add&quot; -&gt; 1
        &quot;AddN&quot; -&gt; 1
        &quot;Cast&quot; -&gt; 1
        &quot;Const&quot; -&gt; 1
        &quot;Concat&quot; -&gt; 1
        &quot;Conv2D&quot; -&gt; 1
        &quot;Conv2DBackpropInput&quot; -&gt; 1
        &quot;Div&quot; -&gt; 1
        &quot;DynamicStitch&quot; -&gt; 1
        &quot;DynamicPartition&quot; -&gt;
            fromIntegral (lookupAttr o &quot;num_partitions&quot; :: Int64)
        &quot;Exp&quot; -&gt; 1
        &quot;Gather&quot; -&gt; 1
        &quot;LabelClasses&quot; -&gt; 1
        &quot;LabelWeights&quot; -&gt; 1
        &quot;Log&quot; -&gt; 1
        &quot;MatMul&quot; -&gt; 1
        &quot;Max&quot; -&gt; 1
        &quot;Maximum&quot; -&gt; 1
        &quot;MaxPool&quot; -&gt; 1
        &quot;Mean&quot; -&gt; 1
        &quot;Min&quot; -&gt; 1
        &quot;Mul&quot; -&gt; 1
        &quot;Neg&quot; -&gt; 1
        &quot;Placeholder&quot; -&gt; 1
        &quot;OneHot&quot; -&gt; 1
        &quot;ReadVariableOp&quot; -&gt; 1
        &quot;RefIdentity&quot; -&gt; 1
        &quot;Relu&quot; -&gt; 1
        &quot;ReluGrad&quot; -&gt; 1
        &quot;Reshape&quot; -&gt; 1
        &quot;Select&quot; -&gt; 1
        &quot;Size&quot; -&gt; 1
        &quot;SoftmaxCrossEntropyWithLogits&quot; -&gt; 2
        &quot;Square&quot; -&gt; 1
        &quot;SparseSegmentSum&quot; -&gt; 1
        &quot;Sub&quot; -&gt; 1
        &quot;Sum&quot; -&gt; 1
        &quot;Tile&quot; -&gt; 1
        &quot;Transpose&quot; -&gt; 1
        &quot;TruncatedNormal&quot; -&gt; 1
        &quot;VarHandleOp&quot; -&gt; 1
        &quot;Variable&quot; -&gt; 1
        &quot;ZerosLike&quot; -&gt; 1
        &quot;Fill&quot; -&gt; 1
        _ -&gt; error $ &quot;numOutputs not implemented for &quot; ++ show (o ^. op)

-- Divides `x / y` assuming `x, y &gt;= 0`, treating `0 / 0 = 0`
safeShapeDiv :: Tensor v1 Int32 -&gt; Tensor v2 Int32 -&gt; Tensor Build Int32
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)

allDimensions :: Tensor Build Int32
allDimensions = vector [-1 :: Int32]

rangeOfRank :: forall v1 t. TensorType t =&gt; Tensor v1 t -&gt; Tensor Build Int32
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1

lookupAttr ::  Attribute a1 =&gt; NodeDef -&gt; Text -&gt; a1
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens
</span></pre></body></html>