mirror of
https://github.com/tensorflow/haskell.git
synced 2025-04-02 03:55:22 +02:00
697 lines
26 KiB
Haskell
697 lines
26 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE ConstraintKinds #-}
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
{-# LANGUAGE TypeFamilies #-}
|
|
{-# LANGUAGE ViewPatterns #-}
|
|
|
|
module TensorFlow.Gradient
|
|
( gradients
|
|
) where
|
|
|
|
import Control.Monad (forM, zipWithM)
|
|
import Control.Monad.State.Strict (State, evalState, gets, modify)
|
|
import Data.ByteString (ByteString)
|
|
import Data.Complex (Complex)
|
|
import Data.Default (def)
|
|
import Data.Int (Int32, Int64)
|
|
import Data.List (foldl', sortBy)
|
|
import Data.Map.Strict (Map)
|
|
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
|
import Data.Ord (comparing)
|
|
import Data.ProtoLens.TextFormat (showMessage)
|
|
import Data.Set (Set)
|
|
import Data.Text (Text)
|
|
import Data.Tuple (swap)
|
|
import Lens.Family2 (Lens', (&), (^.), (.~), (%~))
|
|
import Lens.Family2.State.Strict (uses)
|
|
import Lens.Family2.Stock (at, intAt)
|
|
import Lens.Family2.Unchecked (lens, iso)
|
|
import Prelude hiding (sum)
|
|
import Text.Printf (printf)
|
|
import qualified Data.Graph.Inductive.Basic as FGL
|
|
import qualified Data.Graph.Inductive.Graph as FGL
|
|
import qualified Data.Graph.Inductive.PatriciaTree as FGL
|
|
import qualified Data.Graph.Inductive.Query.DFS as FGL
|
|
import qualified Data.IntMap.Strict as IntMap
|
|
import qualified Data.Map.Strict as Map
|
|
import qualified Data.Set as Set
|
|
import qualified Data.Text as Text
|
|
|
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
|
import TensorFlow.Build
|
|
( Build
|
|
, render
|
|
, renderNodeName
|
|
, renderedNodeDefs
|
|
, opDef
|
|
, opAttr
|
|
)
|
|
import TensorFlow.BuildOp
|
|
import TensorFlow.Ops
|
|
( addN
|
|
, broadcastGradientArgs
|
|
, expandDims
|
|
, fill
|
|
, matMul
|
|
, reducedShape
|
|
, reluGrad
|
|
, reshape
|
|
, scalar
|
|
, shape
|
|
, softmaxCrossEntropyWithLogits
|
|
, sum
|
|
, vector
|
|
, zerosLike
|
|
)
|
|
import TensorFlow.Output
|
|
( NodeName(..)
|
|
, Op (Rendered)
|
|
, Output(..)
|
|
, OutputIx(..)
|
|
, outputIndex
|
|
)
|
|
import TensorFlow.Tensor
|
|
( Tensor(..)
|
|
, TensorKind (ValueKind)
|
|
, Value
|
|
, tensorOutput
|
|
, tensorAttr
|
|
)
|
|
import TensorFlow.Types (OneOf, TensorType, attrLens)
|
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
|
(NodeDef, attr, input, op, name)
|
|
|
|
type GradientCompatible a =
|
|
-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
|
|
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
|
|
|
|
-- TODO(fmayle): Support control flow.
|
|
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.
|
|
-- TODO(fmayle): Do we need to consider control inputs? See _PendingCount in
|
|
-- tensorflow/python/ops/gradients.py.
|
|
-- TODO(fmayle): Maybe store the gradient functions and numOutputs on the OpDef.
|
|
|
|
|
|
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
|
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
|
|
-- TODO(gnezdo): remove indirect constraint.
|
|
-- It's a wart inherited from Num instance.
|
|
, v1 ~ Value
|
|
, GradientCompatible a
|
|
)
|
|
=> Tensor v1 a -- ^ The output of the graph.
|
|
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
|
-> Build [Tensor Value a]
|
|
gradients y xs = do
|
|
-- The gradients are computed using "reverse accumulation", similarly to
|
|
-- what is described here:
|
|
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
|
|
--
|
|
-- The code is summarised as follows:
|
|
--
|
|
-- 1. Create an fgl graph of the relevant nodes (ops) and edges (tensors).
|
|
-- 2. Initialize the gradient of y to 1 (∂y/∂y = 1) and the rest of tensor's
|
|
-- gradients to nothing.
|
|
-- 3. Process the nodes in reverse topological order (i.e. each node comes
|
|
-- after all of its outputs so that the output gradients for a node have
|
|
-- been completely calculated before it is processed):
|
|
-- a. Record the gradient for each of the node's output tensors (∂y/∂w
|
|
-- for each output tensor w).
|
|
-- b. Calculate the gradient of y w.r.t. each of the node's input
|
|
-- tensors using the gradients of the node's output tensors.
|
|
--
|
|
-- Written differently, for each output tensor w and input tensor v:
|
|
-- ∂y/∂w = ... (calculated in previous steps)
|
|
-- ∂w/∂v = ... (op specific)
|
|
-- ∂y/∂v = ∂y/∂w * ∂w/∂v (technically, if tensor v is an input
|
|
-- to multiple nodes, then this is only
|
|
-- part of ∂y/∂v)
|
|
--
|
|
-- 4. Lookup the recorded gradient for each x in xs.
|
|
|
|
yName <- renderNodeName y
|
|
-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
|
|
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
|
|
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
|
. flip Map.lookup
|
|
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
|
-- Set gradient of y to one.
|
|
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
|
initPending = Map.empty & at (nodeMap Map.! yName)
|
|
. nonEmpty
|
|
. outputIxAt (y ^. tensorOutput . outputIndex)
|
|
. nonEmpty
|
|
.~ [fill (shape y) (scalar 1)]
|
|
-- Calculate the gradients of y w.r.t. each node in the graph.
|
|
gradientMap <- graphGrads gr initPending
|
|
-- Lookup the gradients for each x.
|
|
forM xs $ \x -> do
|
|
xName <- renderNodeName x
|
|
render $ fromMaybe (zerosLike x) $ do
|
|
n <- nodeMap ^. at xName
|
|
let i = x ^. tensorOutput . outputIndex
|
|
gradientMap ^. at n . nonEmpty . outputIxAt i
|
|
|
|
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
|
outputIxAt = intAt . unOutputIx
|
|
|
|
-- | Incomplete gradients of a node's outputs.
|
|
--
|
|
-- The lists represent partial sums. The key is an OutputIx sans newtype.
|
|
type PendingGradients a = IntMap.IntMap [Tensor Value a]
|
|
|
|
-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
|
|
type Gradients a = IntMap.IntMap (Tensor Value a)
|
|
|
|
-- | Graph of TensorFlow operations.
|
|
type Graph = FGL.Gr NodeDef EdgeLabel
|
|
|
|
-- | Data associated with an edge.
|
|
--
|
|
-- Pair of
|
|
-- 1. Output index of a tensor from the source node.
|
|
-- 2. Input index that the tensor connects to on the destination node.
|
|
type EdgeLabel = (OutputIx, OutputIx)
|
|
|
|
|
|
-- | State used for calculating gradients.
|
|
data GradientsState a = GradientsState
|
|
{ _gradientsPending :: !(Map FGL.Node (PendingGradients a))
|
|
, _gradientsResult :: !(Map FGL.Node (Gradients a))
|
|
}
|
|
|
|
gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a))
|
|
gradientsPending = lens _gradientsPending (\x y -> x { _gradientsPending = y })
|
|
|
|
gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a))
|
|
gradientsResult = lens _gradientsResult (\x y -> x { _gradientsResult = y })
|
|
|
|
|
|
-- TODO(fmayle): Use something like Data.List.Safe.
|
|
-- | Safe version of (!!).
|
|
safeIndex :: [a] -> Int -> Maybe a
|
|
_ `safeIndex` n | n < 0 = Nothing
|
|
[] `safeIndex` _ = Nothing
|
|
(x:_) `safeIndex` 0 = Just x
|
|
(_:xs) `safeIndex` n = xs `safeIndex` (n-1)
|
|
|
|
-- Copy of http://hackage.haskell.org/package/lens-3.9.0.2/docs/Control-Lens-Iso.html#v%3anon
|
|
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
|
|
anon a p = iso (fromMaybe a) go where
|
|
go b | p b = Nothing
|
|
| otherwise = Just b
|
|
|
|
non :: Eq a => a -> Lens' (Maybe a) a
|
|
non a = anon a (a==)
|
|
|
|
-- | Lens that defaults Nothing to mempty.
|
|
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
|
|
nonEmpty = anon mempty null
|
|
|
|
-- | Calculate the gradients for every node in a graph.
|
|
graphGrads :: forall a. GradientCompatible a
|
|
=> Graph
|
|
-> Map FGL.Node (PendingGradients a)
|
|
-- ^ Initial gradients (usually just 1 for the node of interest).
|
|
-> Build (Map FGL.Node (Gradients a))
|
|
graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult)
|
|
where
|
|
initState = GradientsState initPending Map.empty
|
|
-- Reverse topological sort.
|
|
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
|
|
-- avoid calculating gradients that won't be used.
|
|
nodeOrder = FGL.topsort $ FGL.grev gr
|
|
go state node =
|
|
-- Aggregate the accumulated gradients for this node.
|
|
let outputGrads =
|
|
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
|
|
in if null outputGrads
|
|
then state
|
|
else
|
|
-- Calculate the gradients for each of the node's inputs.
|
|
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
|
ctx = FGL.context gr node
|
|
in updatePendingGradients
|
|
ctx
|
|
(calculateInputGrads ctx outputGrads gr)
|
|
nextState
|
|
|
|
-- | Reduce accumulated gradients for each output to one Tensor.
|
|
sumPendingGradient :: GradientCompatible a
|
|
=> PendingGradients a -> Gradients a
|
|
sumPendingGradient = IntMap.mapMaybe f
|
|
where
|
|
f [] = Nothing
|
|
f [x] = Just x
|
|
f xs = Just (addN xs)
|
|
|
|
|
|
-- | Calculate the gradients of a node's input tensors.
|
|
--
|
|
-- This is mostly just a wrapper around opGrad.
|
|
calculateInputGrads :: forall a. GradientCompatible a
|
|
=> FGL.Context NodeDef EdgeLabel
|
|
-> Gradients a -- ^ Output gradients of the node.
|
|
-> Graph
|
|
-> [Maybe (Tensor Value a)]
|
|
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
|
|
opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
|
|
where
|
|
fullOutGrads =
|
|
fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads
|
|
-- Create a tensor from an edge (technically an Output, but it seems less
|
|
-- confusing to refer to it as a tensor here).
|
|
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
|
|
edgeToTensor ((i, _), n) =
|
|
case FGL.lab gr n of
|
|
Just edgeNodeDef -> Output i (Rendered edgeNodeDef)
|
|
Nothing -> error $ "calculateInputGrads: missing input node for "
|
|
++ Text.unpack (nodeDef ^. name)
|
|
-- Input tensors, sorted by input index.
|
|
inputTensors = map edgeToTensor $ sortBy (comparing (snd . fst)) inputEdges
|
|
|
|
-- | Convert a Map of gradients to a list, with zeros for missing outputs.
|
|
fullOutputGrads :: (TensorType a, Num a)
|
|
=> OutputIx -- ^ Number of outputs.
|
|
-> Op
|
|
-> Gradients a
|
|
-> [Tensor Value a]
|
|
fullOutputGrads n o gs =
|
|
map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1]
|
|
where
|
|
-- A tensor of zeros with the same shape as the i'th output.
|
|
zero i = zerosLike $ toT (Output i o)
|
|
|
|
|
|
-- | Update the pending gradients of a node's inputs.
|
|
updatePendingGradients :: forall a. (TensorType a, Num a)
|
|
=> FGL.Context NodeDef EdgeLabel
|
|
-> [Maybe (Tensor Value a)]
|
|
-- ^ Gradient of each input tensor.
|
|
-> GradientsState a
|
|
-> GradientsState a
|
|
updatePendingGradients (inputEdges, _, nodeDef, _) inputGrads initState =
|
|
foldl' go initState inputEdges
|
|
where
|
|
go :: GradientsState a -> (EdgeLabel, FGL.Node) -> GradientsState a
|
|
go state ((outIndex, OutputIx inIndex), node) =
|
|
case maybeGradient of
|
|
Nothing -> state
|
|
Just g ->
|
|
-- Add to the list of pending gradients for this tensor.
|
|
state & gradientsPending
|
|
. at node
|
|
. nonEmpty
|
|
. outputIxAt outIndex
|
|
. nonEmpty
|
|
%~ (g:)
|
|
where
|
|
badSizeErr = error $ printf "updatePendingGradients: bad input index \
|
|
\%d for inputGrads of length %d in %s"
|
|
inIndex (length inputGrads)
|
|
(show (nodeDef ^. name))
|
|
maybeGradient = fromMaybe badSizeErr (safeIndex inputGrads inIndex)
|
|
|
|
|
|
-- | Create a graph that includes a node and its transitive dependencies.
|
|
createGraph :: NodeName -> (NodeName -> NodeDef)
|
|
-> (Graph, Map NodeName FGL.Node)
|
|
createGraph nodeName nodeDefLookup = (FGL.nmap nodeDefLookup graph, nodeMap)
|
|
where
|
|
-- Parse a tensor name.
|
|
parseTensorName :: Text -> Maybe (NodeName, OutputIx)
|
|
parseTensorName n
|
|
| Text.null n = error "parseTensorName: empty name"
|
|
| Text.head n == '^' = Nothing -- Control edge
|
|
| otherwise =
|
|
let (nm, indexStr) = Text.breakOn ":" n
|
|
index | Text.null indexStr = 0
|
|
| otherwise = read $ Text.unpack $ Text.tail indexStr
|
|
in Just (NodeName nm, OutputIx index)
|
|
|
|
-- Build a map from node name to outward edges.
|
|
--
|
|
-- The state is the set of visited nodes.
|
|
collect :: Maybe (NodeName, OutputIx, OutputIx)
|
|
-> NodeName
|
|
-> State (Set NodeName)
|
|
(Map NodeName [(NodeName, OutputIx, OutputIx)])
|
|
collect outgoingEdge nm = do
|
|
let nextLookup = Map.singleton nm (maybeToList outgoingEdge)
|
|
seen <- gets (Set.member nm)
|
|
modify (Set.insert nm)
|
|
if seen
|
|
then pure nextLookup
|
|
else do
|
|
let inputs = nodeDefLookup nm ^. input
|
|
recurse inIndex (parentName, outIndex) =
|
|
collect (Just (nm, outIndex, inIndex)) parentName
|
|
subEdgeLookups <-
|
|
zipWithM recurse [0..] $ mapMaybe parseTensorName inputs
|
|
pure $ Map.unionsWith (++) (nextLookup:subEdgeLookups)
|
|
|
|
edgeLookup = evalState (collect Nothing nodeName) Set.empty
|
|
-- Associate an ID with each node name.
|
|
nodeMap = Map.fromList $ zip (Map.keys edgeLookup) [0..]
|
|
-- Create the graph.
|
|
graph = FGL.mkGraph (swap <$> Map.toList nodeMap)
|
|
[ (nodeMap Map.! n, nodeMap Map.! m, (i, j))
|
|
| (n, edges) <- Map.toList edgeLookup
|
|
, (m, i, j) <- edges
|
|
]
|
|
|
|
-- | Function to compute the gradient of y w.r.t. each input.
|
|
--
|
|
-- Let y be an arbitrary tensor
|
|
-- and [w_0, ..., w_n] be the output tensors of a node
|
|
-- and [v_0, ..., v_n] be the input tensors of the same node.
|
|
--
|
|
-- Given [∂y/∂w_0, ..., ∂y/∂w_n] and [v_0, ..., v_n], a GradientFunc computes
|
|
-- [∂y/∂v_0, ..., ∂y/∂v_n] for a particular op type.
|
|
--
|
|
-- A Nothing gradient is equivalent to zero (but allows for short circuiting
|
|
-- computation when all the gradients for something are Nothing).
|
|
type GradientFunc a = NodeDef
|
|
-> [Output]
|
|
-- ^ Input tensors.
|
|
-> [Tensor Value a]
|
|
-- ^ Gradient of y w.r.t. each output tensor.
|
|
-> [Maybe (Tensor Value a)]
|
|
-- ^ Gradient of y w.r.t. each input tensor.
|
|
|
|
|
|
-- TODO(fmayle): Assert the type is correct.
|
|
-- | Create a Tensor from an Output.
|
|
toT :: Output -> Tensor Value a
|
|
toT = Tensor ValueKind
|
|
|
|
-- | The gradient function for an op type.
|
|
--
|
|
-- These implementations should match their python counterparts in:
|
|
-- third_party/tensorflow/python/ops/*_grad.py
|
|
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
|
|
|
|
opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x]
|
|
opGrad "Neg" _ [_] [dz] = [Just $ -dz]
|
|
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
|
|
|
opGrad "Square" _ [toT -> x] [dz] =
|
|
-- TODO(fmayle): Handle complex numbers.
|
|
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
|
|
-- (for performance reasons?). Will need to put these functions in the Build
|
|
-- monad to replicate that.
|
|
[Just $ dz * (2 * x)]
|
|
|
|
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
|
-- TODO(fmayle): The python version uses a better performance implementation
|
|
-- when the shape is known without having to run the graph.
|
|
-- TODO(fmayle): We shouldn't convert the result to a dense tensor. Sparse
|
|
-- tensor support will require some thinking.
|
|
[ Just $ CoreOps.unsortedSegmentSum values indices' numRows
|
|
, Nothing
|
|
]
|
|
where
|
|
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
|
denseShape = shape (x :: Tensor Value a)
|
|
numRows = CoreOps.slice denseShape 0 (1 :: Tensor Value Int32)
|
|
valuesShape = CoreOps.concat 0 [
|
|
allDimensions
|
|
, CoreOps.slice denseShape 1 (-1 :: Tensor Value Int32)
|
|
]
|
|
values = reshape dz valuesShape
|
|
-- TODO(fmayle): This could be either Int32 or Int64.
|
|
indices' = reshape indices allDimensions :: Tensor Value Int32
|
|
|
|
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
|
|
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
|
|
where
|
|
sx = shape (x :: Tensor Value a)
|
|
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
|
x' = reshape x outputShapeKeptDims
|
|
dz' = reshape dz outputShapeKeptDims
|
|
indicators = CoreOps.cast $ CoreOps.equal x' x
|
|
numSelected = reshape (sum indicators indices) outputShapeKeptDims
|
|
|
|
-- Min and Max have identical gradient implementations.
|
|
opGrad "Min" u v w = opGrad "Max" u v w
|
|
|
|
opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
|
|
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
|
|
where
|
|
-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
|
|
sx = shape (x :: Tensor Value a)
|
|
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
|
tileScaling = safeShapeDiv sx outputShapeKeptDims
|
|
grad = reshape dz outputShapeKeptDims
|
|
|
|
opGrad "Mean" u v@[toT -> x, _] w =
|
|
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
|
|
where
|
|
[Just dz, Nothing] = opGrad "Sum" u v w
|
|
inputShape = shape (x :: Tensor Value a)
|
|
outputShape = shape (dz :: Tensor Value a)
|
|
-- TODO(fmayle): Add fast path when shape is known.
|
|
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
|
|
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
|
|
factor = safeShapeDiv inputSize outputSize
|
|
|
|
opGrad "Add" _ [toT -> x, toT -> y] [dz] =
|
|
[ Just $ reshape (sum dz rx) sx
|
|
, Just $ reshape (sum dz ry) sy ]
|
|
where
|
|
sx = shape (x :: Tensor Value a)
|
|
sy = shape (y :: Tensor Value a)
|
|
(rx, ry) = broadcastGradientArgs sx sy
|
|
|
|
opGrad "Sub" u v w =
|
|
[Just x, Just (-y)]
|
|
where
|
|
[Just x, Just y] = opGrad "Add" u v w
|
|
|
|
opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
|
|
[ Just $ expandDims dz (-1) * snd (softmaxCrossEntropyWithLogits x y)
|
|
, Nothing ]
|
|
|
|
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
|
|
-- TODO(fmayle): Handle complex numbers.
|
|
[ Just $ reshape (sum (dz * y) rx) sx
|
|
, Just $ reshape (sum (x * dz) ry) sy ]
|
|
where
|
|
sx = shape (x :: Tensor Value a)
|
|
sy = shape (y :: Tensor Value a)
|
|
(rx, ry) = broadcastGradientArgs sx sy
|
|
|
|
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
|
|
-- TODO(fmayle): Handle complex numbers.
|
|
-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
|
|
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
|
|
, Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy
|
|
]
|
|
where
|
|
sx = shape (x :: Tensor Value a)
|
|
sy = shape (y :: Tensor Value a)
|
|
(rx, ry) = broadcastGradientArgs sx sy
|
|
|
|
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
|
let transposeA = lookupAttr nodeDef "transpose_a"
|
|
transposeB = lookupAttr nodeDef "transpose_b"
|
|
transAttrs a b =
|
|
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
|
|
in case (transposeA, transposeB) of
|
|
(False, False) ->
|
|
[ Just $ (dz `matMul` y) & transAttrs False True
|
|
, Just $ (x `matMul` dz) & transAttrs True False ]
|
|
(False, True) ->
|
|
[ Just $ dz `matMul` y
|
|
, Just $ (x `matMul` dz) & transAttrs True False ]
|
|
(True, False) ->
|
|
[ Just $ (dz `matMul` y) & transAttrs False True
|
|
, Just $ x `matMul` dz ]
|
|
(True, True) ->
|
|
[ Just $ (dz `matMul` y) & transAttrs True True
|
|
, Just $ (x `matMul` dz) & transAttrs True True ]
|
|
|
|
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
|
[ Just $ CoreOps.transpose dz
|
|
(CoreOps.invertPermutation p :: Tensor Value Int32)
|
|
, Nothing
|
|
]
|
|
|
|
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
|
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
|
|
& tensorAttr "strides" .~ strides
|
|
& tensorAttr "padding" .~ padding
|
|
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
|
& tensorAttr "data_format" .~ dataFormat
|
|
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
|
|
& tensorAttr "strides" .~ strides
|
|
& tensorAttr "padding" .~ padding
|
|
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
|
& tensorAttr "data_format" .~ dataFormat
|
|
]
|
|
where
|
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
|
padding = lookupAttr nodeDef "padding" :: ByteString
|
|
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
|
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
|
|
|
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
|
[ Just $ CoreOps.maxPoolGrad x output dz
|
|
& tensorAttr "ksize" .~ ksize
|
|
& tensorAttr "strides" .~ strides
|
|
& tensorAttr "padding" .~ padding
|
|
& tensorAttr "data_format" .~ dataFormat
|
|
]
|
|
where
|
|
output :: Tensor Value a
|
|
output = toT $ Output 0 (Rendered nodeDef)
|
|
ksize = lookupAttr nodeDef "ksize" :: [Int64]
|
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
|
padding = lookupAttr nodeDef "padding" :: ByteString
|
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
|
|
|
opGrad "Reshape" _ [toT -> x, _] [dz] =
|
|
[Just $ reshape dz $ shape (x :: Tensor Value a), Nothing]
|
|
|
|
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
|
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
|
|
|
opGrad "RefIdentity" _ _ [dz] = [Just dz]
|
|
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
|
|
where
|
|
-- TODO(gnezdo): too permissive, python only allows float types as src_type.
|
|
reverseCast =
|
|
buildOp (opDef "Cast"
|
|
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
|
|
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString))
|
|
dz
|
|
|
|
opGrad "DynamicStitch" nodeDef inputs [dz] =
|
|
replicate halfLen Nothing ++ valuesGrads
|
|
where
|
|
halfLen =
|
|
let len = length inputs
|
|
half = len `div` 2
|
|
in if 2 * half == len
|
|
then half
|
|
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
|
|
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32)
|
|
| idx <- take halfLen inputs
|
|
]
|
|
|
|
opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
|
|
[ Just reconstructed, Nothing ]
|
|
where
|
|
reconstructed = CoreOps.reshape stitched
|
|
(CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32)
|
|
stitched = CoreOps.dynamicStitch partitionedIndices dz
|
|
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
|
|
np = lookupAttr nodeDef "num_partitions" :: Int64
|
|
originalIndices =
|
|
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
|
|
prefixShape = shapeInt32 indices
|
|
shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32
|
|
|
|
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
|
|
[ Nothing
|
|
, Just $ CoreOps.select c dz zeros
|
|
, Just $ CoreOps.select c zeros dz
|
|
]
|
|
where zeros = CoreOps.zerosLike x
|
|
|
|
-- TODO(gnezdo): Unlike Python, no control dependency on dz.
|
|
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ]
|
|
-- TODO(gnezdo): Reuse the output instead of doing another exp,
|
|
-- though, it is probably CSE'd away anyway.
|
|
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ]
|
|
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
|
[ Just $ CoreOps.unsortedSegmentSum
|
|
(CoreOps.gather dz (t :: Tensor Value Int32))
|
|
(y :: Tensor Value Int32) inputRows
|
|
, Nothing
|
|
, Nothing
|
|
]
|
|
where inputRows = CoreOps.slice (shape (x :: Tensor Value a)) (scalar (0 :: Int32)) (scalar 1)
|
|
|
|
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
|
opGrad "LabelWeights" _ _ _ = [Nothing]
|
|
opGrad "Size" _ _ _ = [Nothing]
|
|
opGrad "ZerosLike" _ _ _ = [Nothing]
|
|
|
|
-- TODO(fmayle): These can go away if we properly prune the graph.
|
|
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
|
opGrad "Placeholder" _ _ _ = []
|
|
opGrad "Variable" _ _ _ = []
|
|
|
|
opGrad n nodeDef ins grads =
|
|
error $ "no gradient implemented for " ++
|
|
show (n, length ins, length grads, showMessage nodeDef, ins)
|
|
|
|
-- | The number of outputs for an op type.
|
|
numOutputs :: NodeDef -> OutputIx
|
|
numOutputs o =
|
|
case o ^. op of
|
|
"Abs" -> 1
|
|
"Add" -> 1
|
|
"Cast" -> 1
|
|
"Const" -> 1
|
|
"Conv2D" -> 1
|
|
"Div" -> 1
|
|
"DynamicStitch" -> 1
|
|
"DynamicPartition" ->
|
|
fromIntegral (lookupAttr o "num_partitions" :: Int64)
|
|
"Exp" -> 1
|
|
"Gather" -> 1
|
|
"LabelClasses" -> 1
|
|
"LabelWeights" -> 1
|
|
"Log" -> 1
|
|
"MatMul" -> 1
|
|
"Max" -> 1
|
|
"MaxPool" -> 1
|
|
"Mean" -> 1
|
|
"Min" -> 1
|
|
"Mul" -> 1
|
|
"Neg" -> 1
|
|
"Placeholder" -> 1
|
|
"OneHot" -> 1
|
|
"RefIdentity" -> 1
|
|
"Relu" -> 1
|
|
"Reshape" -> 1
|
|
"Select" -> 1
|
|
"Size" -> 1
|
|
"SoftmaxCrossEntropyWithLogits" -> 2
|
|
"Square" -> 1
|
|
"SparseSegmentSum" -> 1
|
|
"Sub" -> 1
|
|
"Sum" -> 1
|
|
"Transpose" -> 1
|
|
"TruncatedNormal" -> 1
|
|
"Variable" -> 1
|
|
"ZerosLike" -> 1
|
|
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
|
|
|
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
|
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
|
|
|
allDimensions = vector [-1 :: Int32]
|
|
|
|
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
|
|
|
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens
|