{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module TensorFlow.Gradient
( GradientCompatible
, gradients
) where
import Control.Monad (forM, zipWithM)
import Control.Monad.State.Strict (State, evalState, gets, modify)
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int32, Int64)
import Data.Foldable (foldlM)
import Data.List (foldl', sortBy)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
import Data.Ord (comparing)
import Data.ProtoLens.TextFormat (showMessage)
import Data.Set (Set)
import Data.Text (Text)
import Data.Tuple (swap)
import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
import Lens.Family2.State.Strict (uses)
import Lens.Family2.Stock (at, intAt)
import Lens.Family2.Unchecked (lens, iso)
import Prelude hiding (sum)
import Text.Printf (printf)
import qualified Data.Graph.Inductive.Basic as FGL
import qualified Data.Graph.Inductive.Graph as FGL
import qualified Data.Graph.Inductive.PatriciaTree as FGL
import qualified Data.Graph.Inductive.Query.DFS as FGL
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as Text
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Build
( MonadBuild
, Build
, build
, renderedNodeDefs
, opDef
, opAttr
, opInputs
)
import TensorFlow.BuildOp
import TensorFlow.Ops
( addN
, broadcastGradientArgs
, expandDims
, fill
, matMul
, matMul'
, reducedShape
, reluGrad
, reshape
, scalar
, shape
, softmaxCrossEntropyWithLogits
, sum
, scalarize
, vector
, zerosLike
)
import TensorFlow.Output
( NodeName(..)
, Output(..)
, OutputIx(..)
, outputIndex
)
import TensorFlow.Tensor
( Tensor(..)
, Value
, render
, expr
, Rendered
, tensorNodeName
, renderedOutput
, renderValue
, ToTensor(..)
)
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef
(NodeDef, attr, input, op, name)
type GradientCompatible a =
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
gradients :: forall a v1 t m . ( MonadBuild m
, Rendered t
, ToTensor t
, GradientCompatible a
)
=> Tensor v1 a
-> [t a]
-> m [Tensor Value a]
gradients y xs = build $ do
y' <- renderValue y
let yName = tensorNodeName y'
yOne <- render $ fill (shape y') (scalar 1)
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
. flip Map.lookup
let (gr, nodeMap) = createGraph yName nodeDefLookup
let initPending :: Map.Map FGL.Node (PendingGradients a)
= Map.empty & (at (nodeMap Map.! yName)
. nonEmpty
. outputIxAt (outputIndex $ renderedOutput y')
. nonEmpty
.~ [yOne]
)
gradientMap <- graphGrads gr initPending
forM xs $ \x ->
let Output i xName = renderedOutput x
in maybe (render $ zerosLike $ toTensor x) return $ do
n <- nodeMap ^. at xName
gradientMap ^. at n . nonEmpty . outputIxAt i
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
outputIxAt = intAt . unOutputIx
type PendingGradients a = IntMap.IntMap [Tensor Value a]
type Gradients a = IntMap.IntMap (Tensor Value a)
type Graph = FGL.Gr NodeDef EdgeLabel
type EdgeLabel = (OutputIx, OutputIx)
data GradientsState a = GradientsState
{ _gradientsPending :: !(Map FGL.Node (PendingGradients a))
, _gradientsResult :: !(Map FGL.Node (Gradients a))
}
gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a))
gradientsPending = lens _gradientsPending (\x y -> x { _gradientsPending = y })
gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a))
gradientsResult = lens _gradientsResult (\x y -> x { _gradientsResult = y })
safeIndex :: [a] -> Int -> Maybe a
_ `safeIndex` n | n < 0 = Nothing
[] `safeIndex` _ = Nothing
(x:_) `safeIndex` 0 = Just x
(_:xs) `safeIndex` n = xs `safeIndex` (n-1)
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
anon a p = iso (fromMaybe a) go where
go b | p b = Nothing
| otherwise = Just b
non :: Eq a => a -> Lens' (Maybe a) a
non a = anon a (a==)
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
nonEmpty = anon mempty null
graphGrads :: forall a. GradientCompatible a
=> Graph
-> Map FGL.Node (PendingGradients a)
-> Build (Map FGL.Node (Gradients a))
graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
where
initState = GradientsState initPending Map.empty
nodeOrder = FGL.topsort $ FGL.grev gr
go :: GradientsState a -> Int -> Build (GradientsState a)
go state node = do
outputGrads <-
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
if null outputGrads
then pure state
else do
let ctx = FGL.context gr node
inputGrads <- calculateInputGrads ctx outputGrads gr
let nextState = state & gradientsResult %~ Map.insert node outputGrads
pure $ updatePendingGradients ctx inputGrads nextState
sumPendingGradient :: GradientCompatible a
=> PendingGradients a -> Build (Gradients a)
sumPendingGradient = sequence . IntMap.mapMaybe f
where
f [] = Nothing
f [x] = Just (pure x)
f xs = Just (render $ addN xs)
calculateInputGrads :: forall a. GradientCompatible a
=> FGL.Context NodeDef EdgeLabel
-> Gradients a
-> Graph
-> Build [Maybe (Tensor Value a)]
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = do
fullOutGrads <- fullOutputGrads (numOutputs nodeDef) (nodeDefName nodeDef)
outputGrads
traverse (traverse render) $ opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
where
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
edgeToTensor ((i, _), n) =
case FGL.lab gr n of
Just edgeNodeDef -> Output i (NodeName $ edgeNodeDef ^. name)
Nothing -> error $ "calculateInputGrads: missing input node for "
++ Text.unpack (nodeDef ^. name)
inputTensors = map edgeToTensor $ sortBy (comparing (snd . fst)) inputEdges
fullOutputGrads :: (TensorType a, Num a)
=> OutputIx
-> NodeName
-> Gradients a
-> Build [Tensor Value a]
fullOutputGrads n o gs =
mapM (\i -> maybe (render $ zero i) return (gs ^. outputIxAt i)) [0..n-1]
where
zero i = zerosLike $ toT (Output i o)
updatePendingGradients :: forall a. (TensorType a, Num a)
=> FGL.Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)]
-> GradientsState a
-> GradientsState a
updatePendingGradients (inputEdges, _, nodeDef, _) inputGrads initState =
foldl' go initState inputEdges
where
go :: GradientsState a -> (EdgeLabel, FGL.Node) -> GradientsState a
go state ((outIndex, OutputIx inIndex), node) =
case maybeGradient of
Nothing -> state
Just g ->
state & gradientsPending
. at node
. nonEmpty
. outputIxAt outIndex
. nonEmpty
%~ (g:)
where
badSizeErr = error $ printf "updatePendingGradients: bad input index \
\%d for inputGrads of length %d in %s"
inIndex (length inputGrads)
(show (nodeDef ^. name))
maybeGradient = fromMaybe badSizeErr (safeIndex inputGrads inIndex)
createGraph :: NodeName -> (NodeName -> NodeDef)
-> (Graph, Map NodeName FGL.Node)
createGraph nodeName nodeDefLookup = (FGL.nmap nodeDefLookup graph, nodeMap)
where
parseTensorName :: Text -> Maybe (NodeName, OutputIx)
parseTensorName n
| Text.null n = error "parseTensorName: empty name"
| Text.head n == '^' = Nothing
| otherwise =
let (nm, indexStr) = Text.breakOn ":" n
index | Text.null indexStr = 0
| otherwise = read $ Text.unpack $ Text.tail indexStr
in Just (NodeName nm, OutputIx index)
collect :: Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State (Set NodeName)
(Map NodeName [(NodeName, OutputIx, OutputIx)])
collect outgoingEdge nm = do
let nextLookup = Map.singleton nm (maybeToList outgoingEdge)
seen <- gets (Set.member nm)
modify (Set.insert nm)
if seen
then pure nextLookup
else do
let inputs = nodeDefLookup nm ^. input
recurse inIndex (parentName, outIndex) =
collect (Just (nm, outIndex, inIndex)) parentName
subEdgeLookups <-
zipWithM recurse [0..] $ mapMaybe parseTensorName inputs
pure $ Map.unionsWith (++) (nextLookup:subEdgeLookups)
edgeLookup = evalState (collect Nothing nodeName) Set.empty
nodeMap = Map.fromList $ zip (Map.keys edgeLookup) [0..]
graph = FGL.mkGraph (swap <$> Map.toList nodeMap)
[ (nodeMap Map.! n, nodeMap Map.! m, (i, j))
| (n, edges) <- Map.toList edgeLookup
, (m, i, j) <- edges
]
type GradientFunc a = NodeDef
-> [Output]
-> [Tensor Value a]
-> [Maybe (Tensor Build a)]
toT :: Output -> Tensor Build a
toT = Tensor . pure
flatSlice :: forall v1 t . TensorType t
=> Tensor v1 t
-> Int32
-> Int32
-> Tensor Build t
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
nodeDefName :: NodeDef -> NodeName
nodeDefName = NodeName . view name
gradForBinaryCwise :: ( OneOf '[ Int32, Int64, Float, Double, Complex Float, Complex Double ] t
)
=> (Tensor v1 t, Tensor v1 t)
-> (Tensor v1 t, Tensor v1 t)
-> [ Maybe (Tensor Build t) ]
gradForBinaryCwise (x, gx) (y, gy) =
[ Just dx
, Just dy ]
where
dx = reshape (sum gx rx) sx
dy = reshape (sum gy ry) sy
sx = shape x
sy = shape y
(rx, ry) = broadcastGradientArgs sx sy
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
opGrad "Concat" _ _ix [dy]
| m == 1 = Nothing : [Just $ expr dy]
| otherwise = Nothing : map Just (dx `reshapeZip` s)
where
reshapeZip = zipWith reshape
dx = CoreOps.splitV (fromIntegral m) dy ki _i
s :: [Tensor Build Int32]
s = map shape x
x :: [Tensor Build a]
x = map toT $ tail _ix
_i = toT (head _ix) `CoreOps.floorMod` n
i = reshape _i $ vector [1 :: Int32]
ki :: Tensor Build Int32
ki = CoreOps.concat 0 $ map (\t -> CoreOps.slice t i $ vector [1 :: Int32]) s
m = length x
n = CoreOps.rank (head x)
opGrad "Square" _ [toT -> x] [dz] =
[Just $ dz `CoreOps.mul` (2 * x)]
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
[ Just $ CoreOps.unsortedSegmentSum values indices' numRows
, Nothing
]
where
denseShape = shape (x :: Tensor Build a)
numRows = scalarize $ flatSlice denseShape 0 1
valuesShape = CoreOps.concat 0 [ allDimensions
, flatSlice denseShape 1 (-1)
]
values = reshape dz valuesShape
indices' = reshape indices allDimensions :: Tensor Build Int32
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
where
sx = shape (x :: Tensor Build a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
y = CoreOps.max x indices
y' = reshape y outputShapeKeptDims
dz' = reshape dz outputShapeKeptDims
indicators = CoreOps.cast $ CoreOps.equal y' x
numSelected = reshape (sum indicators indices) outputShapeKeptDims
opGrad "Min" u v w = opGrad "Max" u v w
opGrad "Maximum" _ [toT -> x, toT -> y] [dz] =
gradForBinaryCwise (x, gx) (y, gy)
where
xmask = CoreOps.greaterEqual x y
gx = CoreOps.select xmask dz (CoreOps.zerosLike dz)
gy = CoreOps.select (CoreOps.logicalNot xmask) dz (CoreOps.zerosLike dz)
opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
where
sx = shape (x :: Tensor Build a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
tileScaling = safeShapeDiv sx outputShapeKeptDims
grad = reshape dz outputShapeKeptDims
opGrad "Mean" u v@[toT -> x, _] w =
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
where
[Just dz, Nothing] = opGrad "Sum" u v w
inputShape = shape (x :: Tensor Build a)
outputShape = shape (dz :: Tensor Build a)
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
factor = safeShapeDiv inputSize outputSize
opGrad "Add" _ [toT -> x, toT -> y] [dz] =
[ Just $ reshape (sum dz rx) sx
, Just $ reshape (sum dz ry) sy ]
where
sx = shape (x :: Tensor Build a)
sy = shape (y :: Tensor Build a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "AddN" _ inputs [dz] =
map ((const . Just . expr) dz) inputs
opGrad "Sub" u v w =
[Just x, Just (-y)]
where
[Just x, Just y] = opGrad "Add" u v w
opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
[ Just $ expandDims dz (-1) * snd (softmaxCrossEntropyWithLogits x y)
, Nothing ]
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
[ Just $ reshape (sum (dz `CoreOps.mul` y) rx) sx
, Just $ reshape (sum (x `CoreOps.mul` dz) ry) sy ]
where
sx = shape (x :: Tensor Build a)
sy = shape (y :: Tensor Build a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
, Just $ reshape (sum (dz `CoreOps.mul` (negate x `CoreOps.div` (y * y)))
ry)
sy
]
where
sx = shape (x :: Tensor Build a)
sy = shape (y :: Tensor Build a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
let transposeA = lookupAttr nodeDef "transpose_a"
transposeB = lookupAttr nodeDef "transpose_b"
transAttrs a b =
(opAttr "transpose_a" .~ a) . (opAttr "transpose_b" .~ b)
in case (transposeA, transposeB) of
(False, False) ->
[ Just $ matMul' (transAttrs False True) dz y
, Just $ matMul' (transAttrs True False) x dz]
(False, True) ->
[ Just $ matMul dz y
, Just $ matMul' (transAttrs True False) dz x]
(True, False) ->
[ Just $ matMul' (transAttrs False True) y dz
, Just $ matMul x dz]
(True, True) ->
[ Just $ matMul' (transAttrs True True) y dz
, Just $ matMul' (transAttrs True True) dz x]
opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz
(CoreOps.invertPermutation p :: Tensor Build Int32)
, Nothing
]
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.conv2DBackpropInput'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
(shape x) y dz
, Just $ CoreOps.conv2DBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
x (shape y) dz
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
[ Nothing
, Just $ CoreOps.conv2DBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
dz (shape x) y
, Just $ CoreOps.conv2D'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
dz x
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
[ Just $ CoreOps.maxPoolGrad'
((opAttr "ksize" .~ ksize)
. (opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
x output dz
]
where
output :: Tensor Build a
output = toT $ Output 0 (nodeDefName nodeDef)
ksize = lookupAttr nodeDef "ksize" :: [Int64]
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "Reshape" _ [toT -> x, _] [dz] =
[Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
opGrad "TruncatedNormal" _ _ _ = [Nothing]
opGrad "RefIdentity" _ _ [dz] = [Just $ expr dz]
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
where
reverseCast =
pureOp [] $ pure (opDef "Cast"
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString)
& opInputs .~ [renderedOutput dz])
opGrad "DynamicStitch" nodeDef inputs [dz] =
replicate halfLen Nothing ++ valuesGrads
where
halfLen =
let len = length inputs
half = len `div` 2
in if 2 * half == len
then half
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Build Int32)
| idx <- take halfLen inputs
]
opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
[ Just reconstructed, Nothing ]
where
reconstructed = CoreOps.reshape stitched
(CoreOps.shape (xs :: Tensor Build a) :: Tensor Build Int32)
stitched = CoreOps.dynamicStitch partitionedIndices dz
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
np = lookupAttr nodeDef "num_partitions" :: Int64
originalIndices =
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
prefixShape = shapeInt32 indices
shapeInt32 t = CoreOps.shape t :: Tensor Build Int32
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
[ Nothing
, Just $ CoreOps.select c dz zeros
, Just $ CoreOps.select c zeros dz
]
where zeros = CoreOps.zerosLike x
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.inv x ]
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.exp x ]
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
[ Just $ CoreOps.unsortedSegmentSum
(CoreOps.gather dz (t :: Tensor Build Int32))
(y :: Tensor Build Int32) inputRows
, Nothing
, Nothing
]
where inputRows = flatSlice (shape (x :: Tensor Build a)) 0 1
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing]
opGrad "Size" _ _ _ = [Nothing]
opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] =
[Just inputGrad, Nothing]
where
inputGrad = sum reshapedDz axes
inputShape = shape (x :: Tensor Build a)
packed = CoreOps.pack [multiples, inputShape]
perm = vector [1, 0 :: Int32]
splitShape = CoreOps.reshape (CoreOps.transpose packed perm) allDimensions
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
reshapedDz = CoreOps.reshape dz splitShape
opGrad "ZerosLike" _ _ _ = [Nothing]
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
where
rx = rangeOfRank dz
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
opGrad "Const" _ _ _ = [Nothing, Nothing]
opGrad "Placeholder" _ _ _ = []
opGrad "VarHandleOp" _ _ _ = []
opGrad "Variable" _ _ _ = []
opGrad n nodeDef ins grads =
error $ "no gradient implemented for " ++
show (n, length ins, length grads, showMessage nodeDef, ins)
numOutputs :: NodeDef -> OutputIx
numOutputs o =
case o ^. op of
"Abs" -> 1
"Add" -> 1
"AddN" -> 1
"Cast" -> 1
"Const" -> 1
"Concat" -> 1
"Conv2D" -> 1
"Conv2DBackpropInput" -> 1
"Div" -> 1
"DynamicStitch" -> 1
"DynamicPartition" ->
fromIntegral (lookupAttr o "num_partitions" :: Int64)
"Exp" -> 1
"Gather" -> 1
"LabelClasses" -> 1
"LabelWeights" -> 1
"Log" -> 1
"MatMul" -> 1
"Max" -> 1
"Maximum" -> 1
"MaxPool" -> 1
"Mean" -> 1
"Min" -> 1
"Mul" -> 1
"Neg" -> 1
"Placeholder" -> 1
"OneHot" -> 1
"ReadVariableOp" -> 1
"RefIdentity" -> 1
"Relu" -> 1
"ReluGrad" -> 1
"Reshape" -> 1
"Select" -> 1
"Size" -> 1
"SoftmaxCrossEntropyWithLogits" -> 2
"Square" -> 1
"SparseSegmentSum" -> 1
"Sub" -> 1
"Sum" -> 1
"Tile" -> 1
"Transpose" -> 1
"TruncatedNormal" -> 1
"VarHandleOp" -> 1
"Variable" -> 1
"ZerosLike" -> 1
"Fill" -> 1
_ -> error $ "numOutputs not implemented for " ++ show (o ^. op)
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
allDimensions :: Tensor Build Int32
allDimensions = vector [-1 :: Int32]
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Build Int32
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens