mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-27 11:15:03 +01:00
Avoid computing gradients for incidental nodes (#238)
This commit is contained in:
parent
96f1c88327
commit
c0f87dc0bc
2 changed files with 89 additions and 13 deletions
|
@ -36,6 +36,7 @@ import Data.Int (Int32, Int64)
|
|||
import Data.Foldable (foldlM)
|
||||
import Data.List (foldl', sortBy)
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.IntSet as IntSet
|
||||
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
||||
import Data.Ord (comparing)
|
||||
import Data.ProtoLens.TextFormat (showMessage)
|
||||
|
@ -165,6 +166,11 @@ gradients y xs = build $ do
|
|||
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
||||
. flip Map.lookup
|
||||
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
||||
xnodes = mapMaybe (\x -> nodeMap ^. (at . outputNodeName . renderedOutput $ x)) xs
|
||||
-- make a set of the nodes reachable from the xnodes
|
||||
-- The xnodes are not part of this set (unless reachable from another xnode)
|
||||
reachableSet = computeReachableSet xnodes gr
|
||||
|
||||
-- Set gradient of y to one.
|
||||
-- TODO: nicer
|
||||
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
||||
|
@ -175,7 +181,7 @@ gradients y xs = build $ do
|
|||
.~ [yOne]
|
||||
)
|
||||
-- Calculate the gradients of y w.r.t. each node in the graph.
|
||||
gradientMap <- graphGrads gr initPending
|
||||
gradientMap <- graphGrads gr reachableSet initPending
|
||||
-- Lookup the gradients for each x.
|
||||
forM xs $ \x ->
|
||||
let Output i xName = renderedOutput x
|
||||
|
@ -183,6 +189,13 @@ gradients y xs = build $ do
|
|||
n <- nodeMap ^. at xName
|
||||
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||
|
||||
-- | Compute a set of nodes reachable from the start nodes
|
||||
--
|
||||
-- the start nodes are excluded, unless reachable from another start node
|
||||
computeReachableSet :: [FGL.Node] -> Graph -> IntSet.IntSet
|
||||
computeReachableSet vs g =
|
||||
IntSet.fromList $ concatMap (drop 1 . FGL.preorder) (FGL.dff vs g)
|
||||
|
||||
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
||||
outputIxAt = intAt . unOutputIx
|
||||
|
||||
|
@ -245,16 +258,15 @@ nonEmpty = anon mempty null
|
|||
-- | Calculate the gradients for every node in a graph.
|
||||
graphGrads :: forall a. GradientCompatible a
|
||||
=> Graph
|
||||
-> IntSet.IntSet
|
||||
-> Map FGL.Node (PendingGradients a)
|
||||
-- ^ Initial gradients (usually just 1 for the node of interest).
|
||||
-> Build (Map FGL.Node (Gradients a))
|
||||
graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
|
||||
graphGrads gr reachableSet initPending = view gradientsResult <$> foldlM go initState nodeOrder
|
||||
where
|
||||
initState = GradientsState initPending Map.empty
|
||||
-- Reverse topological sort.
|
||||
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
|
||||
-- avoid calculating gradients that won't be used.
|
||||
nodeOrder = FGL.topsort $ FGL.grev gr
|
||||
nodeOrder = FGL.topsort . FGL.grev $ gr
|
||||
go :: GradientsState a -> Int -> Build (GradientsState a)
|
||||
go state node = do
|
||||
-- Aggregate the accumulated gradients for this node.
|
||||
|
@ -263,11 +275,17 @@ graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrd
|
|||
if null outputGrads
|
||||
then pure state
|
||||
else do
|
||||
let ctx = FGL.context gr node
|
||||
inputGrads <- calculateInputGrads ctx outputGrads gr
|
||||
-- Calculate the gradients for each of the node's inputs.
|
||||
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
||||
pure $ updatePendingGradients ctx inputGrads nextState
|
||||
-- Only consider nodes that are reachable from the inputs to
|
||||
-- avoid calculating gradients that won't be used.
|
||||
if node `IntSet.member` reachableSet
|
||||
then do
|
||||
let ctx = FGL.context gr node
|
||||
inputGrads <- calculateInputGrads ctx outputGrads gr
|
||||
-- Calculate the gradients for each of the node's inputs.
|
||||
pure $ updatePendingGradients ctx inputGrads nextState
|
||||
else
|
||||
pure nextState
|
||||
|
||||
-- | Reduce accumulated gradients for each output to one Tensor.
|
||||
sumPendingGradient :: GradientCompatible a
|
||||
|
@ -839,11 +857,8 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
|||
-- through each read.
|
||||
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
||||
|
||||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "Placeholder" _ _ _ = []
|
||||
opGrad "VarHandleOp" _ _ _ = []
|
||||
opGrad "Variable" _ _ _ = []
|
||||
|
||||
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
|
||||
where
|
||||
|
|
|
@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||
import qualified TensorFlow.Output as TF
|
||||
|
@ -123,6 +123,65 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
|||
]
|
||||
sort expected @=? sort ops
|
||||
|
||||
testGradientIncidental :: Test
|
||||
testGradientIncidental = testCase "testGradientIncidental" $ do
|
||||
let grads = do
|
||||
x <- TF.render $ TF.scalar (3 :: Float)
|
||||
b <- TF.render $ TF.scalar (4 :: Float)
|
||||
w <- TF.render $ TF.diag $ TF.vector [ 1.0 :: Float ]
|
||||
let incidental = b `TF.mul` w
|
||||
let y = (x `TF.mul` b) `TF.add` incidental
|
||||
TF.gradients y [x]
|
||||
|
||||
-- Assert that the gradients are right.
|
||||
[dx] <- TF.runSession $ grads >>= TF.run
|
||||
4 @=? TF.unScalar dx
|
||||
-- Assert that the graph has the expected ops.
|
||||
let graphDef = TF.asGraphDef grads
|
||||
putStrLn $ showMessage graphDef
|
||||
let ops = graphDef ^.. node . traverse . op
|
||||
expected = [ "Add"
|
||||
, "BroadcastGradientArgs"
|
||||
, "BroadcastGradientArgs"
|
||||
, "Const"
|
||||
, "Const"
|
||||
, "Const"
|
||||
, "Const"
|
||||
, "Diag"
|
||||
, "Fill"
|
||||
, "Mul"
|
||||
, "Mul"
|
||||
, "Mul"
|
||||
, "Mul"
|
||||
, "Reshape"
|
||||
, "Reshape"
|
||||
, "Reshape"
|
||||
, "Reshape"
|
||||
, "Shape"
|
||||
, "Shape"
|
||||
, "Shape"
|
||||
, "Shape"
|
||||
, "Shape"
|
||||
, "Sum"
|
||||
, "Sum"
|
||||
, "Sum"
|
||||
, "Sum"
|
||||
]
|
||||
sort expected @=? sort ops
|
||||
|
||||
testGradientPruning :: Test
|
||||
testGradientPruning = testCase "testGradientPruning" $ do
|
||||
let grads = do
|
||||
x <- TF.render $ TF.scalar (3 :: Float)
|
||||
b <- TF.render $ TF.scalar (4 :: Float)
|
||||
bx <- TF.render $ b `TF.mul` x
|
||||
let y = bx `TF.add` b
|
||||
TF.gradients y [x, bx]
|
||||
|
||||
-- Assert that the gradients are right.
|
||||
[dx, dxb] <- TF.runSession $ grads >>= TF.run
|
||||
4 @=? TF.unScalar dx
|
||||
1 @=? TF.unScalar dxb
|
||||
|
||||
-- Test that identical "stateful" ops work with createGraph.
|
||||
testCreateGraphStateful :: Test
|
||||
|
@ -545,6 +604,8 @@ main :: IO ()
|
|||
main = defaultMain
|
||||
[ testGradientSimple
|
||||
, testGradientDisconnected
|
||||
, testGradientIncidental
|
||||
, testGradientPruning
|
||||
, testCreateGraphStateful
|
||||
, testCreateGraphNameScopes
|
||||
, testDiamond
|
||||
|
|
Loading…
Add table
Reference in a new issue