Avoid computing gradients for incidental nodes (#238)

This commit is contained in:
Christian Berentsen 2019-04-11 20:17:19 +02:00 committed by fkm3
parent 96f1c88327
commit c0f87dc0bc
2 changed files with 89 additions and 13 deletions

View File

@ -36,6 +36,7 @@ import Data.Int (Int32, Int64)
import Data.Foldable (foldlM)
import Data.List (foldl', sortBy)
import Data.Map.Strict (Map)
import qualified Data.IntSet as IntSet
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
import Data.Ord (comparing)
import Data.ProtoLens.TextFormat (showMessage)
@ -165,6 +166,11 @@ gradients y xs = build $ do
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
. flip Map.lookup
let (gr, nodeMap) = createGraph yName nodeDefLookup
xnodes = mapMaybe (\x -> nodeMap ^. (at . outputNodeName . renderedOutput $ x)) xs
-- make a set of the nodes reachable from the xnodes
-- The xnodes are not part of this set (unless reachable from another xnode)
reachableSet = computeReachableSet xnodes gr
-- Set gradient of y to one.
-- TODO: nicer
let initPending :: Map.Map FGL.Node (PendingGradients a)
@ -175,7 +181,7 @@ gradients y xs = build $ do
.~ [yOne]
)
-- Calculate the gradients of y w.r.t. each node in the graph.
gradientMap <- graphGrads gr initPending
gradientMap <- graphGrads gr reachableSet initPending
-- Lookup the gradients for each x.
forM xs $ \x ->
let Output i xName = renderedOutput x
@ -183,6 +189,13 @@ gradients y xs = build $ do
n <- nodeMap ^. at xName
gradientMap ^. at n . nonEmpty . outputIxAt i
-- | Compute a set of nodes reachable from the start nodes
--
-- the start nodes are excluded, unless reachable from another start node
computeReachableSet :: [FGL.Node] -> Graph -> IntSet.IntSet
computeReachableSet vs g =
IntSet.fromList $ concatMap (drop 1 . FGL.preorder) (FGL.dff vs g)
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
outputIxAt = intAt . unOutputIx
@ -245,16 +258,15 @@ nonEmpty = anon mempty null
-- | Calculate the gradients for every node in a graph.
graphGrads :: forall a. GradientCompatible a
=> Graph
-> IntSet.IntSet
-> Map FGL.Node (PendingGradients a)
-- ^ Initial gradients (usually just 1 for the node of interest).
-> Build (Map FGL.Node (Gradients a))
graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
graphGrads gr reachableSet initPending = view gradientsResult <$> foldlM go initState nodeOrder
where
initState = GradientsState initPending Map.empty
-- Reverse topological sort.
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
-- avoid calculating gradients that won't be used.
nodeOrder = FGL.topsort $ FGL.grev gr
nodeOrder = FGL.topsort . FGL.grev $ gr
go :: GradientsState a -> Int -> Build (GradientsState a)
go state node = do
-- Aggregate the accumulated gradients for this node.
@ -263,11 +275,17 @@ graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrd
if null outputGrads
then pure state
else do
let ctx = FGL.context gr node
inputGrads <- calculateInputGrads ctx outputGrads gr
-- Calculate the gradients for each of the node's inputs.
let nextState = state & gradientsResult %~ Map.insert node outputGrads
pure $ updatePendingGradients ctx inputGrads nextState
-- Only consider nodes that are reachable from the inputs to
-- avoid calculating gradients that won't be used.
if node `IntSet.member` reachableSet
then do
let ctx = FGL.context gr node
inputGrads <- calculateInputGrads ctx outputGrads gr
-- Calculate the gradients for each of the node's inputs.
pure $ updatePendingGradients ctx inputGrads nextState
else
pure nextState
-- | Reduce accumulated gradients for each output to one Tensor.
sumPendingGradient :: GradientCompatible a
@ -839,11 +857,8 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
-- through each read.
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
-- TODO(fmayle): These can go away if we properly prune the graph.
opGrad "Const" _ _ _ = [Nothing, Nothing]
opGrad "Placeholder" _ _ _ = []
opGrad "VarHandleOp" _ _ _ = []
opGrad "Variable" _ _ _ = []
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
where

View File

@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag)
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
import qualified TensorFlow.Output as TF
@ -123,6 +123,65 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
]
sort expected @=? sort ops
testGradientIncidental :: Test
testGradientIncidental = testCase "testGradientIncidental" $ do
let grads = do
x <- TF.render $ TF.scalar (3 :: Float)
b <- TF.render $ TF.scalar (4 :: Float)
w <- TF.render $ TF.diag $ TF.vector [ 1.0 :: Float ]
let incidental = b `TF.mul` w
let y = (x `TF.mul` b) `TF.add` incidental
TF.gradients y [x]
-- Assert that the gradients are right.
[dx] <- TF.runSession $ grads >>= TF.run
4 @=? TF.unScalar dx
-- Assert that the graph has the expected ops.
let graphDef = TF.asGraphDef grads
putStrLn $ showMessage graphDef
let ops = graphDef ^.. node . traverse . op
expected = [ "Add"
, "BroadcastGradientArgs"
, "BroadcastGradientArgs"
, "Const"
, "Const"
, "Const"
, "Const"
, "Diag"
, "Fill"
, "Mul"
, "Mul"
, "Mul"
, "Mul"
, "Reshape"
, "Reshape"
, "Reshape"
, "Reshape"
, "Shape"
, "Shape"
, "Shape"
, "Shape"
, "Shape"
, "Sum"
, "Sum"
, "Sum"
, "Sum"
]
sort expected @=? sort ops
testGradientPruning :: Test
testGradientPruning = testCase "testGradientPruning" $ do
let grads = do
x <- TF.render $ TF.scalar (3 :: Float)
b <- TF.render $ TF.scalar (4 :: Float)
bx <- TF.render $ b `TF.mul` x
let y = bx `TF.add` b
TF.gradients y [x, bx]
-- Assert that the gradients are right.
[dx, dxb] <- TF.runSession $ grads >>= TF.run
4 @=? TF.unScalar dx
1 @=? TF.unScalar dxb
-- Test that identical "stateful" ops work with createGraph.
testCreateGraphStateful :: Test
@ -545,6 +604,8 @@ main :: IO ()
main = defaultMain
[ testGradientSimple
, testGradientDisconnected
, testGradientIncidental
, testGradientPruning
, testCreateGraphStateful
, testCreateGraphNameScopes
, testDiamond