1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Fix lens errors

This commit is contained in:
Johannes Maier 2023-03-03 16:24:18 +01:00 committed by fkm3
parent 00aeb23b1e
commit bfd8de5582

View file

@ -168,7 +168,7 @@ gradients y xs = build $ do
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x)) (\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
. flip Map.lookup . flip Map.lookup
let (gr, nodeMap) = createGraph yName nodeDefLookup let (gr, nodeMap) = createGraph yName nodeDefLookup
xnodes = mapMaybe (\x -> nodeMap ^. (at . outputNodeName . renderedOutput $ x)) xs xnodes = mapMaybe (\x -> nodeMap ^. (at $ outputNodeName $ renderedOutput x)) xs
-- make a set of the nodes reachable from the xnodes -- make a set of the nodes reachable from the xnodes
-- The xnodes are not part of this set (unless reachable from another xnode) -- The xnodes are not part of this set (unless reachable from another xnode)
reachableSet = computeReachableSet xnodes gr reachableSet = computeReachableSet xnodes gr
@ -199,7 +199,8 @@ computeReachableSet vs g =
IntSet.fromList $ concatMap (drop 1 . FGL.preorder) (FGL.dff vs g) IntSet.fromList $ concatMap (drop 1 . FGL.preorder) (FGL.dff vs g)
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v) outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
outputIxAt = intAt . unOutputIx -- NOTE: point-free notation leads to unification problems here
outputIxAt x = intAt (unOutputIx x)
-- | Incomplete gradients of a node's outputs. -- | Incomplete gradients of a node's outputs.
-- --