mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Avoid computing gradients for incidental nodes (#238)
This commit is contained in:
parent
96f1c88327
commit
c0f87dc0bc
2 changed files with 89 additions and 13 deletions
|
@ -36,6 +36,7 @@ import Data.Int (Int32, Int64)
|
||||||
import Data.Foldable (foldlM)
|
import Data.Foldable (foldlM)
|
||||||
import Data.List (foldl', sortBy)
|
import Data.List (foldl', sortBy)
|
||||||
import Data.Map.Strict (Map)
|
import Data.Map.Strict (Map)
|
||||||
|
import qualified Data.IntSet as IntSet
|
||||||
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
||||||
import Data.Ord (comparing)
|
import Data.Ord (comparing)
|
||||||
import Data.ProtoLens.TextFormat (showMessage)
|
import Data.ProtoLens.TextFormat (showMessage)
|
||||||
|
@ -165,6 +166,11 @@ gradients y xs = build $ do
|
||||||
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
||||||
. flip Map.lookup
|
. flip Map.lookup
|
||||||
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
||||||
|
xnodes = mapMaybe (\x -> nodeMap ^. (at . outputNodeName . renderedOutput $ x)) xs
|
||||||
|
-- make a set of the nodes reachable from the xnodes
|
||||||
|
-- The xnodes are not part of this set (unless reachable from another xnode)
|
||||||
|
reachableSet = computeReachableSet xnodes gr
|
||||||
|
|
||||||
-- Set gradient of y to one.
|
-- Set gradient of y to one.
|
||||||
-- TODO: nicer
|
-- TODO: nicer
|
||||||
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
||||||
|
@ -175,7 +181,7 @@ gradients y xs = build $ do
|
||||||
.~ [yOne]
|
.~ [yOne]
|
||||||
)
|
)
|
||||||
-- Calculate the gradients of y w.r.t. each node in the graph.
|
-- Calculate the gradients of y w.r.t. each node in the graph.
|
||||||
gradientMap <- graphGrads gr initPending
|
gradientMap <- graphGrads gr reachableSet initPending
|
||||||
-- Lookup the gradients for each x.
|
-- Lookup the gradients for each x.
|
||||||
forM xs $ \x ->
|
forM xs $ \x ->
|
||||||
let Output i xName = renderedOutput x
|
let Output i xName = renderedOutput x
|
||||||
|
@ -183,6 +189,13 @@ gradients y xs = build $ do
|
||||||
n <- nodeMap ^. at xName
|
n <- nodeMap ^. at xName
|
||||||
gradientMap ^. at n . nonEmpty . outputIxAt i
|
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||||
|
|
||||||
|
-- | Compute a set of nodes reachable from the start nodes
|
||||||
|
--
|
||||||
|
-- the start nodes are excluded, unless reachable from another start node
|
||||||
|
computeReachableSet :: [FGL.Node] -> Graph -> IntSet.IntSet
|
||||||
|
computeReachableSet vs g =
|
||||||
|
IntSet.fromList $ concatMap (drop 1 . FGL.preorder) (FGL.dff vs g)
|
||||||
|
|
||||||
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
||||||
outputIxAt = intAt . unOutputIx
|
outputIxAt = intAt . unOutputIx
|
||||||
|
|
||||||
|
@ -245,16 +258,15 @@ nonEmpty = anon mempty null
|
||||||
-- | Calculate the gradients for every node in a graph.
|
-- | Calculate the gradients for every node in a graph.
|
||||||
graphGrads :: forall a. GradientCompatible a
|
graphGrads :: forall a. GradientCompatible a
|
||||||
=> Graph
|
=> Graph
|
||||||
|
-> IntSet.IntSet
|
||||||
-> Map FGL.Node (PendingGradients a)
|
-> Map FGL.Node (PendingGradients a)
|
||||||
-- ^ Initial gradients (usually just 1 for the node of interest).
|
-- ^ Initial gradients (usually just 1 for the node of interest).
|
||||||
-> Build (Map FGL.Node (Gradients a))
|
-> Build (Map FGL.Node (Gradients a))
|
||||||
graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
|
graphGrads gr reachableSet initPending = view gradientsResult <$> foldlM go initState nodeOrder
|
||||||
where
|
where
|
||||||
initState = GradientsState initPending Map.empty
|
initState = GradientsState initPending Map.empty
|
||||||
-- Reverse topological sort.
|
-- Reverse topological sort.
|
||||||
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
|
nodeOrder = FGL.topsort . FGL.grev $ gr
|
||||||
-- avoid calculating gradients that won't be used.
|
|
||||||
nodeOrder = FGL.topsort $ FGL.grev gr
|
|
||||||
go :: GradientsState a -> Int -> Build (GradientsState a)
|
go :: GradientsState a -> Int -> Build (GradientsState a)
|
||||||
go state node = do
|
go state node = do
|
||||||
-- Aggregate the accumulated gradients for this node.
|
-- Aggregate the accumulated gradients for this node.
|
||||||
|
@ -263,11 +275,17 @@ graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrd
|
||||||
if null outputGrads
|
if null outputGrads
|
||||||
then pure state
|
then pure state
|
||||||
else do
|
else do
|
||||||
let ctx = FGL.context gr node
|
|
||||||
inputGrads <- calculateInputGrads ctx outputGrads gr
|
|
||||||
-- Calculate the gradients for each of the node's inputs.
|
|
||||||
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
||||||
pure $ updatePendingGradients ctx inputGrads nextState
|
-- Only consider nodes that are reachable from the inputs to
|
||||||
|
-- avoid calculating gradients that won't be used.
|
||||||
|
if node `IntSet.member` reachableSet
|
||||||
|
then do
|
||||||
|
let ctx = FGL.context gr node
|
||||||
|
inputGrads <- calculateInputGrads ctx outputGrads gr
|
||||||
|
-- Calculate the gradients for each of the node's inputs.
|
||||||
|
pure $ updatePendingGradients ctx inputGrads nextState
|
||||||
|
else
|
||||||
|
pure nextState
|
||||||
|
|
||||||
-- | Reduce accumulated gradients for each output to one Tensor.
|
-- | Reduce accumulated gradients for each output to one Tensor.
|
||||||
sumPendingGradient :: GradientCompatible a
|
sumPendingGradient :: GradientCompatible a
|
||||||
|
@ -839,11 +857,8 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
||||||
-- through each read.
|
-- through each read.
|
||||||
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
||||||
|
|
||||||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
|
||||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||||
opGrad "Placeholder" _ _ _ = []
|
|
||||||
opGrad "VarHandleOp" _ _ _ = []
|
opGrad "VarHandleOp" _ _ _ = []
|
||||||
opGrad "Variable" _ _ _ = []
|
|
||||||
|
|
||||||
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
|
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
|
||||||
where
|
where
|
||||||
|
|
|
@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
|
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||||
import qualified TensorFlow.Output as TF
|
import qualified TensorFlow.Output as TF
|
||||||
|
@ -123,6 +123,65 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
||||||
]
|
]
|
||||||
sort expected @=? sort ops
|
sort expected @=? sort ops
|
||||||
|
|
||||||
|
testGradientIncidental :: Test
|
||||||
|
testGradientIncidental = testCase "testGradientIncidental" $ do
|
||||||
|
let grads = do
|
||||||
|
x <- TF.render $ TF.scalar (3 :: Float)
|
||||||
|
b <- TF.render $ TF.scalar (4 :: Float)
|
||||||
|
w <- TF.render $ TF.diag $ TF.vector [ 1.0 :: Float ]
|
||||||
|
let incidental = b `TF.mul` w
|
||||||
|
let y = (x `TF.mul` b) `TF.add` incidental
|
||||||
|
TF.gradients y [x]
|
||||||
|
|
||||||
|
-- Assert that the gradients are right.
|
||||||
|
[dx] <- TF.runSession $ grads >>= TF.run
|
||||||
|
4 @=? TF.unScalar dx
|
||||||
|
-- Assert that the graph has the expected ops.
|
||||||
|
let graphDef = TF.asGraphDef grads
|
||||||
|
putStrLn $ showMessage graphDef
|
||||||
|
let ops = graphDef ^.. node . traverse . op
|
||||||
|
expected = [ "Add"
|
||||||
|
, "BroadcastGradientArgs"
|
||||||
|
, "BroadcastGradientArgs"
|
||||||
|
, "Const"
|
||||||
|
, "Const"
|
||||||
|
, "Const"
|
||||||
|
, "Const"
|
||||||
|
, "Diag"
|
||||||
|
, "Fill"
|
||||||
|
, "Mul"
|
||||||
|
, "Mul"
|
||||||
|
, "Mul"
|
||||||
|
, "Mul"
|
||||||
|
, "Reshape"
|
||||||
|
, "Reshape"
|
||||||
|
, "Reshape"
|
||||||
|
, "Reshape"
|
||||||
|
, "Shape"
|
||||||
|
, "Shape"
|
||||||
|
, "Shape"
|
||||||
|
, "Shape"
|
||||||
|
, "Shape"
|
||||||
|
, "Sum"
|
||||||
|
, "Sum"
|
||||||
|
, "Sum"
|
||||||
|
, "Sum"
|
||||||
|
]
|
||||||
|
sort expected @=? sort ops
|
||||||
|
|
||||||
|
testGradientPruning :: Test
|
||||||
|
testGradientPruning = testCase "testGradientPruning" $ do
|
||||||
|
let grads = do
|
||||||
|
x <- TF.render $ TF.scalar (3 :: Float)
|
||||||
|
b <- TF.render $ TF.scalar (4 :: Float)
|
||||||
|
bx <- TF.render $ b `TF.mul` x
|
||||||
|
let y = bx `TF.add` b
|
||||||
|
TF.gradients y [x, bx]
|
||||||
|
|
||||||
|
-- Assert that the gradients are right.
|
||||||
|
[dx, dxb] <- TF.runSession $ grads >>= TF.run
|
||||||
|
4 @=? TF.unScalar dx
|
||||||
|
1 @=? TF.unScalar dxb
|
||||||
|
|
||||||
-- Test that identical "stateful" ops work with createGraph.
|
-- Test that identical "stateful" ops work with createGraph.
|
||||||
testCreateGraphStateful :: Test
|
testCreateGraphStateful :: Test
|
||||||
|
@ -545,6 +604,8 @@ main :: IO ()
|
||||||
main = defaultMain
|
main = defaultMain
|
||||||
[ testGradientSimple
|
[ testGradientSimple
|
||||||
, testGradientDisconnected
|
, testGradientDisconnected
|
||||||
|
, testGradientIncidental
|
||||||
|
, testGradientPruning
|
||||||
, testCreateGraphStateful
|
, testCreateGraphStateful
|
||||||
, testCreateGraphNameScopes
|
, testCreateGraphNameScopes
|
||||||
, testDiamond
|
, testDiamond
|
||||||
|
|
Loading…
Reference in a new issue