diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index cb12aee..b3bc73c 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -36,6 +36,7 @@ import Data.Int (Int32, Int64) import Data.Foldable (foldlM) import Data.List (foldl', sortBy) import Data.Map.Strict (Map) +import qualified Data.IntSet as IntSet import Data.Maybe (fromMaybe, maybeToList, mapMaybe) import Data.Ord (comparing) import Data.ProtoLens.TextFormat (showMessage) @@ -165,6 +166,11 @@ gradients y xs = build $ do (\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x)) . flip Map.lookup let (gr, nodeMap) = createGraph yName nodeDefLookup + xnodes = mapMaybe (\x -> nodeMap ^. (at . outputNodeName . renderedOutput $ x)) xs + -- make a set of the nodes reachable from the xnodes + -- The xnodes are not part of this set (unless reachable from another xnode) + reachableSet = computeReachableSet xnodes gr + -- Set gradient of y to one. -- TODO: nicer let initPending :: Map.Map FGL.Node (PendingGradients a) @@ -175,7 +181,7 @@ gradients y xs = build $ do .~ [yOne] ) -- Calculate the gradients of y w.r.t. each node in the graph. - gradientMap <- graphGrads gr initPending + gradientMap <- graphGrads gr reachableSet initPending -- Lookup the gradients for each x. forM xs $ \x -> let Output i xName = renderedOutput x @@ -183,6 +189,13 @@ gradients y xs = build $ do n <- nodeMap ^. at xName gradientMap ^. at n . nonEmpty . outputIxAt i +-- | Compute a set of nodes reachable from the start nodes +-- +-- the start nodes are excluded, unless reachable from another start node +computeReachableSet :: [FGL.Node] -> Graph -> IntSet.IntSet +computeReachableSet vs g = + IntSet.fromList $ concatMap (drop 1 . FGL.preorder) (FGL.dff vs g) + outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v) outputIxAt = intAt . unOutputIx @@ -245,16 +258,15 @@ nonEmpty = anon mempty null -- | Calculate the gradients for every node in a graph. graphGrads :: forall a. GradientCompatible a => Graph + -> IntSet.IntSet -> Map FGL.Node (PendingGradients a) -- ^ Initial gradients (usually just 1 for the node of interest). -> Build (Map FGL.Node (Gradients a)) -graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder +graphGrads gr reachableSet initPending = view gradientsResult <$> foldlM go initState nodeOrder where initState = GradientsState initPending Map.empty -- Reverse topological sort. - -- TODO(fmayle): Filter out nodes that are not successors of any x in xs to - -- avoid calculating gradients that won't be used. - nodeOrder = FGL.topsort $ FGL.grev gr + nodeOrder = FGL.topsort . FGL.grev $ gr go :: GradientsState a -> Int -> Build (GradientsState a) go state node = do -- Aggregate the accumulated gradients for this node. @@ -263,11 +275,17 @@ graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrd if null outputGrads then pure state else do - let ctx = FGL.context gr node - inputGrads <- calculateInputGrads ctx outputGrads gr - -- Calculate the gradients for each of the node's inputs. let nextState = state & gradientsResult %~ Map.insert node outputGrads - pure $ updatePendingGradients ctx inputGrads nextState + -- Only consider nodes that are reachable from the inputs to + -- avoid calculating gradients that won't be used. + if node `IntSet.member` reachableSet + then do + let ctx = FGL.context gr node + inputGrads <- calculateInputGrads ctx outputGrads gr + -- Calculate the gradients for each of the node's inputs. + pure $ updatePendingGradients ctx inputGrads nextState + else + pure nextState -- | Reduce accumulated gradients for each output to one Tensor. sumPendingGradient :: GradientCompatible a @@ -839,11 +857,8 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx] -- through each read. opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz] --- TODO(fmayle): These can go away if we properly prune the graph. opGrad "Const" _ _ _ = [Nothing, Nothing] -opGrad "Placeholder" _ _ _ = [] opGrad "VarHandleOp" _ _ _ = [] -opGrad "Variable" _ _ _ = [] opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz] where diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index dd36480..4bc1a45 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM) import Control.Monad.IO.Class (liftIO) import qualified TensorFlow.Core as TF -import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape) +import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag) import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape) import qualified TensorFlow.Output as TF @@ -123,6 +123,65 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do ] sort expected @=? sort ops +testGradientIncidental :: Test +testGradientIncidental = testCase "testGradientIncidental" $ do + let grads = do + x <- TF.render $ TF.scalar (3 :: Float) + b <- TF.render $ TF.scalar (4 :: Float) + w <- TF.render $ TF.diag $ TF.vector [ 1.0 :: Float ] + let incidental = b `TF.mul` w + let y = (x `TF.mul` b) `TF.add` incidental + TF.gradients y [x] + + -- Assert that the gradients are right. + [dx] <- TF.runSession $ grads >>= TF.run + 4 @=? TF.unScalar dx + -- Assert that the graph has the expected ops. + let graphDef = TF.asGraphDef grads + putStrLn $ showMessage graphDef + let ops = graphDef ^.. node . traverse . op + expected = [ "Add" + , "BroadcastGradientArgs" + , "BroadcastGradientArgs" + , "Const" + , "Const" + , "Const" + , "Const" + , "Diag" + , "Fill" + , "Mul" + , "Mul" + , "Mul" + , "Mul" + , "Reshape" + , "Reshape" + , "Reshape" + , "Reshape" + , "Shape" + , "Shape" + , "Shape" + , "Shape" + , "Shape" + , "Sum" + , "Sum" + , "Sum" + , "Sum" + ] + sort expected @=? sort ops + +testGradientPruning :: Test +testGradientPruning = testCase "testGradientPruning" $ do + let grads = do + x <- TF.render $ TF.scalar (3 :: Float) + b <- TF.render $ TF.scalar (4 :: Float) + bx <- TF.render $ b `TF.mul` x + let y = bx `TF.add` b + TF.gradients y [x, bx] + + -- Assert that the gradients are right. + [dx, dxb] <- TF.runSession $ grads >>= TF.run + 4 @=? TF.unScalar dx + 1 @=? TF.unScalar dxb -- Test that identical "stateful" ops work with createGraph. testCreateGraphStateful :: Test @@ -545,6 +604,8 @@ main :: IO () main = defaultMain [ testGradientSimple , testGradientDisconnected + , testGradientIncidental + , testGradientPruning , testCreateGraphStateful , testCreateGraphNameScopes , testDiamond