mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-03 08:19:44 +01:00
727 lines
161 KiB
HTML
727 lines
161 KiB
HTML
<?xml version="1.0" encoding="UTF-8"?>
|
|
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
|
|
<html>
|
|
<head>
|
|
<!-- Generated by HsColour, http://code.haskell.org/~malcolm/hscolour/ -->
|
|
<title>src/TensorFlow/Gradient.hs</title>
|
|
<link type='text/css' rel='stylesheet' href='hscolour.css' />
|
|
</head>
|
|
<body>
|
|
<pre><a name="line-1"></a><span class='hs-comment'>-- Copyright 2016 TensorFlow authors.</span>
|
|
<a name="line-2"></a><span class='hs-comment'>--</span>
|
|
<a name="line-3"></a><span class='hs-comment'>-- Licensed under the Apache License, Version 2.0 (the "License");</span>
|
|
<a name="line-4"></a><span class='hs-comment'>-- you may not use this file except in compliance with the License.</span>
|
|
<a name="line-5"></a><span class='hs-comment'>-- You may obtain a copy of the License at</span>
|
|
<a name="line-6"></a><span class='hs-comment'>--</span>
|
|
<a name="line-7"></a><span class='hs-comment'>-- <a href="http://www.apache.org/licenses/LICENSE-2.0">http://www.apache.org/licenses/LICENSE-2.0</a></span>
|
|
<a name="line-8"></a><span class='hs-comment'>--</span>
|
|
<a name="line-9"></a><span class='hs-comment'>-- Unless required by applicable law or agreed to in writing, software</span>
|
|
<a name="line-10"></a><span class='hs-comment'>-- distributed under the License is distributed on an "AS IS" BASIS,</span>
|
|
<a name="line-11"></a><span class='hs-comment'>-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
|
<a name="line-12"></a><span class='hs-comment'>-- See the License for the specific language governing permissions and</span>
|
|
<a name="line-13"></a><span class='hs-comment'>-- limitations under the License.</span>
|
|
<a name="line-14"></a>
|
|
<a name="line-15"></a><span class='hs-comment'>{-# LANGUAGE ConstraintKinds #-}</span>
|
|
<a name="line-16"></a><span class='hs-comment'>{-# LANGUAGE DataKinds #-}</span>
|
|
<a name="line-17"></a><span class='hs-comment'>{-# LANGUAGE FlexibleContexts #-}</span>
|
|
<a name="line-18"></a><span class='hs-comment'>{-# LANGUAGE OverloadedStrings #-}</span>
|
|
<a name="line-19"></a><span class='hs-comment'>{-# LANGUAGE RankNTypes #-}</span>
|
|
<a name="line-20"></a><span class='hs-comment'>{-# LANGUAGE ScopedTypeVariables #-}</span>
|
|
<a name="line-21"></a><span class='hs-comment'>{-# LANGUAGE TypeFamilies #-}</span>
|
|
<a name="line-22"></a><span class='hs-comment'>{-# LANGUAGE ViewPatterns #-}</span>
|
|
<a name="line-23"></a>
|
|
<a name="line-24"></a><span class='hs-keyword'>module</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Gradient</span>
|
|
<a name="line-25"></a> <span class='hs-layout'>(</span> <span class='hs-varid'>gradients</span>
|
|
<a name="line-26"></a> <span class='hs-layout'>)</span> <span class='hs-keyword'>where</span>
|
|
<a name="line-27"></a>
|
|
<a name="line-28"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Control</span><span class='hs-varop'>.</span><span class='hs-conid'>Monad</span> <span class='hs-layout'>(</span><span class='hs-varid'>forM</span><span class='hs-layout'>,</span> <span class='hs-varid'>zipWithM</span><span class='hs-layout'>)</span>
|
|
<a name="line-29"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Control</span><span class='hs-varop'>.</span><span class='hs-conid'>Monad</span><span class='hs-varop'>.</span><span class='hs-conid'>State</span><span class='hs-varop'>.</span><span class='hs-conid'>Strict</span> <span class='hs-layout'>(</span><span class='hs-conid'>State</span><span class='hs-layout'>,</span> <span class='hs-varid'>evalState</span><span class='hs-layout'>,</span> <span class='hs-varid'>gets</span><span class='hs-layout'>,</span> <span class='hs-varid'>modify</span><span class='hs-layout'>)</span>
|
|
<a name="line-30"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>ByteString</span> <span class='hs-layout'>(</span><span class='hs-conid'>ByteString</span><span class='hs-layout'>)</span>
|
|
<a name="line-31"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Complex</span> <span class='hs-layout'>(</span><span class='hs-conid'>Complex</span><span class='hs-layout'>)</span>
|
|
<a name="line-32"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Default</span> <span class='hs-layout'>(</span><span class='hs-varid'>def</span><span class='hs-layout'>)</span>
|
|
<a name="line-33"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Int</span> <span class='hs-layout'>(</span><span class='hs-conid'>Int32</span><span class='hs-layout'>,</span> <span class='hs-conid'>Int64</span><span class='hs-layout'>)</span>
|
|
<a name="line-34"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>List</span> <span class='hs-layout'>(</span><span class='hs-varid'>foldl'</span><span class='hs-layout'>,</span> <span class='hs-varid'>sortBy</span><span class='hs-layout'>)</span>
|
|
<a name="line-35"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-conid'>Strict</span> <span class='hs-layout'>(</span><span class='hs-conid'>Map</span><span class='hs-layout'>)</span>
|
|
<a name="line-36"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Maybe</span> <span class='hs-layout'>(</span><span class='hs-varid'>fromMaybe</span><span class='hs-layout'>,</span> <span class='hs-varid'>maybeToList</span><span class='hs-layout'>,</span> <span class='hs-varid'>mapMaybe</span><span class='hs-layout'>)</span>
|
|
<a name="line-37"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Ord</span> <span class='hs-layout'>(</span><span class='hs-varid'>comparing</span><span class='hs-layout'>)</span>
|
|
<a name="line-38"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>ProtoLens</span><span class='hs-varop'>.</span><span class='hs-conid'>TextFormat</span> <span class='hs-layout'>(</span><span class='hs-varid'>showMessage</span><span class='hs-layout'>)</span>
|
|
<a name="line-39"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Set</span> <span class='hs-layout'>(</span><span class='hs-conid'>Set</span><span class='hs-layout'>)</span>
|
|
<a name="line-40"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Text</span> <span class='hs-layout'>(</span><span class='hs-conid'>Text</span><span class='hs-layout'>)</span>
|
|
<a name="line-41"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Tuple</span> <span class='hs-layout'>(</span><span class='hs-varid'>swap</span><span class='hs-layout'>)</span>
|
|
<a name="line-42"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Lens</span><span class='hs-varop'>.</span><span class='hs-conid'>Family2</span> <span class='hs-layout'>(</span><span class='hs-conid'>Lens'</span><span class='hs-layout'>,</span> <span class='hs-layout'>(</span><span class='hs-varop'>&</span><span class='hs-layout'>)</span><span class='hs-layout'>,</span> <span class='hs-layout'>(</span><span class='hs-varop'>^.</span><span class='hs-layout'>)</span><span class='hs-layout'>,</span> <span class='hs-layout'>(</span><span class='hs-varop'>.~</span><span class='hs-layout'>)</span><span class='hs-layout'>,</span> <span class='hs-layout'>(</span><span class='hs-varop'>%~</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-43"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Lens</span><span class='hs-varop'>.</span><span class='hs-conid'>Family2</span><span class='hs-varop'>.</span><span class='hs-conid'>State</span><span class='hs-varop'>.</span><span class='hs-conid'>Strict</span> <span class='hs-layout'>(</span><span class='hs-varid'>uses</span><span class='hs-layout'>)</span>
|
|
<a name="line-44"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Lens</span><span class='hs-varop'>.</span><span class='hs-conid'>Family2</span><span class='hs-varop'>.</span><span class='hs-conid'>Stock</span> <span class='hs-layout'>(</span><span class='hs-varid'>at</span><span class='hs-layout'>,</span> <span class='hs-varid'>intAt</span><span class='hs-layout'>)</span>
|
|
<a name="line-45"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Lens</span><span class='hs-varop'>.</span><span class='hs-conid'>Family2</span><span class='hs-varop'>.</span><span class='hs-conid'>Unchecked</span> <span class='hs-layout'>(</span><span class='hs-varid'>lens</span><span class='hs-layout'>,</span> <span class='hs-varid'>iso</span><span class='hs-layout'>)</span>
|
|
<a name="line-46"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Prelude</span> <span class='hs-varid'>hiding</span> <span class='hs-layout'>(</span><span class='hs-varid'>sum</span><span class='hs-layout'>)</span>
|
|
<a name="line-47"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Text</span><span class='hs-varop'>.</span><span class='hs-conid'>Printf</span> <span class='hs-layout'>(</span><span class='hs-varid'>printf</span><span class='hs-layout'>)</span>
|
|
<a name="line-48"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Graph</span><span class='hs-varop'>.</span><span class='hs-conid'>Inductive</span><span class='hs-varop'>.</span><span class='hs-conid'>Basic</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>FGL</span>
|
|
<a name="line-49"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Graph</span><span class='hs-varop'>.</span><span class='hs-conid'>Inductive</span><span class='hs-varop'>.</span><span class='hs-conid'>Graph</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>FGL</span>
|
|
<a name="line-50"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Graph</span><span class='hs-varop'>.</span><span class='hs-conid'>Inductive</span><span class='hs-varop'>.</span><span class='hs-conid'>PatriciaTree</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>FGL</span>
|
|
<a name="line-51"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Graph</span><span class='hs-varop'>.</span><span class='hs-conid'>Inductive</span><span class='hs-varop'>.</span><span class='hs-conid'>Query</span><span class='hs-varop'>.</span><span class='hs-conid'>DFS</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>FGL</span>
|
|
<a name="line-52"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>IntMap</span><span class='hs-varop'>.</span><span class='hs-conid'>Strict</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>IntMap</span>
|
|
<a name="line-53"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-conid'>Strict</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>Map</span>
|
|
<a name="line-54"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Set</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>Set</span>
|
|
<a name="line-55"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>Data</span><span class='hs-varop'>.</span><span class='hs-conid'>Text</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>Text</span>
|
|
<a name="line-56"></a>
|
|
<a name="line-57"></a><span class='hs-keyword'>import</span> <span class='hs-keyword'>qualified</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>GenOps</span><span class='hs-varop'>.</span><span class='hs-conid'>Core</span> <span class='hs-keyword'>as</span> <span class='hs-conid'>CoreOps</span>
|
|
<a name="line-58"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Build</span>
|
|
<a name="line-59"></a> <span class='hs-layout'>(</span> <span class='hs-conid'>Build</span>
|
|
<a name="line-60"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>render</span>
|
|
<a name="line-61"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>renderNodeName</span>
|
|
<a name="line-62"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>renderedNodeDefs</span>
|
|
<a name="line-63"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>opDef</span>
|
|
<a name="line-64"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>opAttr</span>
|
|
<a name="line-65"></a> <span class='hs-layout'>)</span>
|
|
<a name="line-66"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>BuildOp</span>
|
|
<a name="line-67"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Ops</span>
|
|
<a name="line-68"></a> <span class='hs-layout'>(</span> <span class='hs-varid'>addN</span>
|
|
<a name="line-69"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>broadcastGradientArgs</span>
|
|
<a name="line-70"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>expandDims</span>
|
|
<a name="line-71"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>fill</span>
|
|
<a name="line-72"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>matMul</span>
|
|
<a name="line-73"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>reducedShape</span>
|
|
<a name="line-74"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>reluGrad</span>
|
|
<a name="line-75"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>reshape</span>
|
|
<a name="line-76"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>scalar</span>
|
|
<a name="line-77"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>shape</span>
|
|
<a name="line-78"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>softmaxCrossEntropyWithLogits</span>
|
|
<a name="line-79"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>sum</span>
|
|
<a name="line-80"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>scalarize</span>
|
|
<a name="line-81"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>vector</span>
|
|
<a name="line-82"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>zerosLike</span>
|
|
<a name="line-83"></a> <span class='hs-layout'>)</span>
|
|
<a name="line-84"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Output</span>
|
|
<a name="line-85"></a> <span class='hs-layout'>(</span> <span class='hs-conid'>NodeName</span><span class='hs-layout'>(</span><span class='hs-keyglyph'>..</span><span class='hs-layout'>)</span>
|
|
<a name="line-86"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Op</span> <span class='hs-layout'>(</span><span class='hs-conid'>Rendered</span><span class='hs-layout'>)</span>
|
|
<a name="line-87"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Output</span><span class='hs-layout'>(</span><span class='hs-keyglyph'>..</span><span class='hs-layout'>)</span>
|
|
<a name="line-88"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>OutputIx</span><span class='hs-layout'>(</span><span class='hs-keyglyph'>..</span><span class='hs-layout'>)</span>
|
|
<a name="line-89"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>outputIndex</span>
|
|
<a name="line-90"></a> <span class='hs-layout'>)</span>
|
|
<a name="line-91"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Tensor</span>
|
|
<a name="line-92"></a> <span class='hs-layout'>(</span> <span class='hs-conid'>Tensor</span><span class='hs-layout'>(</span><span class='hs-keyglyph'>..</span><span class='hs-layout'>)</span>
|
|
<a name="line-93"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>TensorKind</span> <span class='hs-layout'>(</span><span class='hs-conid'>ValueKind</span><span class='hs-layout'>)</span>
|
|
<a name="line-94"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Value</span>
|
|
<a name="line-95"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>tensorOutput</span>
|
|
<a name="line-96"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>tensorAttr</span>
|
|
<a name="line-97"></a> <span class='hs-layout'>)</span>
|
|
<a name="line-98"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>TensorFlow</span><span class='hs-varop'>.</span><span class='hs-conid'>Types</span> <span class='hs-layout'>(</span><span class='hs-conid'>Attribute</span><span class='hs-layout'>,</span> <span class='hs-conid'>OneOf</span><span class='hs-layout'>,</span> <span class='hs-conid'>TensorType</span><span class='hs-layout'>,</span> <span class='hs-varid'>attrLens</span><span class='hs-layout'>)</span>
|
|
<a name="line-99"></a><span class='hs-keyword'>import</span> <span class='hs-conid'>Proto</span><span class='hs-varop'>.</span><span class='hs-conid'>Tensorflow</span><span class='hs-varop'>.</span><span class='hs-conid'>Core</span><span class='hs-varop'>.</span><span class='hs-conid'>Framework</span><span class='hs-varop'>.</span><span class='hs-conid'>NodeDef</span>
|
|
<a name="line-100"></a> <span class='hs-layout'>(</span><span class='hs-conid'>NodeDef</span><span class='hs-layout'>,</span> <span class='hs-varid'>attr</span><span class='hs-layout'>,</span> <span class='hs-varid'>input</span><span class='hs-layout'>,</span> <span class='hs-varid'>op</span><span class='hs-layout'>,</span> <span class='hs-varid'>name</span><span class='hs-layout'>)</span>
|
|
<a name="line-101"></a>
|
|
<a name="line-102"></a><a name="GradientCompatible"></a><span class='hs-keyword'>type</span> <span class='hs-conid'>GradientCompatible</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-103"></a> <span class='hs-comment'>-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.</span>
|
|
<a name="line-104"></a> <span class='hs-layout'>(</span><span class='hs-conid'>Num</span> <span class='hs-varid'>a</span><span class='hs-layout'>,</span> <span class='hs-conid'>OneOf</span> <span class='hs-chr'>'</span><span class='hs-keyglyph'>[</span> <span class='hs-conid'>Float</span><span class='hs-layout'>,</span> <span class='hs-conid'>Complex</span> <span class='hs-conid'>Float</span><span class='hs-layout'>,</span> <span class='hs-conid'>Complex</span> <span class='hs-conid'>Double</span> <span class='hs-keyglyph'>]</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-105"></a>
|
|
<a name="line-106"></a><span class='hs-comment'>-- TODO(fmayle): Support control flow.</span>
|
|
<a name="line-107"></a><span class='hs-comment'>-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.</span>
|
|
<a name="line-108"></a><span class='hs-comment'>-- TODO(fmayle): Do we need to consider control inputs? See _PendingCount in</span>
|
|
<a name="line-109"></a><span class='hs-comment'>-- tensorflow/python/ops/gradients.py.</span>
|
|
<a name="line-110"></a><span class='hs-comment'>-- TODO(fmayle): Maybe store the gradient functions and numOutputs on the OpDef.</span>
|
|
<a name="line-111"></a>
|
|
<a name="line-112"></a>
|
|
<a name="line-113"></a><a name="gradients"></a><span class='hs-comment'>-- | Gradient of @y@ w.r.t. each element of @xs@.</span>
|
|
<a name="line-114"></a><span class='hs-definition'>gradients</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyword'>forall</span> <span class='hs-varid'>a</span> <span class='hs-varid'>v1</span> <span class='hs-varid'>v2</span> <span class='hs-varop'>.</span> <span class='hs-layout'>(</span> <span class='hs-conid'>Num</span> <span class='hs-layout'>(</span><span class='hs-conid'>Tensor</span> <span class='hs-varid'>v1</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-115"></a> <span class='hs-comment'>-- TODO(gnezdo): remove indirect constraint.</span>
|
|
<a name="line-116"></a> <span class='hs-comment'>-- It's a wart inherited from Num instance.</span>
|
|
<a name="line-117"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>v1</span> <span class='hs-keyglyph'>~</span> <span class='hs-conid'>Value</span>
|
|
<a name="line-118"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>GradientCompatible</span> <span class='hs-varid'>a</span>
|
|
<a name="line-119"></a> <span class='hs-layout'>)</span>
|
|
<a name="line-120"></a> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>Tensor</span> <span class='hs-varid'>v1</span> <span class='hs-varid'>a</span> <span class='hs-comment'>-- ^ The output of the graph.</span>
|
|
<a name="line-121"></a> <span class='hs-keyglyph'>-></span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Tensor</span> <span class='hs-varid'>v2</span> <span class='hs-varid'>a</span><span class='hs-keyglyph'>]</span> <span class='hs-comment'>-- ^ Tensors for which gradients are computed.</span>
|
|
<a name="line-122"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Build</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-123"></a><span class='hs-definition'>gradients</span> <span class='hs-varid'>y</span> <span class='hs-varid'>xs</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyword'>do</span>
|
|
<a name="line-124"></a> <span class='hs-comment'>-- The gradients are computed using "reverse accumulation", similarly to</span>
|
|
<a name="line-125"></a> <span class='hs-comment'>-- what is described here:</span>
|
|
<a name="line-126"></a> <span class='hs-comment'>-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation</span>
|
|
<a name="line-127"></a> <span class='hs-comment'>--</span>
|
|
<a name="line-128"></a> <span class='hs-comment'>-- The code is summarised as follows:</span>
|
|
<a name="line-129"></a> <span class='hs-comment'>--</span>
|
|
<a name="line-130"></a> <span class='hs-comment'>-- 1. Create an fgl graph of the relevant nodes (ops) and edges (tensors).</span>
|
|
<a name="line-131"></a> <span class='hs-comment'>-- 2. Initialize the gradient of y to 1 (∂y/∂y = 1) and the rest of tensor's</span>
|
|
<a name="line-132"></a> <span class='hs-comment'>-- gradients to nothing.</span>
|
|
<a name="line-133"></a> <span class='hs-comment'>-- 3. Process the nodes in reverse topological order (i.e. each node comes</span>
|
|
<a name="line-134"></a> <span class='hs-comment'>-- after all of its outputs so that the output gradients for a node have</span>
|
|
<a name="line-135"></a> <span class='hs-comment'>-- been completely calculated before it is processed):</span>
|
|
<a name="line-136"></a> <span class='hs-comment'>-- a. Record the gradient for each of the node's output tensors (∂y/∂w</span>
|
|
<a name="line-137"></a> <span class='hs-comment'>-- for each output tensor w).</span>
|
|
<a name="line-138"></a> <span class='hs-comment'>-- b. Calculate the gradient of y w.r.t. each of the node's input</span>
|
|
<a name="line-139"></a> <span class='hs-comment'>-- tensors using the gradients of the node's output tensors.</span>
|
|
<a name="line-140"></a> <span class='hs-comment'>--</span>
|
|
<a name="line-141"></a> <span class='hs-comment'>-- Written differently, for each output tensor w and input tensor v:</span>
|
|
<a name="line-142"></a> <span class='hs-comment'>-- ∂y/∂w = ... (calculated in previous steps)</span>
|
|
<a name="line-143"></a> <span class='hs-comment'>-- ∂w/∂v = ... (op specific)</span>
|
|
<a name="line-144"></a> <span class='hs-comment'>-- ∂y/∂v = ∂y/∂w * ∂w/∂v (technically, if tensor v is an input</span>
|
|
<a name="line-145"></a> <span class='hs-comment'>-- to multiple nodes, then this is only</span>
|
|
<a name="line-146"></a> <span class='hs-comment'>-- part of ∂y/∂v)</span>
|
|
<a name="line-147"></a> <span class='hs-comment'>--</span>
|
|
<a name="line-148"></a> <span class='hs-comment'>-- 4. Lookup the recorded gradient for each x in xs.</span>
|
|
<a name="line-149"></a>
|
|
<a name="line-150"></a> <span class='hs-varid'>yName</span> <span class='hs-keyglyph'><-</span> <span class='hs-varid'>renderNodeName</span> <span class='hs-varid'>y</span>
|
|
<a name="line-151"></a> <span class='hs-comment'>-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?</span>
|
|
<a name="line-152"></a> <span class='hs-varid'>nodeDefLookup</span> <span class='hs-keyglyph'>::</span> <span class='hs-layout'>(</span><span class='hs-conid'>NodeName</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>NodeDef</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'><-</span> <span class='hs-varid'>uses</span> <span class='hs-varid'>renderedNodeDefs</span> <span class='hs-varop'>$</span>
|
|
<a name="line-153"></a> <span class='hs-layout'>(</span><span class='hs-keyglyph'>\</span><span class='hs-varid'>f</span> <span class='hs-varid'>x</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>fromMaybe</span> <span class='hs-layout'>(</span><span class='hs-varid'>error</span> <span class='hs-varop'>$</span> <span class='hs-str'>"no NodeDef found for "</span> <span class='hs-varop'>++</span> <span class='hs-varid'>show</span> <span class='hs-varid'>x</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-varid'>f</span> <span class='hs-varid'>x</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-154"></a> <span class='hs-varop'>.</span> <span class='hs-varid'>flip</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>lookup</span>
|
|
<a name="line-155"></a> <span class='hs-keyword'>let</span> <span class='hs-layout'>(</span><span class='hs-varid'>gr</span><span class='hs-layout'>,</span> <span class='hs-varid'>nodeMap</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>createGraph</span> <span class='hs-varid'>yName</span> <span class='hs-varid'>nodeDefLookup</span>
|
|
<a name="line-156"></a> <span class='hs-comment'>-- Set gradient of y to one.</span>
|
|
<a name="line-157"></a> <span class='hs-keyword'>let</span> <span class='hs-varid'>initPending</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-conid'>Map</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span> <span class='hs-layout'>(</span><span class='hs-conid'>PendingGradients</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-158"></a> <span class='hs-varid'>initPending</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>empty</span> <span class='hs-varop'>&</span> <span class='hs-varid'>at</span> <span class='hs-layout'>(</span><span class='hs-varid'>nodeMap</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.!</span> <span class='hs-varid'>yName</span><span class='hs-layout'>)</span>
|
|
<a name="line-159"></a> <span class='hs-varop'>.</span> <span class='hs-varid'>nonEmpty</span>
|
|
<a name="line-160"></a> <span class='hs-varop'>.</span> <span class='hs-varid'>outputIxAt</span> <span class='hs-layout'>(</span><span class='hs-varid'>y</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>tensorOutput</span> <span class='hs-varop'>.</span> <span class='hs-varid'>outputIndex</span><span class='hs-layout'>)</span>
|
|
<a name="line-161"></a> <span class='hs-varop'>.</span> <span class='hs-varid'>nonEmpty</span>
|
|
<a name="line-162"></a> <span class='hs-varop'>.~</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>fill</span> <span class='hs-layout'>(</span><span class='hs-varid'>shape</span> <span class='hs-varid'>y</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-varid'>scalar</span> <span class='hs-num'>1</span><span class='hs-layout'>)</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-163"></a> <span class='hs-comment'>-- Calculate the gradients of y w.r.t. each node in the graph.</span>
|
|
<a name="line-164"></a> <span class='hs-varid'>gradientMap</span> <span class='hs-keyglyph'><-</span> <span class='hs-varid'>graphGrads</span> <span class='hs-varid'>gr</span> <span class='hs-varid'>initPending</span>
|
|
<a name="line-165"></a> <span class='hs-comment'>-- Lookup the gradients for each x.</span>
|
|
<a name="line-166"></a> <span class='hs-varid'>forM</span> <span class='hs-varid'>xs</span> <span class='hs-varop'>$</span> <span class='hs-keyglyph'>\</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>-></span> <span class='hs-keyword'>do</span>
|
|
<a name="line-167"></a> <span class='hs-varid'>xName</span> <span class='hs-keyglyph'><-</span> <span class='hs-varid'>renderNodeName</span> <span class='hs-varid'>x</span>
|
|
<a name="line-168"></a> <span class='hs-varid'>render</span> <span class='hs-varop'>$</span> <span class='hs-varid'>fromMaybe</span> <span class='hs-layout'>(</span><span class='hs-varid'>zerosLike</span> <span class='hs-varid'>x</span><span class='hs-layout'>)</span> <span class='hs-varop'>$</span> <span class='hs-keyword'>do</span>
|
|
<a name="line-169"></a> <span class='hs-varid'>n</span> <span class='hs-keyglyph'><-</span> <span class='hs-varid'>nodeMap</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>at</span> <span class='hs-varid'>xName</span>
|
|
<a name="line-170"></a> <span class='hs-keyword'>let</span> <span class='hs-varid'>i</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>x</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>tensorOutput</span> <span class='hs-varop'>.</span> <span class='hs-varid'>outputIndex</span>
|
|
<a name="line-171"></a> <span class='hs-varid'>gradientMap</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>at</span> <span class='hs-varid'>n</span> <span class='hs-varop'>.</span> <span class='hs-varid'>nonEmpty</span> <span class='hs-varop'>.</span> <span class='hs-varid'>outputIxAt</span> <span class='hs-varid'>i</span>
|
|
<a name="line-172"></a>
|
|
<a name="line-173"></a><a name="outputIxAt"></a><span class='hs-definition'>outputIxAt</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>OutputIx</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Lens'</span> <span class='hs-layout'>(</span><span class='hs-conid'>IntMap</span><span class='hs-varop'>.</span><span class='hs-conid'>IntMap</span> <span class='hs-varid'>v</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-conid'>Maybe</span> <span class='hs-varid'>v</span><span class='hs-layout'>)</span>
|
|
<a name="line-174"></a><span class='hs-definition'>outputIxAt</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>intAt</span> <span class='hs-varop'>.</span> <span class='hs-varid'>unOutputIx</span>
|
|
<a name="line-175"></a>
|
|
<a name="line-176"></a><a name="PendingGradients"></a><span class='hs-comment'>-- | Incomplete gradients of a node's outputs.</span>
|
|
<a name="line-177"></a><a name="PendingGradients"></a><span class='hs-comment'>--</span>
|
|
<a name="line-178"></a><a name="PendingGradients"></a><span class='hs-comment'>-- The lists represent partial sums. The key is an OutputIx sans newtype.</span>
|
|
<a name="line-179"></a><a name="PendingGradients"></a><span class='hs-keyword'>type</span> <span class='hs-conid'>PendingGradients</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>IntMap</span><span class='hs-varop'>.</span><span class='hs-conid'>IntMap</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-180"></a>
|
|
<a name="line-181"></a><a name="Gradients"></a><span class='hs-comment'>-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.</span>
|
|
<a name="line-182"></a><a name="Gradients"></a><span class='hs-keyword'>type</span> <span class='hs-conid'>Gradients</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>IntMap</span><span class='hs-varop'>.</span><span class='hs-conid'>IntMap</span> <span class='hs-layout'>(</span><span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-183"></a>
|
|
<a name="line-184"></a><a name="Graph"></a><span class='hs-comment'>-- | Graph of TensorFlow operations.</span>
|
|
<a name="line-185"></a><a name="Graph"></a><span class='hs-keyword'>type</span> <span class='hs-conid'>Graph</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Gr</span> <span class='hs-conid'>NodeDef</span> <span class='hs-conid'>EdgeLabel</span>
|
|
<a name="line-186"></a>
|
|
<a name="line-187"></a><a name="EdgeLabel"></a><span class='hs-comment'>-- | Data associated with an edge.</span>
|
|
<a name="line-188"></a><a name="EdgeLabel"></a><span class='hs-comment'>--</span>
|
|
<a name="line-189"></a><a name="EdgeLabel"></a><span class='hs-comment'>-- Pair of</span>
|
|
<a name="line-190"></a><a name="EdgeLabel"></a><span class='hs-comment'>-- 1. Output index of a tensor from the source node.</span>
|
|
<a name="line-191"></a><a name="EdgeLabel"></a><span class='hs-comment'>-- 2. Input index that the tensor connects to on the destination node.</span>
|
|
<a name="line-192"></a><a name="EdgeLabel"></a><span class='hs-keyword'>type</span> <span class='hs-conid'>EdgeLabel</span> <span class='hs-keyglyph'>=</span> <span class='hs-layout'>(</span><span class='hs-conid'>OutputIx</span><span class='hs-layout'>,</span> <span class='hs-conid'>OutputIx</span><span class='hs-layout'>)</span>
|
|
<a name="line-193"></a>
|
|
<a name="line-194"></a>
|
|
<a name="line-195"></a><a name="GradientsState"></a><span class='hs-comment'>-- | State used for calculating gradients.</span>
|
|
<a name="line-196"></a><a name="GradientsState"></a><span class='hs-keyword'>data</span> <span class='hs-conid'>GradientsState</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>GradientsState</span>
|
|
<a name="line-197"></a> <span class='hs-layout'>{</span> <span class='hs-sel'>_gradientsPending</span> <span class='hs-keyglyph'>::</span> <span class='hs-varop'>!</span><span class='hs-layout'>(</span><span class='hs-conid'>Map</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span> <span class='hs-layout'>(</span><span class='hs-conid'>PendingGradients</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-198"></a> <span class='hs-layout'>,</span> <span class='hs-sel'>_gradientsResult</span> <span class='hs-keyglyph'>::</span> <span class='hs-varop'>!</span><span class='hs-layout'>(</span><span class='hs-conid'>Map</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span> <span class='hs-layout'>(</span><span class='hs-conid'>Gradients</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-199"></a> <span class='hs-layout'>}</span>
|
|
<a name="line-200"></a>
|
|
<a name="line-201"></a><a name="gradientsPending"></a><span class='hs-definition'>gradientsPending</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Lens'</span> <span class='hs-layout'>(</span><span class='hs-conid'>GradientsState</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-conid'>Map</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span> <span class='hs-layout'>(</span><span class='hs-conid'>PendingGradients</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-202"></a><span class='hs-definition'>gradientsPending</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lens</span> <span class='hs-sel'>_gradientsPending</span> <span class='hs-layout'>(</span><span class='hs-keyglyph'>\</span><span class='hs-varid'>x</span> <span class='hs-varid'>y</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span> <span class='hs-layout'>{</span> <span class='hs-sel'>_gradientsPending</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>y</span> <span class='hs-layout'>}</span><span class='hs-layout'>)</span>
|
|
<a name="line-203"></a>
|
|
<a name="line-204"></a><a name="gradientsResult"></a><span class='hs-definition'>gradientsResult</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Lens'</span> <span class='hs-layout'>(</span><span class='hs-conid'>GradientsState</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-conid'>Map</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span> <span class='hs-layout'>(</span><span class='hs-conid'>Gradients</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-205"></a><span class='hs-definition'>gradientsResult</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lens</span> <span class='hs-sel'>_gradientsResult</span> <span class='hs-layout'>(</span><span class='hs-keyglyph'>\</span><span class='hs-varid'>x</span> <span class='hs-varid'>y</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span> <span class='hs-layout'>{</span> <span class='hs-sel'>_gradientsResult</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>y</span> <span class='hs-layout'>}</span><span class='hs-layout'>)</span>
|
|
<a name="line-206"></a>
|
|
<a name="line-207"></a>
|
|
<a name="line-208"></a><a name="safeIndex"></a><span class='hs-comment'>-- TODO(fmayle): Use something like Data.List.Safe.</span>
|
|
<a name="line-209"></a><span class='hs-comment'>-- | Safe version of (!!).</span>
|
|
<a name="line-210"></a><span class='hs-definition'>safeIndex</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>a</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Int</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Maybe</span> <span class='hs-varid'>a</span>
|
|
<a name="line-211"></a><span class='hs-keyword'>_</span> <span class='hs-varop'>`safeIndex`</span> <span class='hs-varid'>n</span> <span class='hs-keyglyph'>|</span> <span class='hs-varid'>n</span> <span class='hs-varop'><</span> <span class='hs-num'>0</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Nothing</span>
|
|
<a name="line-212"></a><span class='hs-conid'>[]</span> <span class='hs-varop'>`safeIndex`</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Nothing</span>
|
|
<a name="line-213"></a><a name="safeIndex"></a><span class='hs-layout'>(</span><span class='hs-varid'>x</span><span class='hs-conop'>:</span><span class='hs-keyword'>_</span><span class='hs-layout'>)</span> <span class='hs-varop'>`safeIndex`</span> <span class='hs-num'>0</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Just</span> <span class='hs-varid'>x</span>
|
|
<a name="line-214"></a><a name="safeIndex"></a><span class='hs-layout'>(</span><span class='hs-keyword'>_</span><span class='hs-conop'>:</span><span class='hs-varid'>xs</span><span class='hs-layout'>)</span> <span class='hs-varop'>`safeIndex`</span> <span class='hs-varid'>n</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>xs</span> <span class='hs-varop'>`safeIndex`</span> <span class='hs-layout'>(</span><span class='hs-varid'>n</span><span class='hs-comment'>-</span><span class='hs-num'>1</span><span class='hs-layout'>)</span>
|
|
<a name="line-215"></a>
|
|
<a name="line-216"></a><a name="anon"></a><span class='hs-comment'>-- Copy of <a href="http://hackage.haskell.org/package/lens-3.9.0.2/docs/Control-Lens-Iso.html#v%3anon">http://hackage.haskell.org/package/lens-3.9.0.2/docs/Control-Lens-Iso.html#v%3anon</a></span>
|
|
<a name="line-217"></a><span class='hs-definition'>anon</span> <span class='hs-keyglyph'>::</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>-></span> <span class='hs-layout'>(</span><span class='hs-varid'>a</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Bool</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Lens'</span> <span class='hs-layout'>(</span><span class='hs-conid'>Maybe</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span> <span class='hs-varid'>a</span>
|
|
<a name="line-218"></a><span class='hs-definition'>anon</span> <span class='hs-varid'>a</span> <span class='hs-varid'>p</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>iso</span> <span class='hs-layout'>(</span><span class='hs-varid'>fromMaybe</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span> <span class='hs-varid'>go</span> <span class='hs-keyword'>where</span>
|
|
<a name="line-219"></a> <span class='hs-varid'>go</span> <span class='hs-varid'>b</span> <span class='hs-keyglyph'>|</span> <span class='hs-varid'>p</span> <span class='hs-varid'>b</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Nothing</span>
|
|
<a name="line-220"></a> <span class='hs-keyglyph'>|</span> <span class='hs-varid'>otherwise</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Just</span> <span class='hs-varid'>b</span>
|
|
<a name="line-221"></a>
|
|
<a name="line-222"></a><a name="non"></a><span class='hs-definition'>non</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Eq</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>=></span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Lens'</span> <span class='hs-layout'>(</span><span class='hs-conid'>Maybe</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span> <span class='hs-varid'>a</span>
|
|
<a name="line-223"></a><span class='hs-definition'>non</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>anon</span> <span class='hs-varid'>a</span> <span class='hs-layout'>(</span><span class='hs-varid'>a</span><span class='hs-varop'>==</span><span class='hs-layout'>)</span>
|
|
<a name="line-224"></a>
|
|
<a name="line-225"></a><a name="nonEmpty"></a><span class='hs-comment'>-- | Lens that defaults Nothing to mempty.</span>
|
|
<a name="line-226"></a><span class='hs-definition'>nonEmpty</span> <span class='hs-keyglyph'>::</span> <span class='hs-layout'>(</span><span class='hs-conid'>Monoid</span> <span class='hs-layout'>(</span><span class='hs-varid'>t</span> <span class='hs-varid'>v</span><span class='hs-layout'>)</span><span class='hs-layout'>,</span> <span class='hs-conid'>Foldable</span> <span class='hs-varid'>t</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>Lens'</span> <span class='hs-layout'>(</span><span class='hs-conid'>Maybe</span> <span class='hs-layout'>(</span><span class='hs-varid'>t</span> <span class='hs-varid'>v</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-varid'>t</span> <span class='hs-varid'>v</span><span class='hs-layout'>)</span>
|
|
<a name="line-227"></a><span class='hs-definition'>nonEmpty</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>anon</span> <span class='hs-varid'>mempty</span> <span class='hs-varid'>null</span>
|
|
<a name="line-228"></a>
|
|
<a name="line-229"></a><a name="graphGrads"></a><span class='hs-comment'>-- | Calculate the gradients for every node in a graph.</span>
|
|
<a name="line-230"></a><span class='hs-definition'>graphGrads</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyword'>forall</span> <span class='hs-varid'>a</span><span class='hs-varop'>.</span> <span class='hs-conid'>GradientCompatible</span> <span class='hs-varid'>a</span>
|
|
<a name="line-231"></a> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>Graph</span>
|
|
<a name="line-232"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Map</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span> <span class='hs-layout'>(</span><span class='hs-conid'>PendingGradients</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-233"></a> <span class='hs-comment'>-- ^ Initial gradients (usually just 1 for the node of interest).</span>
|
|
<a name="line-234"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Build</span> <span class='hs-layout'>(</span><span class='hs-conid'>Map</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span> <span class='hs-layout'>(</span><span class='hs-conid'>Gradients</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-235"></a><span class='hs-definition'>graphGrads</span> <span class='hs-varid'>gr</span> <span class='hs-varid'>initPending</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>pure</span> <span class='hs-layout'>(</span><span class='hs-varid'>foldl'</span> <span class='hs-varid'>go</span> <span class='hs-varid'>initState</span> <span class='hs-varid'>nodeOrder</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>gradientsResult</span><span class='hs-layout'>)</span>
|
|
<a name="line-236"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-237"></a> <span class='hs-varid'>initState</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>GradientsState</span> <span class='hs-varid'>initPending</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>empty</span>
|
|
<a name="line-238"></a> <span class='hs-comment'>-- Reverse topological sort.</span>
|
|
<a name="line-239"></a> <span class='hs-comment'>-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to</span>
|
|
<a name="line-240"></a> <span class='hs-comment'>-- avoid calculating gradients that won't be used.</span>
|
|
<a name="line-241"></a> <span class='hs-varid'>nodeOrder</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-varid'>topsort</span> <span class='hs-varop'>$</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-varid'>grev</span> <span class='hs-varid'>gr</span>
|
|
<a name="line-242"></a> <span class='hs-varid'>go</span> <span class='hs-varid'>state</span> <span class='hs-varid'>node</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-243"></a> <span class='hs-comment'>-- Aggregate the accumulated gradients for this node.</span>
|
|
<a name="line-244"></a> <span class='hs-keyword'>let</span> <span class='hs-varid'>outputGrads</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-245"></a> <span class='hs-varid'>sumPendingGradient</span> <span class='hs-layout'>(</span><span class='hs-varid'>state</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>gradientsPending</span> <span class='hs-varop'>.</span> <span class='hs-varid'>at</span> <span class='hs-varid'>node</span> <span class='hs-varop'>.</span> <span class='hs-varid'>nonEmpty</span><span class='hs-layout'>)</span>
|
|
<a name="line-246"></a> <span class='hs-keyword'>in</span> <span class='hs-keyword'>if</span> <span class='hs-varid'>null</span> <span class='hs-varid'>outputGrads</span>
|
|
<a name="line-247"></a> <span class='hs-keyword'>then</span> <span class='hs-varid'>state</span>
|
|
<a name="line-248"></a> <span class='hs-keyword'>else</span>
|
|
<a name="line-249"></a> <span class='hs-comment'>-- Calculate the gradients for each of the node's inputs.</span>
|
|
<a name="line-250"></a> <span class='hs-keyword'>let</span> <span class='hs-varid'>nextState</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>state</span> <span class='hs-varop'>&</span> <span class='hs-varid'>gradientsResult</span> <span class='hs-varop'>%~</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>insert</span> <span class='hs-varid'>node</span> <span class='hs-varid'>outputGrads</span>
|
|
<a name="line-251"></a> <span class='hs-varid'>ctx</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-varid'>context</span> <span class='hs-varid'>gr</span> <span class='hs-varid'>node</span>
|
|
<a name="line-252"></a> <span class='hs-keyword'>in</span> <span class='hs-varid'>updatePendingGradients</span>
|
|
<a name="line-253"></a> <span class='hs-varid'>ctx</span>
|
|
<a name="line-254"></a> <span class='hs-layout'>(</span><span class='hs-varid'>calculateInputGrads</span> <span class='hs-varid'>ctx</span> <span class='hs-varid'>outputGrads</span> <span class='hs-varid'>gr</span><span class='hs-layout'>)</span>
|
|
<a name="line-255"></a> <span class='hs-varid'>nextState</span>
|
|
<a name="line-256"></a>
|
|
<a name="line-257"></a><a name="sumPendingGradient"></a><span class='hs-comment'>-- | Reduce accumulated gradients for each output to one Tensor.</span>
|
|
<a name="line-258"></a><span class='hs-definition'>sumPendingGradient</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>GradientCompatible</span> <span class='hs-varid'>a</span>
|
|
<a name="line-259"></a> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>PendingGradients</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Gradients</span> <span class='hs-varid'>a</span>
|
|
<a name="line-260"></a><span class='hs-definition'>sumPendingGradient</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>IntMap</span><span class='hs-varop'>.</span><span class='hs-varid'>mapMaybe</span> <span class='hs-varid'>f</span>
|
|
<a name="line-261"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-262"></a> <span class='hs-varid'>f</span> <span class='hs-conid'>[]</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Nothing</span>
|
|
<a name="line-263"></a> <span class='hs-varid'>f</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>x</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Just</span> <span class='hs-varid'>x</span>
|
|
<a name="line-264"></a> <span class='hs-varid'>f</span> <span class='hs-varid'>xs</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Just</span> <span class='hs-layout'>(</span><span class='hs-varid'>addN</span> <span class='hs-varid'>xs</span><span class='hs-layout'>)</span>
|
|
<a name="line-265"></a>
|
|
<a name="line-266"></a>
|
|
<a name="line-267"></a><a name="calculateInputGrads"></a><span class='hs-comment'>-- | Calculate the gradients of a node's input tensors.</span>
|
|
<a name="line-268"></a><span class='hs-comment'>--</span>
|
|
<a name="line-269"></a><span class='hs-comment'>-- This is mostly just a wrapper around opGrad.</span>
|
|
<a name="line-270"></a><span class='hs-definition'>calculateInputGrads</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyword'>forall</span> <span class='hs-varid'>a</span><span class='hs-varop'>.</span> <span class='hs-conid'>GradientCompatible</span> <span class='hs-varid'>a</span>
|
|
<a name="line-271"></a> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Context</span> <span class='hs-conid'>NodeDef</span> <span class='hs-conid'>EdgeLabel</span>
|
|
<a name="line-272"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Gradients</span> <span class='hs-varid'>a</span> <span class='hs-comment'>-- ^ Output gradients of the node.</span>
|
|
<a name="line-273"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Graph</span>
|
|
<a name="line-274"></a> <span class='hs-keyglyph'>-></span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Maybe</span> <span class='hs-layout'>(</span><span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-275"></a><span class='hs-definition'>calculateInputGrads</span> <span class='hs-layout'>(</span><span class='hs-varid'>inputEdges</span><span class='hs-layout'>,</span> <span class='hs-keyword'>_</span><span class='hs-layout'>,</span> <span class='hs-varid'>nodeDef</span><span class='hs-layout'>,</span> <span class='hs-keyword'>_</span><span class='hs-layout'>)</span> <span class='hs-varid'>outputGrads</span> <span class='hs-varid'>gr</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-276"></a> <span class='hs-varid'>opGrad</span> <span class='hs-layout'>(</span><span class='hs-varid'>nodeDef</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>op</span><span class='hs-layout'>)</span> <span class='hs-varid'>nodeDef</span> <span class='hs-varid'>inputTensors</span> <span class='hs-varid'>fullOutGrads</span>
|
|
<a name="line-277"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-278"></a> <span class='hs-varid'>fullOutGrads</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-279"></a> <span class='hs-varid'>fullOutputGrads</span> <span class='hs-layout'>(</span><span class='hs-varid'>numOutputs</span> <span class='hs-varid'>nodeDef</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-conid'>Rendered</span> <span class='hs-varid'>nodeDef</span><span class='hs-layout'>)</span> <span class='hs-varid'>outputGrads</span>
|
|
<a name="line-280"></a> <span class='hs-comment'>-- Create a tensor from an edge (technically an Output, but it seems less</span>
|
|
<a name="line-281"></a> <span class='hs-comment'>-- confusing to refer to it as a tensor here).</span>
|
|
<a name="line-282"></a> <span class='hs-varid'>edgeToTensor</span> <span class='hs-keyglyph'>::</span> <span class='hs-layout'>(</span><span class='hs-conid'>EdgeLabel</span><span class='hs-layout'>,</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Output</span>
|
|
<a name="line-283"></a> <span class='hs-varid'>edgeToTensor</span> <span class='hs-layout'>(</span><span class='hs-layout'>(</span><span class='hs-varid'>i</span><span class='hs-layout'>,</span> <span class='hs-keyword'>_</span><span class='hs-layout'>)</span><span class='hs-layout'>,</span> <span class='hs-varid'>n</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-284"></a> <span class='hs-keyword'>case</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-varid'>lab</span> <span class='hs-varid'>gr</span> <span class='hs-varid'>n</span> <span class='hs-keyword'>of</span>
|
|
<a name="line-285"></a> <span class='hs-conid'>Just</span> <span class='hs-varid'>edgeNodeDef</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Output</span> <span class='hs-varid'>i</span> <span class='hs-layout'>(</span><span class='hs-conid'>Rendered</span> <span class='hs-varid'>edgeNodeDef</span><span class='hs-layout'>)</span>
|
|
<a name="line-286"></a> <span class='hs-conid'>Nothing</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>error</span> <span class='hs-varop'>$</span> <span class='hs-str'>"calculateInputGrads: missing input node for "</span>
|
|
<a name="line-287"></a> <span class='hs-varop'>++</span> <span class='hs-conid'>Text</span><span class='hs-varop'>.</span><span class='hs-varid'>unpack</span> <span class='hs-layout'>(</span><span class='hs-varid'>nodeDef</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>name</span><span class='hs-layout'>)</span>
|
|
<a name="line-288"></a> <span class='hs-comment'>-- Input tensors, sorted by input index.</span>
|
|
<a name="line-289"></a> <span class='hs-varid'>inputTensors</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>map</span> <span class='hs-varid'>edgeToTensor</span> <span class='hs-varop'>$</span> <span class='hs-varid'>sortBy</span> <span class='hs-layout'>(</span><span class='hs-varid'>comparing</span> <span class='hs-layout'>(</span><span class='hs-varid'>snd</span> <span class='hs-varop'>.</span> <span class='hs-varid'>fst</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span> <span class='hs-varid'>inputEdges</span>
|
|
<a name="line-290"></a>
|
|
<a name="line-291"></a><a name="fullOutputGrads"></a><span class='hs-comment'>-- | Convert a Map of gradients to a list, with zeros for missing outputs.</span>
|
|
<a name="line-292"></a><span class='hs-definition'>fullOutputGrads</span> <span class='hs-keyglyph'>::</span> <span class='hs-layout'>(</span><span class='hs-conid'>TensorType</span> <span class='hs-varid'>a</span><span class='hs-layout'>,</span> <span class='hs-conid'>Num</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-293"></a> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>OutputIx</span> <span class='hs-comment'>-- ^ Number of outputs.</span>
|
|
<a name="line-294"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Op</span>
|
|
<a name="line-295"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Gradients</span> <span class='hs-varid'>a</span>
|
|
<a name="line-296"></a> <span class='hs-keyglyph'>-></span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-297"></a><span class='hs-definition'>fullOutputGrads</span> <span class='hs-varid'>n</span> <span class='hs-varid'>o</span> <span class='hs-varid'>gs</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-298"></a> <span class='hs-varid'>map</span> <span class='hs-layout'>(</span><span class='hs-keyglyph'>\</span><span class='hs-varid'>i</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>fromMaybe</span> <span class='hs-layout'>(</span><span class='hs-varid'>zero</span> <span class='hs-varid'>i</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-varid'>gs</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>outputIxAt</span> <span class='hs-varid'>i</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>[</span><span class='hs-num'>0</span><span class='hs-keyglyph'>..</span><span class='hs-varid'>n</span><span class='hs-comment'>-</span><span class='hs-num'>1</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-299"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-300"></a> <span class='hs-comment'>-- A tensor of zeros with the same shape as the i'th output.</span>
|
|
<a name="line-301"></a> <span class='hs-varid'>zero</span> <span class='hs-varid'>i</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>zerosLike</span> <span class='hs-varop'>$</span> <span class='hs-varid'>toT</span> <span class='hs-layout'>(</span><span class='hs-conid'>Output</span> <span class='hs-varid'>i</span> <span class='hs-varid'>o</span><span class='hs-layout'>)</span>
|
|
<a name="line-302"></a>
|
|
<a name="line-303"></a>
|
|
<a name="line-304"></a><a name="updatePendingGradients"></a><span class='hs-comment'>-- | Update the pending gradients of a node's inputs.</span>
|
|
<a name="line-305"></a><span class='hs-definition'>updatePendingGradients</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyword'>forall</span> <span class='hs-varid'>a</span><span class='hs-varop'>.</span> <span class='hs-layout'>(</span><span class='hs-conid'>TensorType</span> <span class='hs-varid'>a</span><span class='hs-layout'>,</span> <span class='hs-conid'>Num</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-306"></a> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Context</span> <span class='hs-conid'>NodeDef</span> <span class='hs-conid'>EdgeLabel</span>
|
|
<a name="line-307"></a> <span class='hs-keyglyph'>-></span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Maybe</span> <span class='hs-layout'>(</span><span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-308"></a> <span class='hs-comment'>-- ^ Gradient of each input tensor.</span>
|
|
<a name="line-309"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>GradientsState</span> <span class='hs-varid'>a</span>
|
|
<a name="line-310"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>GradientsState</span> <span class='hs-varid'>a</span>
|
|
<a name="line-311"></a><span class='hs-definition'>updatePendingGradients</span> <span class='hs-layout'>(</span><span class='hs-varid'>inputEdges</span><span class='hs-layout'>,</span> <span class='hs-keyword'>_</span><span class='hs-layout'>,</span> <span class='hs-varid'>nodeDef</span><span class='hs-layout'>,</span> <span class='hs-keyword'>_</span><span class='hs-layout'>)</span> <span class='hs-varid'>inputGrads</span> <span class='hs-varid'>initState</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-312"></a> <span class='hs-varid'>foldl'</span> <span class='hs-varid'>go</span> <span class='hs-varid'>initState</span> <span class='hs-varid'>inputEdges</span>
|
|
<a name="line-313"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-314"></a> <span class='hs-varid'>go</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>GradientsState</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>-></span> <span class='hs-layout'>(</span><span class='hs-conid'>EdgeLabel</span><span class='hs-layout'>,</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>GradientsState</span> <span class='hs-varid'>a</span>
|
|
<a name="line-315"></a> <span class='hs-varid'>go</span> <span class='hs-varid'>state</span> <span class='hs-layout'>(</span><span class='hs-layout'>(</span><span class='hs-varid'>outIndex</span><span class='hs-layout'>,</span> <span class='hs-conid'>OutputIx</span> <span class='hs-varid'>inIndex</span><span class='hs-layout'>)</span><span class='hs-layout'>,</span> <span class='hs-varid'>node</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-316"></a> <span class='hs-keyword'>case</span> <span class='hs-varid'>maybeGradient</span> <span class='hs-keyword'>of</span>
|
|
<a name="line-317"></a> <span class='hs-conid'>Nothing</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>state</span>
|
|
<a name="line-318"></a> <span class='hs-conid'>Just</span> <span class='hs-varid'>g</span> <span class='hs-keyglyph'>-></span>
|
|
<a name="line-319"></a> <span class='hs-comment'>-- Add to the list of pending gradients for this tensor.</span>
|
|
<a name="line-320"></a> <span class='hs-varid'>state</span> <span class='hs-varop'>&</span> <span class='hs-varid'>gradientsPending</span>
|
|
<a name="line-321"></a> <span class='hs-varop'>.</span> <span class='hs-varid'>at</span> <span class='hs-varid'>node</span>
|
|
<a name="line-322"></a> <span class='hs-varop'>.</span> <span class='hs-varid'>nonEmpty</span>
|
|
<a name="line-323"></a> <span class='hs-varop'>.</span> <span class='hs-varid'>outputIxAt</span> <span class='hs-varid'>outIndex</span>
|
|
<a name="line-324"></a> <span class='hs-varop'>.</span> <span class='hs-varid'>nonEmpty</span>
|
|
<a name="line-325"></a> <span class='hs-varop'>%~</span> <span class='hs-layout'>(</span><span class='hs-varid'>g</span><span class='hs-conop'>:</span><span class='hs-layout'>)</span>
|
|
<a name="line-326"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-327"></a> <span class='hs-varid'>badSizeErr</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>error</span> <span class='hs-varop'>$</span> <span class='hs-varid'>printf</span> <span class='hs-str'>"updatePendingGradients: bad input index \
|
|
<a name="line-328"></a> \%d for inputGrads of length %d in %s"</span>
|
|
<a name="line-329"></a> <span class='hs-varid'>inIndex</span> <span class='hs-layout'>(</span><span class='hs-varid'>length</span> <span class='hs-varid'>inputGrads</span><span class='hs-layout'>)</span>
|
|
<a name="line-330"></a> <span class='hs-layout'>(</span><span class='hs-varid'>show</span> <span class='hs-layout'>(</span><span class='hs-varid'>nodeDef</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>name</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-331"></a> <span class='hs-varid'>maybeGradient</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>fromMaybe</span> <span class='hs-varid'>badSizeErr</span> <span class='hs-layout'>(</span><span class='hs-varid'>safeIndex</span> <span class='hs-varid'>inputGrads</span> <span class='hs-varid'>inIndex</span><span class='hs-layout'>)</span>
|
|
<a name="line-332"></a>
|
|
<a name="line-333"></a>
|
|
<a name="line-334"></a><a name="createGraph"></a><span class='hs-comment'>-- | Create a graph that includes a node and its transitive dependencies.</span>
|
|
<a name="line-335"></a><span class='hs-definition'>createGraph</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>NodeName</span> <span class='hs-keyglyph'>-></span> <span class='hs-layout'>(</span><span class='hs-conid'>NodeName</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>NodeDef</span><span class='hs-layout'>)</span>
|
|
<a name="line-336"></a> <span class='hs-keyglyph'>-></span> <span class='hs-layout'>(</span><span class='hs-conid'>Graph</span><span class='hs-layout'>,</span> <span class='hs-conid'>Map</span> <span class='hs-conid'>NodeName</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-conid'>Node</span><span class='hs-layout'>)</span>
|
|
<a name="line-337"></a><span class='hs-definition'>createGraph</span> <span class='hs-varid'>nodeName</span> <span class='hs-varid'>nodeDefLookup</span> <span class='hs-keyglyph'>=</span> <span class='hs-layout'>(</span><span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-varid'>nmap</span> <span class='hs-varid'>nodeDefLookup</span> <span class='hs-varid'>graph</span><span class='hs-layout'>,</span> <span class='hs-varid'>nodeMap</span><span class='hs-layout'>)</span>
|
|
<a name="line-338"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-339"></a> <span class='hs-comment'>-- Parse a tensor name.</span>
|
|
<a name="line-340"></a> <span class='hs-varid'>parseTensorName</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Text</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Maybe</span> <span class='hs-layout'>(</span><span class='hs-conid'>NodeName</span><span class='hs-layout'>,</span> <span class='hs-conid'>OutputIx</span><span class='hs-layout'>)</span>
|
|
<a name="line-341"></a> <span class='hs-varid'>parseTensorName</span> <span class='hs-varid'>n</span>
|
|
<a name="line-342"></a> <span class='hs-keyglyph'>|</span> <span class='hs-conid'>Text</span><span class='hs-varop'>.</span><span class='hs-varid'>null</span> <span class='hs-varid'>n</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>error</span> <span class='hs-str'>"parseTensorName: empty name"</span>
|
|
<a name="line-343"></a> <span class='hs-keyglyph'>|</span> <span class='hs-conid'>Text</span><span class='hs-varop'>.</span><span class='hs-varid'>head</span> <span class='hs-varid'>n</span> <span class='hs-varop'>==</span> <span class='hs-chr'>'^'</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Nothing</span> <span class='hs-comment'>-- Control edge</span>
|
|
<a name="line-344"></a> <span class='hs-keyglyph'>|</span> <span class='hs-varid'>otherwise</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-345"></a> <span class='hs-keyword'>let</span> <span class='hs-layout'>(</span><span class='hs-varid'>nm</span><span class='hs-layout'>,</span> <span class='hs-varid'>indexStr</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Text</span><span class='hs-varop'>.</span><span class='hs-varid'>breakOn</span> <span class='hs-str'>":"</span> <span class='hs-varid'>n</span>
|
|
<a name="line-346"></a> <span class='hs-varid'>index</span> <span class='hs-keyglyph'>|</span> <span class='hs-conid'>Text</span><span class='hs-varop'>.</span><span class='hs-varid'>null</span> <span class='hs-varid'>indexStr</span> <span class='hs-keyglyph'>=</span> <span class='hs-num'>0</span>
|
|
<a name="line-347"></a> <span class='hs-keyglyph'>|</span> <span class='hs-varid'>otherwise</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>read</span> <span class='hs-varop'>$</span> <span class='hs-conid'>Text</span><span class='hs-varop'>.</span><span class='hs-varid'>unpack</span> <span class='hs-varop'>$</span> <span class='hs-conid'>Text</span><span class='hs-varop'>.</span><span class='hs-varid'>tail</span> <span class='hs-varid'>indexStr</span>
|
|
<a name="line-348"></a> <span class='hs-keyword'>in</span> <span class='hs-conid'>Just</span> <span class='hs-layout'>(</span><span class='hs-conid'>NodeName</span> <span class='hs-varid'>nm</span><span class='hs-layout'>,</span> <span class='hs-conid'>OutputIx</span> <span class='hs-varid'>index</span><span class='hs-layout'>)</span>
|
|
<a name="line-349"></a>
|
|
<a name="line-350"></a> <span class='hs-comment'>-- Build a map from node name to outward edges.</span>
|
|
<a name="line-351"></a> <span class='hs-comment'>--</span>
|
|
<a name="line-352"></a> <span class='hs-comment'>-- The state is the set of visited nodes.</span>
|
|
<a name="line-353"></a> <span class='hs-varid'>collect</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Maybe</span> <span class='hs-layout'>(</span><span class='hs-conid'>NodeName</span><span class='hs-layout'>,</span> <span class='hs-conid'>OutputIx</span><span class='hs-layout'>,</span> <span class='hs-conid'>OutputIx</span><span class='hs-layout'>)</span>
|
|
<a name="line-354"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>NodeName</span>
|
|
<a name="line-355"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>State</span> <span class='hs-layout'>(</span><span class='hs-conid'>Set</span> <span class='hs-conid'>NodeName</span><span class='hs-layout'>)</span>
|
|
<a name="line-356"></a> <span class='hs-layout'>(</span><span class='hs-conid'>Map</span> <span class='hs-conid'>NodeName</span> <span class='hs-keyglyph'>[</span><span class='hs-layout'>(</span><span class='hs-conid'>NodeName</span><span class='hs-layout'>,</span> <span class='hs-conid'>OutputIx</span><span class='hs-layout'>,</span> <span class='hs-conid'>OutputIx</span><span class='hs-layout'>)</span><span class='hs-keyglyph'>]</span><span class='hs-layout'>)</span>
|
|
<a name="line-357"></a> <span class='hs-varid'>collect</span> <span class='hs-varid'>outgoingEdge</span> <span class='hs-varid'>nm</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyword'>do</span>
|
|
<a name="line-358"></a> <span class='hs-keyword'>let</span> <span class='hs-varid'>nextLookup</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>singleton</span> <span class='hs-varid'>nm</span> <span class='hs-layout'>(</span><span class='hs-varid'>maybeToList</span> <span class='hs-varid'>outgoingEdge</span><span class='hs-layout'>)</span>
|
|
<a name="line-359"></a> <span class='hs-varid'>seen</span> <span class='hs-keyglyph'><-</span> <span class='hs-varid'>gets</span> <span class='hs-layout'>(</span><span class='hs-conid'>Set</span><span class='hs-varop'>.</span><span class='hs-varid'>member</span> <span class='hs-varid'>nm</span><span class='hs-layout'>)</span>
|
|
<a name="line-360"></a> <span class='hs-varid'>modify</span> <span class='hs-layout'>(</span><span class='hs-conid'>Set</span><span class='hs-varop'>.</span><span class='hs-varid'>insert</span> <span class='hs-varid'>nm</span><span class='hs-layout'>)</span>
|
|
<a name="line-361"></a> <span class='hs-keyword'>if</span> <span class='hs-varid'>seen</span>
|
|
<a name="line-362"></a> <span class='hs-keyword'>then</span> <span class='hs-varid'>pure</span> <span class='hs-varid'>nextLookup</span>
|
|
<a name="line-363"></a> <span class='hs-keyword'>else</span> <span class='hs-keyword'>do</span>
|
|
<a name="line-364"></a> <span class='hs-keyword'>let</span> <span class='hs-varid'>inputs</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>nodeDefLookup</span> <span class='hs-varid'>nm</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>input</span>
|
|
<a name="line-365"></a> <span class='hs-varid'>recurse</span> <span class='hs-varid'>inIndex</span> <span class='hs-layout'>(</span><span class='hs-varid'>parentName</span><span class='hs-layout'>,</span> <span class='hs-varid'>outIndex</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-366"></a> <span class='hs-varid'>collect</span> <span class='hs-layout'>(</span><span class='hs-conid'>Just</span> <span class='hs-layout'>(</span><span class='hs-varid'>nm</span><span class='hs-layout'>,</span> <span class='hs-varid'>outIndex</span><span class='hs-layout'>,</span> <span class='hs-varid'>inIndex</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span> <span class='hs-varid'>parentName</span>
|
|
<a name="line-367"></a> <span class='hs-varid'>subEdgeLookups</span> <span class='hs-keyglyph'><-</span>
|
|
<a name="line-368"></a> <span class='hs-varid'>zipWithM</span> <span class='hs-varid'>recurse</span> <span class='hs-keyglyph'>[</span><span class='hs-num'>0</span><span class='hs-keyglyph'>..</span><span class='hs-keyglyph'>]</span> <span class='hs-varop'>$</span> <span class='hs-varid'>mapMaybe</span> <span class='hs-varid'>parseTensorName</span> <span class='hs-varid'>inputs</span>
|
|
<a name="line-369"></a> <span class='hs-varid'>pure</span> <span class='hs-varop'>$</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>unionsWith</span> <span class='hs-layout'>(</span><span class='hs-varop'>++</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-varid'>nextLookup</span><span class='hs-conop'>:</span><span class='hs-varid'>subEdgeLookups</span><span class='hs-layout'>)</span>
|
|
<a name="line-370"></a>
|
|
<a name="line-371"></a> <span class='hs-varid'>edgeLookup</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>evalState</span> <span class='hs-layout'>(</span><span class='hs-varid'>collect</span> <span class='hs-conid'>Nothing</span> <span class='hs-varid'>nodeName</span><span class='hs-layout'>)</span> <span class='hs-conid'>Set</span><span class='hs-varop'>.</span><span class='hs-varid'>empty</span>
|
|
<a name="line-372"></a> <span class='hs-comment'>-- Associate an ID with each node name.</span>
|
|
<a name="line-373"></a> <span class='hs-varid'>nodeMap</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>fromList</span> <span class='hs-varop'>$</span> <span class='hs-varid'>zip</span> <span class='hs-layout'>(</span><span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>keys</span> <span class='hs-varid'>edgeLookup</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>[</span><span class='hs-num'>0</span><span class='hs-keyglyph'>..</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-374"></a> <span class='hs-comment'>-- Create the graph.</span>
|
|
<a name="line-375"></a> <span class='hs-varid'>graph</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>FGL</span><span class='hs-varop'>.</span><span class='hs-varid'>mkGraph</span> <span class='hs-layout'>(</span><span class='hs-varid'>swap</span> <span class='hs-varop'><$></span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>toList</span> <span class='hs-varid'>nodeMap</span><span class='hs-layout'>)</span>
|
|
<a name="line-376"></a> <span class='hs-keyglyph'>[</span> <span class='hs-layout'>(</span><span class='hs-varid'>nodeMap</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.!</span> <span class='hs-varid'>n</span><span class='hs-layout'>,</span> <span class='hs-varid'>nodeMap</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.!</span> <span class='hs-varid'>m</span><span class='hs-layout'>,</span> <span class='hs-layout'>(</span><span class='hs-varid'>i</span><span class='hs-layout'>,</span> <span class='hs-varid'>j</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-377"></a> <span class='hs-keyglyph'>|</span> <span class='hs-layout'>(</span><span class='hs-varid'>n</span><span class='hs-layout'>,</span> <span class='hs-varid'>edges</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'><-</span> <span class='hs-conid'>Map</span><span class='hs-varop'>.</span><span class='hs-varid'>toList</span> <span class='hs-varid'>edgeLookup</span>
|
|
<a name="line-378"></a> <span class='hs-layout'>,</span> <span class='hs-layout'>(</span><span class='hs-varid'>m</span><span class='hs-layout'>,</span> <span class='hs-varid'>i</span><span class='hs-layout'>,</span> <span class='hs-varid'>j</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'><-</span> <span class='hs-varid'>edges</span>
|
|
<a name="line-379"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-380"></a>
|
|
<a name="line-381"></a><a name="GradientFunc"></a><span class='hs-comment'>-- | Function to compute the gradient of y w.r.t. each input.</span>
|
|
<a name="line-382"></a><a name="GradientFunc"></a><span class='hs-comment'>--</span>
|
|
<a name="line-383"></a><a name="GradientFunc"></a><span class='hs-comment'>-- Let y be an arbitrary tensor</span>
|
|
<a name="line-384"></a><a name="GradientFunc"></a><span class='hs-comment'>-- and [w_0, ..., w_n] be the output tensors of a node</span>
|
|
<a name="line-385"></a><a name="GradientFunc"></a><span class='hs-comment'>-- and [v_0, ..., v_n] be the input tensors of the same node.</span>
|
|
<a name="line-386"></a><a name="GradientFunc"></a><span class='hs-comment'>--</span>
|
|
<a name="line-387"></a><a name="GradientFunc"></a><span class='hs-comment'>-- Given [∂y/∂w_0, ..., ∂y/∂w_n] and [v_0, ..., v_n], a GradientFunc computes</span>
|
|
<a name="line-388"></a><a name="GradientFunc"></a><span class='hs-comment'>-- [∂y/∂v_0, ..., ∂y/∂v_n] for a particular op type.</span>
|
|
<a name="line-389"></a><a name="GradientFunc"></a><span class='hs-comment'>--</span>
|
|
<a name="line-390"></a><a name="GradientFunc"></a><span class='hs-comment'>-- A Nothing gradient is equivalent to zero (but allows for short circuiting</span>
|
|
<a name="line-391"></a><a name="GradientFunc"></a><span class='hs-comment'>-- computation when all the gradients for something are Nothing).</span>
|
|
<a name="line-392"></a><a name="GradientFunc"></a><span class='hs-keyword'>type</span> <span class='hs-conid'>GradientFunc</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>NodeDef</span>
|
|
<a name="line-393"></a> <span class='hs-keyglyph'>-></span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Output</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-394"></a> <span class='hs-comment'>-- ^ Input tensors.</span>
|
|
<a name="line-395"></a> <span class='hs-keyglyph'>-></span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-396"></a> <span class='hs-comment'>-- ^ Gradient of y w.r.t. each output tensor.</span>
|
|
<a name="line-397"></a> <span class='hs-keyglyph'>-></span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Maybe</span> <span class='hs-layout'>(</span><span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-398"></a> <span class='hs-comment'>-- ^ Gradient of y w.r.t. each input tensor.</span>
|
|
<a name="line-399"></a>
|
|
<a name="line-400"></a>
|
|
<a name="line-401"></a><a name="toT"></a><span class='hs-comment'>-- TODO(fmayle): Assert the type is correct.</span>
|
|
<a name="line-402"></a><span class='hs-comment'>-- | Create a Tensor from an Output.</span>
|
|
<a name="line-403"></a><span class='hs-definition'>toT</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Output</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span>
|
|
<a name="line-404"></a><span class='hs-definition'>toT</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>ValueKind</span>
|
|
<a name="line-405"></a>
|
|
<a name="line-406"></a>
|
|
<a name="line-407"></a><a name="flatSlice"></a><span class='hs-comment'>-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for</span>
|
|
<a name="line-408"></a><span class='hs-comment'>-- simple slicing operations.</span>
|
|
<a name="line-409"></a><span class='hs-definition'>flatSlice</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyword'>forall</span> <span class='hs-varid'>v1</span> <span class='hs-varid'>t</span> <span class='hs-varop'>.</span> <span class='hs-layout'>(</span><span class='hs-conid'>TensorType</span> <span class='hs-varid'>t</span><span class='hs-layout'>)</span>
|
|
<a name="line-410"></a> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>Tensor</span> <span class='hs-varid'>v1</span> <span class='hs-varid'>t</span> <span class='hs-comment'>-- ^ __input__</span>
|
|
<a name="line-411"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Int32</span> <span class='hs-comment'>-- ^ __begin__: specifies the offset into the first dimension of</span>
|
|
<a name="line-412"></a> <span class='hs-comment'>-- 'input' to slice from.</span>
|
|
<a name="line-413"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Int32</span> <span class='hs-comment'>-- ^ __size__: specifies the number of elements of the first dimension</span>
|
|
<a name="line-414"></a> <span class='hs-comment'>-- of 'input' to slice. If size is -1, all remaining elements in the dimension</span>
|
|
<a name="line-415"></a> <span class='hs-comment'>-- are included in the slice (i.e. this is equivalent to setting</span>
|
|
<a name="line-416"></a> <span class='hs-comment'>-- size = input.dim_size(0) - begin).</span>
|
|
<a name="line-417"></a> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>t</span> <span class='hs-comment'>-- ^ __output__</span>
|
|
<a name="line-418"></a><span class='hs-definition'>flatSlice</span> <span class='hs-varid'>t</span> <span class='hs-varid'>begin</span> <span class='hs-varid'>size</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>slice</span> <span class='hs-varid'>t</span> <span class='hs-layout'>(</span><span class='hs-varid'>vector</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>begin</span><span class='hs-keyglyph'>]</span><span class='hs-layout'>)</span> <span class='hs-layout'>(</span><span class='hs-varid'>vector</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>size</span><span class='hs-keyglyph'>]</span><span class='hs-layout'>)</span>
|
|
<a name="line-419"></a>
|
|
<a name="line-420"></a>
|
|
<a name="line-421"></a><a name="opGrad"></a><span class='hs-comment'>-- | The gradient function for an op type.</span>
|
|
<a name="line-422"></a><span class='hs-comment'>--</span>
|
|
<a name="line-423"></a><span class='hs-comment'>-- These implementations should match their python counterparts in:</span>
|
|
<a name="line-424"></a><span class='hs-comment'>-- third_party/tensorflow/python/ops/*_grad.py</span>
|
|
<a name="line-425"></a><span class='hs-definition'>opGrad</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyword'>forall</span> <span class='hs-varid'>a</span> <span class='hs-varop'>.</span> <span class='hs-conid'>GradientCompatible</span> <span class='hs-varid'>a</span> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>Text</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>GradientFunc</span> <span class='hs-varid'>a</span>
|
|
<a name="line-426"></a>
|
|
<a name="line-427"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Abs"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>dz</span> <span class='hs-varop'>*</span> <span class='hs-varid'>signum</span> <span class='hs-varid'>x</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-428"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Neg"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-keyword'>_</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-comment'>-</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-429"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Relu"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>reluGrad</span> <span class='hs-varid'>dz</span> <span class='hs-varid'>x</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-430"></a>
|
|
<a name="line-431"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Square"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-432"></a> <span class='hs-comment'>-- TODO(fmayle): Handle complex numbers.</span>
|
|
<a name="line-433"></a> <span class='hs-comment'>-- TODO(fmayle): The python code makes dz a control dependency of the 2*x</span>
|
|
<a name="line-434"></a> <span class='hs-comment'>-- (for performance reasons?). Will need to put these functions in the Build</span>
|
|
<a name="line-435"></a> <span class='hs-comment'>-- monad to replicate that.</span>
|
|
<a name="line-436"></a> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>dz</span> <span class='hs-varop'>*</span> <span class='hs-layout'>(</span><span class='hs-num'>2</span> <span class='hs-varop'>*</span> <span class='hs-varid'>x</span><span class='hs-layout'>)</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-437"></a>
|
|
<a name="line-438"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Gather"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>indices</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-439"></a> <span class='hs-comment'>-- TODO(fmayle): The python version uses a better performance implementation</span>
|
|
<a name="line-440"></a> <span class='hs-comment'>-- when the shape is known without having to run the graph.</span>
|
|
<a name="line-441"></a> <span class='hs-comment'>-- TODO(fmayle): We shouldn't convert the result to a dense tensor. Sparse</span>
|
|
<a name="line-442"></a> <span class='hs-comment'>-- tensor support will require some thinking.</span>
|
|
<a name="line-443"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>unsortedSegmentSum</span> <span class='hs-varid'>values</span> <span class='hs-varid'>indices'</span> <span class='hs-varid'>numRows</span>
|
|
<a name="line-444"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span>
|
|
<a name="line-445"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-446"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-447"></a> <span class='hs-comment'>-- TODO(gnezdo): Use colocateWith but it requires Build monad.</span>
|
|
<a name="line-448"></a> <span class='hs-varid'>denseShape</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-449"></a> <span class='hs-varid'>numRows</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>scalarize</span> <span class='hs-varop'>$</span> <span class='hs-varid'>flatSlice</span> <span class='hs-varid'>denseShape</span> <span class='hs-num'>0</span> <span class='hs-num'>1</span>
|
|
<a name="line-450"></a> <span class='hs-varid'>valuesShape</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>concat</span> <span class='hs-num'>0</span> <span class='hs-keyglyph'>[</span> <span class='hs-varid'>allDimensions</span>
|
|
<a name="line-451"></a> <span class='hs-layout'>,</span> <span class='hs-varid'>flatSlice</span> <span class='hs-varid'>denseShape</span> <span class='hs-num'>1</span> <span class='hs-layout'>(</span><span class='hs-comment'>-</span><span class='hs-num'>1</span><span class='hs-layout'>)</span>
|
|
<a name="line-452"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-453"></a> <span class='hs-varid'>values</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>reshape</span> <span class='hs-varid'>dz</span> <span class='hs-varid'>valuesShape</span>
|
|
<a name="line-454"></a> <span class='hs-comment'>-- TODO(fmayle): This could be either Int32 or Int64.</span>
|
|
<a name="line-455"></a> <span class='hs-varid'>indices'</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>reshape</span> <span class='hs-varid'>indices</span> <span class='hs-varid'>allDimensions</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span>
|
|
<a name="line-456"></a>
|
|
<a name="line-457"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Max"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>indices</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-458"></a> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>indicators</span> <span class='hs-varop'>`</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>div</span><span class='hs-varop'>`</span> <span class='hs-varid'>numSelected</span> <span class='hs-varop'>*</span> <span class='hs-varid'>dz'</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-459"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-460"></a> <span class='hs-varid'>sx</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-461"></a> <span class='hs-varid'>outputShapeKeptDims</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>reducedShape</span> <span class='hs-varid'>sx</span> <span class='hs-layout'>(</span><span class='hs-varid'>indices</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span><span class='hs-layout'>)</span>
|
|
<a name="line-462"></a> <span class='hs-varid'>x'</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>reshape</span> <span class='hs-varid'>x</span> <span class='hs-varid'>outputShapeKeptDims</span>
|
|
<a name="line-463"></a> <span class='hs-varid'>dz'</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>reshape</span> <span class='hs-varid'>dz</span> <span class='hs-varid'>outputShapeKeptDims</span>
|
|
<a name="line-464"></a> <span class='hs-varid'>indicators</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>cast</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>equal</span> <span class='hs-varid'>x'</span> <span class='hs-varid'>x</span>
|
|
<a name="line-465"></a> <span class='hs-varid'>numSelected</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>reshape</span> <span class='hs-layout'>(</span><span class='hs-varid'>sum</span> <span class='hs-varid'>indicators</span> <span class='hs-varid'>indices</span><span class='hs-layout'>)</span> <span class='hs-varid'>outputShapeKeptDims</span>
|
|
<a name="line-466"></a>
|
|
<a name="line-467"></a><span class='hs-comment'>-- Min and Max have identical gradient implementations.</span>
|
|
<a name="line-468"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Min"</span> <span class='hs-varid'>u</span> <span class='hs-varid'>v</span> <span class='hs-varid'>w</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>opGrad</span> <span class='hs-str'>"Max"</span> <span class='hs-varid'>u</span> <span class='hs-varid'>v</span> <span class='hs-varid'>w</span>
|
|
<a name="line-469"></a>
|
|
<a name="line-470"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Sum"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>indices</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-471"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>tile</span> <span class='hs-varid'>grad</span> <span class='hs-varid'>tileScaling</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-472"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-473"></a> <span class='hs-comment'>-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.</span>
|
|
<a name="line-474"></a> <span class='hs-varid'>sx</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-475"></a> <span class='hs-varid'>outputShapeKeptDims</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>reducedShape</span> <span class='hs-varid'>sx</span> <span class='hs-layout'>(</span><span class='hs-varid'>indices</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span><span class='hs-layout'>)</span>
|
|
<a name="line-476"></a> <span class='hs-varid'>tileScaling</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>safeShapeDiv</span> <span class='hs-varid'>sx</span> <span class='hs-varid'>outputShapeKeptDims</span>
|
|
<a name="line-477"></a> <span class='hs-varid'>grad</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>reshape</span> <span class='hs-varid'>dz</span> <span class='hs-varid'>outputShapeKeptDims</span>
|
|
<a name="line-478"></a>
|
|
<a name="line-479"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Mean"</span> <span class='hs-varid'>u</span> <span class='hs-varid'>v</span><span class='hs-keyglyph'>@</span><span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-keyword'>_</span><span class='hs-keyglyph'>]</span> <span class='hs-varid'>w</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-480"></a> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>dz</span> <span class='hs-varop'>`</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>div</span><span class='hs-varop'>`</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>cast</span> <span class='hs-varid'>factor</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-481"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-482"></a> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varid'>dz</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>opGrad</span> <span class='hs-str'>"Sum"</span> <span class='hs-varid'>u</span> <span class='hs-varid'>v</span> <span class='hs-varid'>w</span>
|
|
<a name="line-483"></a> <span class='hs-varid'>inputShape</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-484"></a> <span class='hs-varid'>outputShape</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>dz</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-485"></a> <span class='hs-comment'>-- TODO(fmayle): Add fast path when shape is known.</span>
|
|
<a name="line-486"></a> <span class='hs-varid'>inputSize</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>prod</span> <span class='hs-varid'>inputShape</span> <span class='hs-varop'>$</span> <span class='hs-varid'>rangeOfRank</span> <span class='hs-varid'>inputShape</span>
|
|
<a name="line-487"></a> <span class='hs-varid'>outputSize</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>prod</span> <span class='hs-varid'>outputShape</span> <span class='hs-varop'>$</span> <span class='hs-varid'>rangeOfRank</span> <span class='hs-varid'>outputShape</span>
|
|
<a name="line-488"></a> <span class='hs-varid'>factor</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>safeShapeDiv</span> <span class='hs-varid'>inputSize</span> <span class='hs-varid'>outputSize</span>
|
|
<a name="line-489"></a>
|
|
<a name="line-490"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Add"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>y</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-491"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>reshape</span> <span class='hs-layout'>(</span><span class='hs-varid'>sum</span> <span class='hs-varid'>dz</span> <span class='hs-varid'>rx</span><span class='hs-layout'>)</span> <span class='hs-varid'>sx</span>
|
|
<a name="line-492"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>reshape</span> <span class='hs-layout'>(</span><span class='hs-varid'>sum</span> <span class='hs-varid'>dz</span> <span class='hs-varid'>ry</span><span class='hs-layout'>)</span> <span class='hs-varid'>sy</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-493"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-494"></a> <span class='hs-varid'>sx</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-495"></a> <span class='hs-varid'>sy</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>y</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-496"></a> <span class='hs-layout'>(</span><span class='hs-varid'>rx</span><span class='hs-layout'>,</span> <span class='hs-varid'>ry</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>broadcastGradientArgs</span> <span class='hs-varid'>sx</span> <span class='hs-varid'>sy</span>
|
|
<a name="line-497"></a>
|
|
<a name="line-498"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Sub"</span> <span class='hs-varid'>u</span> <span class='hs-varid'>v</span> <span class='hs-varid'>w</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-499"></a> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-layout'>(</span><span class='hs-comment'>-</span><span class='hs-varid'>y</span><span class='hs-layout'>)</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-500"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-501"></a> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varid'>y</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>opGrad</span> <span class='hs-str'>"Add"</span> <span class='hs-varid'>u</span> <span class='hs-varid'>v</span> <span class='hs-varid'>w</span>
|
|
<a name="line-502"></a>
|
|
<a name="line-503"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"SoftmaxCrossEntropyWithLogits"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>y</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-layout'>,</span> <span class='hs-keyword'>_</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-504"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>expandDims</span> <span class='hs-varid'>dz</span> <span class='hs-layout'>(</span><span class='hs-comment'>-</span><span class='hs-num'>1</span><span class='hs-layout'>)</span> <span class='hs-varop'>*</span> <span class='hs-varid'>snd</span> <span class='hs-layout'>(</span><span class='hs-varid'>softmaxCrossEntropyWithLogits</span> <span class='hs-varid'>x</span> <span class='hs-varid'>y</span><span class='hs-layout'>)</span>
|
|
<a name="line-505"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-506"></a>
|
|
<a name="line-507"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Mul"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>y</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-508"></a> <span class='hs-comment'>-- TODO(fmayle): Handle complex numbers.</span>
|
|
<a name="line-509"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>reshape</span> <span class='hs-layout'>(</span><span class='hs-varid'>sum</span> <span class='hs-layout'>(</span><span class='hs-varid'>dz</span> <span class='hs-varop'>*</span> <span class='hs-varid'>y</span><span class='hs-layout'>)</span> <span class='hs-varid'>rx</span><span class='hs-layout'>)</span> <span class='hs-varid'>sx</span>
|
|
<a name="line-510"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>reshape</span> <span class='hs-layout'>(</span><span class='hs-varid'>sum</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-varop'>*</span> <span class='hs-varid'>dz</span><span class='hs-layout'>)</span> <span class='hs-varid'>ry</span><span class='hs-layout'>)</span> <span class='hs-varid'>sy</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-511"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-512"></a> <span class='hs-varid'>sx</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-513"></a> <span class='hs-varid'>sy</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>y</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-514"></a> <span class='hs-layout'>(</span><span class='hs-varid'>rx</span><span class='hs-layout'>,</span> <span class='hs-varid'>ry</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>broadcastGradientArgs</span> <span class='hs-varid'>sx</span> <span class='hs-varid'>sy</span>
|
|
<a name="line-515"></a>
|
|
<a name="line-516"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Div"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>y</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-517"></a> <span class='hs-comment'>-- TODO(fmayle): Handle complex numbers.</span>
|
|
<a name="line-518"></a> <span class='hs-comment'>-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.</span>
|
|
<a name="line-519"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>reshape</span> <span class='hs-layout'>(</span><span class='hs-varid'>sum</span> <span class='hs-layout'>(</span><span class='hs-varid'>dz</span> <span class='hs-varop'>`</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>div</span><span class='hs-varop'>`</span> <span class='hs-varid'>y</span><span class='hs-layout'>)</span> <span class='hs-varid'>rx</span><span class='hs-layout'>)</span> <span class='hs-varid'>sx</span>
|
|
<a name="line-520"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>reshape</span> <span class='hs-layout'>(</span><span class='hs-varid'>sum</span> <span class='hs-layout'>(</span><span class='hs-varid'>dz</span> <span class='hs-varop'>*</span> <span class='hs-layout'>(</span><span class='hs-varid'>negate</span> <span class='hs-varid'>x</span> <span class='hs-varop'>`</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>div</span><span class='hs-varop'>`</span> <span class='hs-layout'>(</span><span class='hs-varid'>y</span> <span class='hs-varop'>*</span> <span class='hs-varid'>y</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span> <span class='hs-varid'>ry</span><span class='hs-layout'>)</span> <span class='hs-varid'>sy</span>
|
|
<a name="line-521"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-522"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-523"></a> <span class='hs-varid'>sx</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-524"></a> <span class='hs-varid'>sy</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>y</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span>
|
|
<a name="line-525"></a> <span class='hs-layout'>(</span><span class='hs-varid'>rx</span><span class='hs-layout'>,</span> <span class='hs-varid'>ry</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>broadcastGradientArgs</span> <span class='hs-varid'>sx</span> <span class='hs-varid'>sy</span>
|
|
<a name="line-526"></a>
|
|
<a name="line-527"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"MatMul"</span> <span class='hs-varid'>nodeDef</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>y</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-528"></a> <span class='hs-keyword'>let</span> <span class='hs-varid'>transposeA</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"transpose_a"</span>
|
|
<a name="line-529"></a> <span class='hs-varid'>transposeB</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"transpose_b"</span>
|
|
<a name="line-530"></a> <span class='hs-varid'>transAttrs</span> <span class='hs-varid'>a</span> <span class='hs-varid'>b</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-531"></a> <span class='hs-layout'>(</span><span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"transpose_a"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span> <span class='hs-varop'>.</span> <span class='hs-layout'>(</span><span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"transpose_b"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>b</span><span class='hs-layout'>)</span>
|
|
<a name="line-532"></a> <span class='hs-keyword'>in</span> <span class='hs-keyword'>case</span> <span class='hs-layout'>(</span><span class='hs-varid'>transposeA</span><span class='hs-layout'>,</span> <span class='hs-varid'>transposeB</span><span class='hs-layout'>)</span> <span class='hs-keyword'>of</span>
|
|
<a name="line-533"></a> <span class='hs-layout'>(</span><span class='hs-conid'>False</span><span class='hs-layout'>,</span> <span class='hs-conid'>False</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>-></span>
|
|
<a name="line-534"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-layout'>(</span><span class='hs-varid'>dz</span> <span class='hs-varop'>`matMul`</span> <span class='hs-varid'>y</span><span class='hs-layout'>)</span> <span class='hs-varop'>&</span> <span class='hs-varid'>transAttrs</span> <span class='hs-conid'>False</span> <span class='hs-conid'>True</span>
|
|
<a name="line-535"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-varop'>`matMul`</span> <span class='hs-varid'>dz</span><span class='hs-layout'>)</span> <span class='hs-varop'>&</span> <span class='hs-varid'>transAttrs</span> <span class='hs-conid'>True</span> <span class='hs-conid'>False</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-536"></a> <span class='hs-layout'>(</span><span class='hs-conid'>False</span><span class='hs-layout'>,</span> <span class='hs-conid'>True</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>-></span>
|
|
<a name="line-537"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>dz</span> <span class='hs-varop'>`matMul`</span> <span class='hs-varid'>y</span>
|
|
<a name="line-538"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-varop'>`matMul`</span> <span class='hs-varid'>dz</span><span class='hs-layout'>)</span> <span class='hs-varop'>&</span> <span class='hs-varid'>transAttrs</span> <span class='hs-conid'>True</span> <span class='hs-conid'>False</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-539"></a> <span class='hs-layout'>(</span><span class='hs-conid'>True</span><span class='hs-layout'>,</span> <span class='hs-conid'>False</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>-></span>
|
|
<a name="line-540"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-layout'>(</span><span class='hs-varid'>dz</span> <span class='hs-varop'>`matMul`</span> <span class='hs-varid'>y</span><span class='hs-layout'>)</span> <span class='hs-varop'>&</span> <span class='hs-varid'>transAttrs</span> <span class='hs-conid'>False</span> <span class='hs-conid'>True</span>
|
|
<a name="line-541"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>x</span> <span class='hs-varop'>`matMul`</span> <span class='hs-varid'>dz</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-542"></a> <span class='hs-layout'>(</span><span class='hs-conid'>True</span><span class='hs-layout'>,</span> <span class='hs-conid'>True</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>-></span>
|
|
<a name="line-543"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-layout'>(</span><span class='hs-varid'>dz</span> <span class='hs-varop'>`matMul`</span> <span class='hs-varid'>y</span><span class='hs-layout'>)</span> <span class='hs-varop'>&</span> <span class='hs-varid'>transAttrs</span> <span class='hs-conid'>True</span> <span class='hs-conid'>True</span>
|
|
<a name="line-544"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-varop'>`matMul`</span> <span class='hs-varid'>dz</span><span class='hs-layout'>)</span> <span class='hs-varop'>&</span> <span class='hs-varid'>transAttrs</span> <span class='hs-conid'>True</span> <span class='hs-conid'>True</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-545"></a>
|
|
<a name="line-546"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Transpose"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-keyword'>_</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>p</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-547"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>transpose</span> <span class='hs-varid'>dz</span>
|
|
<a name="line-548"></a> <span class='hs-layout'>(</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>invertPermutation</span> <span class='hs-varid'>p</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span><span class='hs-layout'>)</span>
|
|
<a name="line-549"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span>
|
|
<a name="line-550"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-551"></a>
|
|
<a name="line-552"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Conv2D"</span> <span class='hs-varid'>nodeDef</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>y</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-553"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>conv2DBackpropInput</span> <span class='hs-layout'>(</span><span class='hs-varid'>shape</span> <span class='hs-varid'>x</span><span class='hs-layout'>)</span> <span class='hs-varid'>y</span> <span class='hs-varid'>dz</span>
|
|
<a name="line-554"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"strides"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>strides</span>
|
|
<a name="line-555"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"padding"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>padding</span>
|
|
<a name="line-556"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"use_cudnn_on_gpu"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>useCudnnOnGpu</span>
|
|
<a name="line-557"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"data_format"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>dataFormat</span>
|
|
<a name="line-558"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>conv2DBackpropFilter</span> <span class='hs-varid'>x</span> <span class='hs-layout'>(</span><span class='hs-varid'>shape</span> <span class='hs-varid'>y</span><span class='hs-layout'>)</span> <span class='hs-varid'>dz</span>
|
|
<a name="line-559"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"strides"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>strides</span>
|
|
<a name="line-560"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"padding"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>padding</span>
|
|
<a name="line-561"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"use_cudnn_on_gpu"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>useCudnnOnGpu</span>
|
|
<a name="line-562"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"data_format"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>dataFormat</span>
|
|
<a name="line-563"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-564"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-565"></a> <span class='hs-varid'>strides</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"strides"</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Int64</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-566"></a> <span class='hs-varid'>padding</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"padding"</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>ByteString</span>
|
|
<a name="line-567"></a> <span class='hs-varid'>useCudnnOnGpu</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"use_cudnn_on_gpu"</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Bool</span>
|
|
<a name="line-568"></a> <span class='hs-varid'>dataFormat</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"data_format"</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>ByteString</span>
|
|
<a name="line-569"></a>
|
|
<a name="line-570"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"MaxPool"</span> <span class='hs-varid'>nodeDef</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-571"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>maxPoolGrad</span> <span class='hs-varid'>x</span> <span class='hs-varid'>output</span> <span class='hs-varid'>dz</span>
|
|
<a name="line-572"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"ksize"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>ksize</span>
|
|
<a name="line-573"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"strides"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>strides</span>
|
|
<a name="line-574"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"padding"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>padding</span>
|
|
<a name="line-575"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>tensorAttr</span> <span class='hs-str'>"data_format"</span> <span class='hs-varop'>.~</span> <span class='hs-varid'>dataFormat</span>
|
|
<a name="line-576"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-577"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-578"></a> <span class='hs-varid'>output</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span>
|
|
<a name="line-579"></a> <span class='hs-varid'>output</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>toT</span> <span class='hs-varop'>$</span> <span class='hs-conid'>Output</span> <span class='hs-num'>0</span> <span class='hs-layout'>(</span><span class='hs-conid'>Rendered</span> <span class='hs-varid'>nodeDef</span><span class='hs-layout'>)</span>
|
|
<a name="line-580"></a> <span class='hs-varid'>ksize</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"ksize"</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Int64</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-581"></a> <span class='hs-varid'>strides</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"strides"</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Int64</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-582"></a> <span class='hs-varid'>padding</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"padding"</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>ByteString</span>
|
|
<a name="line-583"></a> <span class='hs-varid'>dataFormat</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"data_format"</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>ByteString</span>
|
|
<a name="line-584"></a>
|
|
<a name="line-585"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Reshape"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-keyword'>_</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-586"></a> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>reshape</span> <span class='hs-varid'>dz</span> <span class='hs-varop'>$</span> <span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-587"></a>
|
|
<a name="line-588"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"OneHot"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Nothing</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-589"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"TruncatedNormal"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-590"></a>
|
|
<a name="line-591"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"RefIdentity"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-592"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Cast"</span> <span class='hs-varid'>nodeDef</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Just</span> <span class='hs-varid'>reverseCast</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-593"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-594"></a> <span class='hs-comment'>-- TODO(gnezdo): too permissive, python only allows float types as src_type.</span>
|
|
<a name="line-595"></a> <span class='hs-varid'>reverseCast</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-596"></a> <span class='hs-varid'>buildOp</span> <span class='hs-layout'>(</span><span class='hs-varid'>opDef</span> <span class='hs-str'>"Cast"</span>
|
|
<a name="line-597"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>opAttr</span> <span class='hs-str'>"DstT"</span> <span class='hs-varop'>.~</span> <span class='hs-layout'>(</span><span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"SrcT"</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>ByteString</span><span class='hs-layout'>)</span>
|
|
<a name="line-598"></a> <span class='hs-varop'>&</span> <span class='hs-varid'>opAttr</span> <span class='hs-str'>"SrcT"</span> <span class='hs-varop'>.~</span> <span class='hs-layout'>(</span><span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"DstT"</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>ByteString</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-599"></a> <span class='hs-varid'>dz</span>
|
|
<a name="line-600"></a>
|
|
<a name="line-601"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"DynamicStitch"</span> <span class='hs-varid'>nodeDef</span> <span class='hs-varid'>inputs</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-602"></a> <span class='hs-varid'>replicate</span> <span class='hs-varid'>halfLen</span> <span class='hs-conid'>Nothing</span> <span class='hs-varop'>++</span> <span class='hs-varid'>valuesGrads</span>
|
|
<a name="line-603"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-604"></a> <span class='hs-varid'>halfLen</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-605"></a> <span class='hs-keyword'>let</span> <span class='hs-varid'>len</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>length</span> <span class='hs-varid'>inputs</span>
|
|
<a name="line-606"></a> <span class='hs-varid'>half</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>len</span> <span class='hs-varop'>`div`</span> <span class='hs-num'>2</span>
|
|
<a name="line-607"></a> <span class='hs-keyword'>in</span> <span class='hs-keyword'>if</span> <span class='hs-num'>2</span> <span class='hs-varop'>*</span> <span class='hs-varid'>half</span> <span class='hs-varop'>==</span> <span class='hs-varid'>len</span>
|
|
<a name="line-608"></a> <span class='hs-keyword'>then</span> <span class='hs-varid'>half</span>
|
|
<a name="line-609"></a> <span class='hs-keyword'>else</span> <span class='hs-varid'>error</span> <span class='hs-layout'>(</span><span class='hs-str'>"Uneven input size "</span> <span class='hs-varop'>++</span> <span class='hs-varid'>show</span> <span class='hs-layout'>(</span><span class='hs-varid'>len</span><span class='hs-layout'>,</span> <span class='hs-varid'>showMessage</span> <span class='hs-varid'>nodeDef</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-610"></a> <span class='hs-varid'>valuesGrads</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>gather</span> <span class='hs-varid'>dz</span> <span class='hs-layout'>(</span><span class='hs-varid'>toT</span> <span class='hs-varid'>idx</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span><span class='hs-layout'>)</span>
|
|
<a name="line-611"></a> <span class='hs-keyglyph'>|</span> <span class='hs-varid'>idx</span> <span class='hs-keyglyph'><-</span> <span class='hs-varid'>take</span> <span class='hs-varid'>halfLen</span> <span class='hs-varid'>inputs</span>
|
|
<a name="line-612"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-613"></a>
|
|
<a name="line-614"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"DynamicPartition"</span> <span class='hs-varid'>nodeDef</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>xs</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>indices</span><span class='hs-keyglyph'>]</span> <span class='hs-varid'>dz</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-615"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varid'>reconstructed</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-616"></a> <span class='hs-keyword'>where</span>
|
|
<a name="line-617"></a> <span class='hs-varid'>reconstructed</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>reshape</span> <span class='hs-varid'>stitched</span>
|
|
<a name="line-618"></a> <span class='hs-layout'>(</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>xs</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span><span class='hs-layout'>)</span>
|
|
<a name="line-619"></a> <span class='hs-varid'>stitched</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>dynamicStitch</span> <span class='hs-varid'>partitionedIndices</span> <span class='hs-varid'>dz</span>
|
|
<a name="line-620"></a> <span class='hs-varid'>partitionedIndices</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>dynamicPartition</span> <span class='hs-varid'>np</span> <span class='hs-varid'>originalIndices</span> <span class='hs-varid'>indices</span>
|
|
<a name="line-621"></a> <span class='hs-varid'>np</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-str'>"num_partitions"</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Int64</span>
|
|
<a name="line-622"></a> <span class='hs-varid'>originalIndices</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-623"></a> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>reshape</span> <span class='hs-layout'>(</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>range</span> <span class='hs-num'>0</span> <span class='hs-layout'>(</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>size</span> <span class='hs-varid'>indices</span><span class='hs-layout'>)</span> <span class='hs-num'>1</span><span class='hs-layout'>)</span> <span class='hs-varid'>prefixShape</span>
|
|
<a name="line-624"></a> <span class='hs-varid'>prefixShape</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>shapeInt32</span> <span class='hs-varid'>indices</span>
|
|
<a name="line-625"></a> <span class='hs-varid'>shapeInt32</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>shape</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span>
|
|
<a name="line-626"></a>
|
|
<a name="line-627"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Select"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>c</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-keyword'>_</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-628"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Nothing</span>
|
|
<a name="line-629"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>select</span> <span class='hs-varid'>c</span> <span class='hs-varid'>dz</span> <span class='hs-varid'>zeros</span>
|
|
<a name="line-630"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>select</span> <span class='hs-varid'>c</span> <span class='hs-varid'>zeros</span> <span class='hs-varid'>dz</span>
|
|
<a name="line-631"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-632"></a> <span class='hs-keyword'>where</span> <span class='hs-varid'>zeros</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>zerosLike</span> <span class='hs-varid'>x</span>
|
|
<a name="line-633"></a>
|
|
<a name="line-634"></a><span class='hs-comment'>-- TODO(gnezdo): Unlike Python, no control dependency on dz.</span>
|
|
<a name="line-635"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Log"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>dz</span> <span class='hs-varop'>*</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>inv</span> <span class='hs-varid'>x</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-636"></a><span class='hs-comment'>-- TODO(gnezdo): Reuse the output instead of doing another exp,</span>
|
|
<a name="line-637"></a><span class='hs-comment'>-- though, it is probably CSE'd away anyway.</span>
|
|
<a name="line-638"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Exp"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-varid'>dz</span> <span class='hs-varop'>*</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>exp</span> <span class='hs-varid'>x</span> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-639"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"SparseSegmentSum"</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>x</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>y</span><span class='hs-layout'>,</span> <span class='hs-varid'>toT</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>t</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>[</span><span class='hs-varid'>dz</span><span class='hs-keyglyph'>]</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-640"></a> <span class='hs-keyglyph'>[</span> <span class='hs-conid'>Just</span> <span class='hs-varop'>$</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>unsortedSegmentSum</span>
|
|
<a name="line-641"></a> <span class='hs-layout'>(</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>gather</span> <span class='hs-varid'>dz</span> <span class='hs-layout'>(</span><span class='hs-varid'>t</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span>
|
|
<a name="line-642"></a> <span class='hs-layout'>(</span><span class='hs-varid'>y</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span><span class='hs-layout'>)</span> <span class='hs-varid'>inputRows</span>
|
|
<a name="line-643"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span>
|
|
<a name="line-644"></a> <span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span>
|
|
<a name="line-645"></a> <span class='hs-keyglyph'>]</span>
|
|
<a name="line-646"></a> <span class='hs-keyword'>where</span> <span class='hs-varid'>inputRows</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>flatSlice</span> <span class='hs-layout'>(</span><span class='hs-varid'>shape</span> <span class='hs-layout'>(</span><span class='hs-varid'>x</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-varid'>a</span><span class='hs-layout'>)</span><span class='hs-layout'>)</span> <span class='hs-num'>0</span> <span class='hs-num'>1</span>
|
|
<a name="line-647"></a>
|
|
<a name="line-648"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"LabelClasses"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Nothing</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-649"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"LabelWeights"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-650"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Size"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-651"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"ZerosLike"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-652"></a>
|
|
<a name="line-653"></a><span class='hs-comment'>-- TODO(fmayle): These can go away if we properly prune the graph.</span>
|
|
<a name="line-654"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Const"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-keyglyph'>[</span><span class='hs-conid'>Nothing</span><span class='hs-layout'>,</span> <span class='hs-conid'>Nothing</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-655"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Placeholder"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>[]</span>
|
|
<a name="line-656"></a><span class='hs-definition'>opGrad</span> <span class='hs-str'>"Variable"</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>[]</span>
|
|
<a name="line-657"></a>
|
|
<a name="line-658"></a><span class='hs-definition'>opGrad</span> <span class='hs-varid'>n</span> <span class='hs-varid'>nodeDef</span> <span class='hs-varid'>ins</span> <span class='hs-varid'>grads</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-659"></a> <span class='hs-varid'>error</span> <span class='hs-varop'>$</span> <span class='hs-str'>"no gradient implemented for "</span> <span class='hs-varop'>++</span>
|
|
<a name="line-660"></a> <span class='hs-varid'>show</span> <span class='hs-layout'>(</span><span class='hs-varid'>n</span><span class='hs-layout'>,</span> <span class='hs-varid'>length</span> <span class='hs-varid'>ins</span><span class='hs-layout'>,</span> <span class='hs-varid'>length</span> <span class='hs-varid'>grads</span><span class='hs-layout'>,</span> <span class='hs-varid'>showMessage</span> <span class='hs-varid'>nodeDef</span><span class='hs-layout'>,</span> <span class='hs-varid'>ins</span><span class='hs-layout'>)</span>
|
|
<a name="line-661"></a>
|
|
<a name="line-662"></a><a name="numOutputs"></a><span class='hs-comment'>-- | The number of outputs for an op type.</span>
|
|
<a name="line-663"></a><span class='hs-definition'>numOutputs</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>NodeDef</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>OutputIx</span>
|
|
<a name="line-664"></a><span class='hs-definition'>numOutputs</span> <span class='hs-varid'>o</span> <span class='hs-keyglyph'>=</span>
|
|
<a name="line-665"></a> <span class='hs-keyword'>case</span> <span class='hs-varid'>o</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>op</span> <span class='hs-keyword'>of</span>
|
|
<a name="line-666"></a> <span class='hs-str'>"Abs"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-667"></a> <span class='hs-str'>"Add"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-668"></a> <span class='hs-str'>"Cast"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-669"></a> <span class='hs-str'>"Const"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-670"></a> <span class='hs-str'>"Conv2D"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-671"></a> <span class='hs-str'>"Div"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-672"></a> <span class='hs-str'>"DynamicStitch"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-673"></a> <span class='hs-str'>"DynamicPartition"</span> <span class='hs-keyglyph'>-></span>
|
|
<a name="line-674"></a> <span class='hs-varid'>fromIntegral</span> <span class='hs-layout'>(</span><span class='hs-varid'>lookupAttr</span> <span class='hs-varid'>o</span> <span class='hs-str'>"num_partitions"</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Int64</span><span class='hs-layout'>)</span>
|
|
<a name="line-675"></a> <span class='hs-str'>"Exp"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-676"></a> <span class='hs-str'>"Gather"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-677"></a> <span class='hs-str'>"LabelClasses"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-678"></a> <span class='hs-str'>"LabelWeights"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-679"></a> <span class='hs-str'>"Log"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-680"></a> <span class='hs-str'>"MatMul"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-681"></a> <span class='hs-str'>"Max"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-682"></a> <span class='hs-str'>"MaxPool"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-683"></a> <span class='hs-str'>"Mean"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-684"></a> <span class='hs-str'>"Min"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-685"></a> <span class='hs-str'>"Mul"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-686"></a> <span class='hs-str'>"Neg"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-687"></a> <span class='hs-str'>"Placeholder"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-688"></a> <span class='hs-str'>"OneHot"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-689"></a> <span class='hs-str'>"RefIdentity"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-690"></a> <span class='hs-str'>"Relu"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-691"></a> <span class='hs-str'>"Reshape"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-692"></a> <span class='hs-str'>"Select"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-693"></a> <span class='hs-str'>"Size"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-694"></a> <span class='hs-str'>"SoftmaxCrossEntropyWithLogits"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>2</span>
|
|
<a name="line-695"></a> <span class='hs-str'>"Square"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-696"></a> <span class='hs-str'>"SparseSegmentSum"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-697"></a> <span class='hs-str'>"Sub"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-698"></a> <span class='hs-str'>"Sum"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-699"></a> <span class='hs-str'>"Transpose"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-700"></a> <span class='hs-str'>"TruncatedNormal"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-701"></a> <span class='hs-str'>"Variable"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-702"></a> <span class='hs-str'>"ZerosLike"</span> <span class='hs-keyglyph'>-></span> <span class='hs-num'>1</span>
|
|
<a name="line-703"></a> <span class='hs-keyword'>_</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>error</span> <span class='hs-varop'>$</span> <span class='hs-str'>"numOuputs not implemented for "</span> <span class='hs-varop'>++</span> <span class='hs-varid'>show</span> <span class='hs-layout'>(</span><span class='hs-varid'>o</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>op</span><span class='hs-layout'>)</span>
|
|
<a name="line-704"></a>
|
|
<a name="line-705"></a><a name="safeShapeDiv"></a><span class='hs-comment'>-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`</span>
|
|
<a name="line-706"></a><span class='hs-definition'>safeShapeDiv</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-varid'>v1</span> <span class='hs-conid'>Int32</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Tensor</span> <span class='hs-varid'>v2</span> <span class='hs-conid'>Int32</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span>
|
|
<a name="line-707"></a><span class='hs-definition'>safeShapeDiv</span> <span class='hs-varid'>x</span> <span class='hs-varid'>y</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>x</span> <span class='hs-varop'>`</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>div</span><span class='hs-varop'>`</span> <span class='hs-layout'>(</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>maximum</span> <span class='hs-varid'>y</span> <span class='hs-num'>1</span><span class='hs-layout'>)</span>
|
|
<a name="line-708"></a>
|
|
<a name="line-709"></a><a name="allDimensions"></a><span class='hs-definition'>allDimensions</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span>
|
|
<a name="line-710"></a><span class='hs-definition'>allDimensions</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>vector</span> <span class='hs-keyglyph'>[</span><span class='hs-comment'>-</span><span class='hs-num'>1</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Int32</span><span class='hs-keyglyph'>]</span>
|
|
<a name="line-711"></a>
|
|
<a name="line-712"></a><a name="rangeOfRank"></a><span class='hs-definition'>rangeOfRank</span> <span class='hs-keyglyph'>::</span> <span class='hs-keyword'>forall</span> <span class='hs-varid'>v1</span> <span class='hs-varid'>t</span><span class='hs-varop'>.</span> <span class='hs-conid'>TensorType</span> <span class='hs-varid'>t</span> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>Tensor</span> <span class='hs-varid'>v1</span> <span class='hs-varid'>t</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Tensor</span> <span class='hs-conid'>Value</span> <span class='hs-conid'>Int32</span>
|
|
<a name="line-713"></a><span class='hs-definition'>rangeOfRank</span> <span class='hs-varid'>x</span> <span class='hs-keyglyph'>=</span> <span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>range</span> <span class='hs-num'>0</span> <span class='hs-layout'>(</span><span class='hs-conid'>CoreOps</span><span class='hs-varop'>.</span><span class='hs-varid'>rank</span> <span class='hs-varid'>x</span><span class='hs-layout'>)</span> <span class='hs-num'>1</span>
|
|
<a name="line-714"></a>
|
|
<a name="line-715"></a><a name="lookupAttr"></a><span class='hs-definition'>lookupAttr</span> <span class='hs-keyglyph'>::</span> <span class='hs-conid'>Attribute</span> <span class='hs-varid'>a1</span> <span class='hs-keyglyph'>=></span> <span class='hs-conid'>NodeDef</span> <span class='hs-keyglyph'>-></span> <span class='hs-conid'>Text</span> <span class='hs-keyglyph'>-></span> <span class='hs-varid'>a1</span>
|
|
<a name="line-716"></a><span class='hs-definition'>lookupAttr</span> <span class='hs-varid'>nodeDef</span> <span class='hs-varid'>attrName</span> <span class='hs-keyglyph'>=</span> <span class='hs-varid'>nodeDef</span> <span class='hs-varop'>^.</span> <span class='hs-varid'>attr</span> <span class='hs-varop'>.</span> <span class='hs-varid'>at</span> <span class='hs-varid'>attrName</span> <span class='hs-varop'>.</span> <span class='hs-varid'>non</span> <span class='hs-varid'>def</span> <span class='hs-varop'>.</span> <span class='hs-varid'>attrLens</span>
|
|
</pre></body>
|
|
</html>
|