module TensorFlow.Gradient
( gradients
) where
import Control.Monad (forM, zipWithM)
import Control.Monad.State.Strict (State, evalState, gets, modify)
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int32, Int64)
import Data.List (foldl', sortBy)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
import Data.Ord (comparing)
import Data.ProtoLens.TextFormat (showMessage)
import Data.Set (Set)
import Data.Text (Text)
import Data.Tuple (swap)
import Lens.Family2 (Lens', (&), (^.), (.~), (%~))
import Lens.Family2.State.Strict (uses)
import Lens.Family2.Stock (at, intAt)
import Lens.Family2.Unchecked (lens, iso)
import Prelude hiding (sum)
import Text.Printf (printf)
import qualified Data.Graph.Inductive.Basic as FGL
import qualified Data.Graph.Inductive.Graph as FGL
import qualified Data.Graph.Inductive.PatriciaTree as FGL
import qualified Data.Graph.Inductive.Query.DFS as FGL
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as Text
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Build
( Build
, render
, renderNodeName
, renderedNodeDefs
, opDef
, opAttr
)
import TensorFlow.BuildOp
import TensorFlow.Ops
( addN
, broadcastGradientArgs
, expandDims
, fill
, matMul
, reducedShape
, reluGrad
, reshape
, scalar
, shape
, softmaxCrossEntropyWithLogits
, sum
, scalarize
, vector
, zerosLike
)
import TensorFlow.Output
( NodeName(..)
, Op (Rendered)
, Output(..)
, OutputIx(..)
, outputIndex
)
import TensorFlow.Tensor
( Tensor(..)
, TensorKind (ValueKind)
, Value
, tensorOutput
, tensorAttr
)
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef
(NodeDef, attr, input, op, name)
type GradientCompatible a =
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
, v1 ~ Value
, GradientCompatible a
)
=> Tensor v1 a
-> [Tensor v2 a]
-> Build [Tensor Value a]
gradients y xs = do
yName <- renderNodeName y
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
. flip Map.lookup
let (gr, nodeMap) = createGraph yName nodeDefLookup
let initPending :: Map.Map FGL.Node (PendingGradients a)
initPending = Map.empty & at (nodeMap Map.! yName)
. nonEmpty
. outputIxAt (y ^. tensorOutput . outputIndex)
. nonEmpty
.~ [fill (shape y) (scalar 1)]
gradientMap <- graphGrads gr initPending
forM xs $ \x -> do
xName <- renderNodeName x
render $ fromMaybe (zerosLike x) $ do
n <- nodeMap ^. at xName
let i = x ^. tensorOutput . outputIndex
gradientMap ^. at n . nonEmpty . outputIxAt i
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
outputIxAt = intAt . unOutputIx
type PendingGradients a = IntMap.IntMap [Tensor Value a]
type Gradients a = IntMap.IntMap (Tensor Value a)
type Graph = FGL.Gr NodeDef EdgeLabel
type EdgeLabel = (OutputIx, OutputIx)
data GradientsState a = GradientsState
{ _gradientsPending :: !(Map FGL.Node (PendingGradients a))
, _gradientsResult :: !(Map FGL.Node (Gradients a))
}
gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a))
gradientsPending = lens _gradientsPending (\x y -> x { _gradientsPending = y })
gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a))
gradientsResult = lens _gradientsResult (\x y -> x { _gradientsResult = y })
safeIndex :: [a] -> Int -> Maybe a
_ `safeIndex` n | n < 0 = Nothing
[] `safeIndex` _ = Nothing
(x:_) `safeIndex` 0 = Just x
(_:xs) `safeIndex` n = xs `safeIndex` (n1)
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
anon a p = iso (fromMaybe a) go where
go b | p b = Nothing
| otherwise = Just b
non :: Eq a => a -> Lens' (Maybe a) a
non a = anon a (a==)
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
nonEmpty = anon mempty null
graphGrads :: forall a. GradientCompatible a
=> Graph
-> Map FGL.Node (PendingGradients a)
-> Build (Map FGL.Node (Gradients a))
graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult)
where
initState = GradientsState initPending Map.empty
nodeOrder = FGL.topsort $ FGL.grev gr
go state node =
let outputGrads =
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
in if null outputGrads
then state
else
let nextState = state & gradientsResult %~ Map.insert node outputGrads
ctx = FGL.context gr node
in updatePendingGradients
ctx
(calculateInputGrads ctx outputGrads gr)
nextState
sumPendingGradient :: GradientCompatible a
=> PendingGradients a -> Gradients a
sumPendingGradient = IntMap.mapMaybe f
where
f [] = Nothing
f [x] = Just x
f xs = Just (addN xs)
calculateInputGrads :: forall a. GradientCompatible a
=> FGL.Context NodeDef EdgeLabel
-> Gradients a
-> Graph
-> [Maybe (Tensor Value a)]
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
where
fullOutGrads =
fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
edgeToTensor ((i, _), n) =
case FGL.lab gr n of
Just edgeNodeDef -> Output i (Rendered edgeNodeDef)
Nothing -> error $ "calculateInputGrads: missing input node for "
++ Text.unpack (nodeDef ^. name)
inputTensors = map edgeToTensor $ sortBy (comparing (snd . fst)) inputEdges
fullOutputGrads :: (TensorType a, Num a)
=> OutputIx
-> Op
-> Gradients a
-> [Tensor Value a]
fullOutputGrads n o gs =
map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n1]
where
zero i = zerosLike $ toT (Output i o)
updatePendingGradients :: forall a. (TensorType a, Num a)
=> FGL.Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)]
-> GradientsState a
-> GradientsState a
updatePendingGradients (inputEdges, _, nodeDef, _) inputGrads initState =
foldl' go initState inputEdges
where
go :: GradientsState a -> (EdgeLabel, FGL.Node) -> GradientsState a
go state ((outIndex, OutputIx inIndex), node) =
case maybeGradient of
Nothing -> state
Just g ->
state & gradientsPending
. at node
. nonEmpty
. outputIxAt outIndex
. nonEmpty
%~ (g:)
where
badSizeErr = error $ printf "updatePendingGradients: bad input index \
\%d for inputGrads of length %d in %s"
inIndex (length inputGrads)
(show (nodeDef ^. name))
maybeGradient = fromMaybe badSizeErr (safeIndex inputGrads inIndex)
createGraph :: NodeName -> (NodeName -> NodeDef)
-> (Graph, Map NodeName FGL.Node)
createGraph nodeName nodeDefLookup = (FGL.nmap nodeDefLookup graph, nodeMap)
where
parseTensorName :: Text -> Maybe (NodeName, OutputIx)
parseTensorName n
| Text.null n = error "parseTensorName: empty name"
| Text.head n == '^' = Nothing
| otherwise =
let (nm, indexStr) = Text.breakOn ":" n
index | Text.null indexStr = 0
| otherwise = read $ Text.unpack $ Text.tail indexStr
in Just (NodeName nm, OutputIx index)
collect :: Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State (Set NodeName)
(Map NodeName [(NodeName, OutputIx, OutputIx)])
collect outgoingEdge nm = do
let nextLookup = Map.singleton nm (maybeToList outgoingEdge)
seen <- gets (Set.member nm)
modify (Set.insert nm)
if seen
then pure nextLookup
else do
let inputs = nodeDefLookup nm ^. input
recurse inIndex (parentName, outIndex) =
collect (Just (nm, outIndex, inIndex)) parentName
subEdgeLookups <-
zipWithM recurse [0..] $ mapMaybe parseTensorName inputs
pure $ Map.unionsWith (++) (nextLookup:subEdgeLookups)
edgeLookup = evalState (collect Nothing nodeName) Set.empty
nodeMap = Map.fromList $ zip (Map.keys edgeLookup) [0..]
graph = FGL.mkGraph (swap <$> Map.toList nodeMap)
[ (nodeMap Map.! n, nodeMap Map.! m, (i, j))
| (n, edges) <- Map.toList edgeLookup
, (m, i, j) <- edges
]
type GradientFunc a = NodeDef
-> [Output]
-> [Tensor Value a]
-> [Maybe (Tensor Value a)]
toT :: Output -> Tensor Value a
toT = Tensor ValueKind
flatSlice :: forall v1 t . (TensorType t)
=> Tensor v1 t
-> Int32
-> Int32
-> Tensor Value t
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x]
opGrad "Neg" _ [_] [dz] = [Just $ dz]
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
opGrad "Square" _ [toT -> x] [dz] =
[Just $ dz * (2 * x)]
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
[ Just $ CoreOps.unsortedSegmentSum values indices' numRows
, Nothing
]
where
denseShape = shape (x :: Tensor Value a)
numRows = scalarize $ flatSlice denseShape 0 1
valuesShape = CoreOps.concat 0 [ allDimensions
, flatSlice denseShape 1 (1)
]
values = reshape dz valuesShape
indices' = reshape indices allDimensions :: Tensor Value Int32
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
where
sx = shape (x :: Tensor Value a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
x' = reshape x outputShapeKeptDims
dz' = reshape dz outputShapeKeptDims
indicators = CoreOps.cast $ CoreOps.equal x' x
numSelected = reshape (sum indicators indices) outputShapeKeptDims
opGrad "Min" u v w = opGrad "Max" u v w
opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
where
sx = shape (x :: Tensor Value a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
tileScaling = safeShapeDiv sx outputShapeKeptDims
grad = reshape dz outputShapeKeptDims
opGrad "Mean" u v@[toT -> x, _] w =
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
where
[Just dz, Nothing] = opGrad "Sum" u v w
inputShape = shape (x :: Tensor Value a)
outputShape = shape (dz :: Tensor Value a)
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
factor = safeShapeDiv inputSize outputSize
opGrad "Add" _ [toT -> x, toT -> y] [dz] =
[ Just $ reshape (sum dz rx) sx
, Just $ reshape (sum dz ry) sy ]
where
sx = shape (x :: Tensor Value a)
sy = shape (y :: Tensor Value a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "Sub" u v w =
[Just x, Just (y)]
where
[Just x, Just y] = opGrad "Add" u v w
opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
[ Just $ expandDims dz (1) * snd (softmaxCrossEntropyWithLogits x y)
, Nothing ]
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
[ Just $ reshape (sum (dz * y) rx) sx
, Just $ reshape (sum (x * dz) ry) sy ]
where
sx = shape (x :: Tensor Value a)
sy = shape (y :: Tensor Value a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
, Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy
]
where
sx = shape (x :: Tensor Value a)
sy = shape (y :: Tensor Value a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
let transposeA = lookupAttr nodeDef "transpose_a"
transposeB = lookupAttr nodeDef "transpose_b"
transAttrs a b =
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
in case (transposeA, transposeB) of
(False, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True
, Just $ (x `matMul` dz) & transAttrs True False ]
(False, True) ->
[ Just $ dz `matMul` y
, Just $ (x `matMul` dz) & transAttrs True False ]
(True, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True
, Just $ x `matMul` dz ]
(True, True) ->
[ Just $ (dz `matMul` y) & transAttrs True True
, Just $ (x `matMul` dz) & transAttrs True True ]
opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz
(CoreOps.invertPermutation p :: Tensor Value Int32)
, Nothing
]
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
& tensorAttr "data_format" .~ dataFormat
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
& tensorAttr "data_format" .~ dataFormat
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
[ Just $ CoreOps.maxPoolGrad x output dz
& tensorAttr "ksize" .~ ksize
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "data_format" .~ dataFormat
]
where
output :: Tensor Value a
output = toT $ Output 0 (Rendered nodeDef)
ksize = lookupAttr nodeDef "ksize" :: [Int64]
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "Reshape" _ [toT -> x, _] [dz] =
[Just $ reshape dz $ shape (x :: Tensor Value a), Nothing]
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
opGrad "TruncatedNormal" _ _ _ = [Nothing]
opGrad "RefIdentity" _ _ [dz] = [Just dz]
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
where
reverseCast =
buildOp (opDef "Cast"
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString))
dz
opGrad "DynamicStitch" nodeDef inputs [dz] =
replicate halfLen Nothing ++ valuesGrads
where
halfLen =
let len = length inputs
half = len `div` 2
in if 2 * half == len
then half
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32)
| idx <- take halfLen inputs
]
opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
[ Just reconstructed, Nothing ]
where
reconstructed = CoreOps.reshape stitched
(CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32)
stitched = CoreOps.dynamicStitch partitionedIndices dz
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
np = lookupAttr nodeDef "num_partitions" :: Int64
originalIndices =
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
prefixShape = shapeInt32 indices
shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
[ Nothing
, Just $ CoreOps.select c dz zeros
, Just $ CoreOps.select c zeros dz
]
where zeros = CoreOps.zerosLike x
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ]
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ]
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
[ Just $ CoreOps.unsortedSegmentSum
(CoreOps.gather dz (t :: Tensor Value Int32))
(y :: Tensor Value Int32) inputRows
, Nothing
, Nothing
]
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing]
opGrad "Size" _ _ _ = [Nothing]
opGrad "ZerosLike" _ _ _ = [Nothing]
opGrad "Const" _ _ _ = [Nothing, Nothing]
opGrad "Placeholder" _ _ _ = []
opGrad "Variable" _ _ _ = []
opGrad n nodeDef ins grads =
error $ "no gradient implemented for " ++
show (n, length ins, length grads, showMessage nodeDef, ins)
numOutputs :: NodeDef -> OutputIx
numOutputs o =
case o ^. op of
"Abs" -> 1
"Add" -> 1
"Cast" -> 1
"Const" -> 1
"Conv2D" -> 1
"Div" -> 1
"DynamicStitch" -> 1
"DynamicPartition" ->
fromIntegral (lookupAttr o "num_partitions" :: Int64)
"Exp" -> 1
"Gather" -> 1
"LabelClasses" -> 1
"LabelWeights" -> 1
"Log" -> 1
"MatMul" -> 1
"Max" -> 1
"MaxPool" -> 1
"Mean" -> 1
"Min" -> 1
"Mul" -> 1
"Neg" -> 1
"Placeholder" -> 1
"OneHot" -> 1
"RefIdentity" -> 1
"Relu" -> 1
"Reshape" -> 1
"Select" -> 1
"Size" -> 1
"SoftmaxCrossEntropyWithLogits" -> 2
"Square" -> 1
"SparseSegmentSum" -> 1
"Sub" -> 1
"Sum" -> 1
"Transpose" -> 1
"TruncatedNormal" -> 1
"Variable" -> 1
"ZerosLike" -> 1
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
allDimensions :: Tensor Value Int32
allDimensions = vector [1 :: Int32]
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens