{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeApplications #-}
module TensorFlow.Gradient
( GradientCompatible
, gradients
) where
import Control.Monad (forM, zipWithM)
import Control.Monad.State.Strict (State, evalState, gets, modify)
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.ProtoLens.Default(def)
import Data.Int (Int32, Int64)
import Data.Foldable (foldlM)
import Data.List (foldl', sortBy)
import Data.Map.Strict (Map)
import qualified Data.IntSet as IntSet
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
import Data.Ord (comparing)
import Data.ProtoLens.TextFormat (showMessage)
import Data.Set (Set)
import Data.Text (Text)
import Data.Tuple (swap)
import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~), under)
import Lens.Family2.State.Strict (uses)
import Lens.Family2.Stock (at, intAt)
import Lens.Family2.Unchecked (lens, adapter)
import Prelude hiding (sum, tanh)
import Text.Printf (printf)
import qualified Data.Graph.Inductive.Basic as FGL
import qualified Data.Graph.Inductive.Graph as FGL
import qualified Data.Graph.Inductive.PatriciaTree as FGL
import qualified Data.Graph.Inductive.Query.DFS as FGL
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as Text
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Build
( MonadBuild
, Build
, build
, renderedNodeDefs
, opDef
, opAttr
, opInputs
)
import TensorFlow.BuildOp
import TensorFlow.Ops
( addN
, broadcastGradientArgs
, expandDims
, fill
, matMul
, matMul'
, reducedShape
, reluGrad
, tanh
, tanhGrad
, reshape
, scalar
, shape
, softmaxCrossEntropyWithLogits
, sum
, sigmoid
, sigmoidGrad
, scalarize
, vector
, zerosLike
)
import TensorFlow.Output
( NodeName(..)
, Output(..)
, OutputIx(..)
, outputIndex
)
import TensorFlow.Tensor
( Tensor(..)
, Value
, render
, expr
, Rendered
, tensorNodeName
, renderedOutput
, renderValue
, ToTensor(..)
)
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef)
import Proto.Tensorflow.Core.Framework.NodeDef_Fields
( attr, input, op, name)
type GradientCompatible a =
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
gradients :: forall a v1 t m . ( MonadBuild m
, Rendered t
, ToTensor t
, GradientCompatible a
)
=> Tensor v1 a
-> [t a]
-> m [Tensor Value a]
gradients :: Tensor v1 a -> [t a] -> m [Tensor Value a]
gradients y :: Tensor v1 a
y xs :: [t a]
xs = Build [Tensor Value a] -> m [Tensor Value a]
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build [Tensor Value a] -> m [Tensor Value a])
-> Build [Tensor Value a] -> m [Tensor Value a]
forall a b. (a -> b) -> a -> b
$ do
Tensor Value a
y' <- Tensor v1 a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) (v :: * -> *) a.
MonadBuild m =>
Tensor v a -> m (Tensor Value a)
renderValue Tensor v1 a
y
let yName :: NodeName
yName = Tensor Value a -> NodeName
forall (t :: * -> *) a. Rendered t => t a -> NodeName
tensorNodeName Tensor Value a
y'
Tensor Value a
yOne <- Tensor Build a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> BuildT Identity (Tensor Value a))
-> Tensor Build a -> BuildT Identity (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t index_type.
(TensorType t, OneOf '[Int32, Int64] index_type) =>
Tensor v'1 index_type -> Tensor v'2 t -> Tensor Build t
fill (Tensor Value a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Value a
y') (a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar 1)
NodeName -> NodeDef
nodeDefLookup :: (NodeName -> NodeDef) <- FoldLike
(NodeName -> NodeDef)
GraphState
GraphState
(Map NodeName NodeDef)
(Map NodeName NodeDef)
-> (Map NodeName NodeDef -> NodeName -> NodeDef)
-> BuildT Identity (NodeName -> NodeDef)
forall s (m :: * -> *) r t a b.
MonadState s m =>
FoldLike r s t a b -> (a -> r) -> m r
uses FoldLike
(NodeName -> NodeDef)
GraphState
GraphState
(Map NodeName NodeDef)
(Map NodeName NodeDef)
Lens' GraphState (Map NodeName NodeDef)
renderedNodeDefs ((Map NodeName NodeDef -> NodeName -> NodeDef)
-> BuildT Identity (NodeName -> NodeDef))
-> (Map NodeName NodeDef -> NodeName -> NodeDef)
-> BuildT Identity (NodeName -> NodeDef)
forall a b. (a -> b) -> a -> b
$
(\f :: NodeName -> Maybe NodeDef
f x :: NodeName
x -> NodeDef -> Maybe NodeDef -> NodeDef
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> NodeDef
forall a. HasCallStack => [Char] -> a
error ([Char] -> NodeDef) -> [Char] -> NodeDef
forall a b. (a -> b) -> a -> b
$ "no NodeDef found for " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ NodeName -> [Char]
forall a. Show a => a -> [Char]
show NodeName
x) (NodeName -> Maybe NodeDef
f NodeName
x))
((NodeName -> Maybe NodeDef) -> NodeName -> NodeDef)
-> (Map NodeName NodeDef -> NodeName -> Maybe NodeDef)
-> Map NodeName NodeDef
-> NodeName
-> NodeDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeName -> Map NodeName NodeDef -> Maybe NodeDef)
-> Map NodeName NodeDef -> NodeName -> Maybe NodeDef
forall a b c. (a -> b -> c) -> b -> a -> c
flip NodeName -> Map NodeName NodeDef -> Maybe NodeDef
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup
let (gr :: Graph
gr, nodeMap :: Map NodeName Node
nodeMap) = NodeName -> (NodeName -> NodeDef) -> (Graph, Map NodeName Node)
createGraph NodeName
yName NodeName -> NodeDef
nodeDefLookup
xnodes :: [Node]
xnodes = (t a -> Maybe Node) -> [t a] -> [Node]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\x :: t a
x -> Map NodeName Node
nodeMap Map NodeName Node
-> FoldLike
(Maybe Node)
(Map NodeName Node)
(Map NodeName Node)
(Maybe Node)
(Maybe Node)
-> Maybe Node
forall s a t b. s -> FoldLike a s t a b -> a
^. (NodeName
-> FoldLike
(Maybe Node)
(Map NodeName Node)
(Map NodeName Node)
(Maybe Node)
(Maybe Node)
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at (NodeName
-> FoldLike
(Maybe Node)
(Map NodeName Node)
(Map NodeName Node)
(Maybe Node)
(Maybe Node))
-> (t a -> NodeName)
-> t a
-> FoldLike
(Maybe Node)
(Map NodeName Node)
(Map NodeName Node)
(Maybe Node)
(Maybe Node)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> NodeName
outputNodeName (Output -> NodeName) -> (t a -> Output) -> t a -> NodeName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput (t a
-> FoldLike
(Maybe Node)
(Map NodeName Node)
(Map NodeName Node)
(Maybe Node)
(Maybe Node))
-> t a
-> FoldLike
(Maybe Node)
(Map NodeName Node)
(Map NodeName Node)
(Maybe Node)
(Maybe Node)
forall a b. (a -> b) -> a -> b
$ t a
x)) [t a]
xs
reachableSet :: IntSet
reachableSet = [Node] -> Graph -> IntSet
computeReachableSet [Node]
xnodes Graph
gr
let Map Node (PendingGradients a)
initPending :: Map.Map FGL.Node (PendingGradients a)
= Map Node (PendingGradients a)
forall k a. Map k a
Map.empty Map Node (PendingGradients a)
-> (Map Node (PendingGradients a) -> Map Node (PendingGradients a))
-> Map Node (PendingGradients a)
forall s t. s -> (s -> t) -> t
& (Node
-> Lens'
(Map Node (PendingGradients a)) (Maybe (PendingGradients a))
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at (Map NodeName Node
nodeMap Map NodeName Node -> NodeName -> Node
forall k a. Ord k => Map k a -> k -> a
Map.! NodeName
yName)
LensLike'
f (Map Node (PendingGradients a)) (Maybe (PendingGradients a))
-> (([Tensor Value a] -> f [Tensor Value a])
-> Maybe (PendingGradients a) -> f (Maybe (PendingGradients a)))
-> ([Tensor Value a] -> f [Tensor Value a])
-> Map Node (PendingGradients a)
-> f (Map Node (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LensLike' f (Maybe (PendingGradients a)) (PendingGradients a)
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty
LensLike' f (Maybe (PendingGradients a)) (PendingGradients a)
-> (([Tensor Value a] -> f [Tensor Value a])
-> PendingGradients a -> f (PendingGradients a))
-> ([Tensor Value a] -> f [Tensor Value a])
-> Maybe (PendingGradients a)
-> f (Maybe (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutputIx -> Lens' (PendingGradients a) (Maybe [Tensor Value a])
forall v. OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt (Output -> OutputIx
outputIndex (Output -> OutputIx) -> Output -> OutputIx
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput Tensor Value a
y')
LensLike' f (PendingGradients a) (Maybe [Tensor Value a])
-> (([Tensor Value a] -> f [Tensor Value a])
-> Maybe [Tensor Value a] -> f (Maybe [Tensor Value a]))
-> ([Tensor Value a] -> f [Tensor Value a])
-> PendingGradients a
-> f (PendingGradients a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Tensor Value a] -> f [Tensor Value a])
-> Maybe [Tensor Value a] -> f (Maybe [Tensor Value a])
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty
(forall (f :: * -> *).
Identical f =>
([Tensor Value a] -> f [Tensor Value a])
-> Map Node (PendingGradients a)
-> f (Map Node (PendingGradients a)))
-> [Tensor Value a]
-> Map Node (PendingGradients a)
-> Map Node (PendingGradients a)
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Tensor Value a
yOne]
)
Map Node (Gradients a)
gradientMap <- Graph
-> IntSet
-> Map Node (PendingGradients a)
-> Build (Map Node (Gradients a))
forall a.
GradientCompatible a =>
Graph
-> IntSet
-> Map Node (PendingGradients a)
-> Build (Map Node (Gradients a))
graphGrads Graph
gr IntSet
reachableSet Map Node (PendingGradients a)
initPending
[t a]
-> (t a -> BuildT Identity (Tensor Value a))
-> Build [Tensor Value a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [t a]
xs ((t a -> BuildT Identity (Tensor Value a))
-> Build [Tensor Value a])
-> (t a -> BuildT Identity (Tensor Value a))
-> Build [Tensor Value a]
forall a b. (a -> b) -> a -> b
$ \x :: t a
x ->
let Output i :: OutputIx
i xName :: NodeName
xName = t a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput t a
x
in BuildT Identity (Tensor Value a)
-> (Tensor Value a -> BuildT Identity (Tensor Value a))
-> Maybe (Tensor Value a)
-> BuildT Identity (Tensor Value a)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Tensor Build a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> BuildT Identity (Tensor Value a))
-> Tensor Build a -> BuildT Identity (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
zerosLike (Tensor Build a -> Tensor Build a)
-> Tensor Build a -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ t a -> Tensor Build a
forall (t :: * -> *) a.
(ToTensor t, TensorType a) =>
t a -> Tensor Build a
toTensor t a
x) Tensor Value a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Tensor Value a) -> BuildT Identity (Tensor Value a))
-> Maybe (Tensor Value a) -> BuildT Identity (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ do
Node
n <- Map NodeName Node
nodeMap Map NodeName Node
-> FoldLike
(Maybe Node)
(Map NodeName Node)
(Map NodeName Node)
(Maybe Node)
(Maybe Node)
-> Maybe Node
forall s a t b. s -> FoldLike a s t a b -> a
^. NodeName -> Lens' (Map NodeName Node) (Maybe Node)
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at NodeName
xName
Map Node (Gradients a)
gradientMap Map Node (Gradients a)
-> FoldLike
(Maybe (Tensor Value a))
(Map Node (Gradients a))
(Map Node (Gradients a))
(Maybe (Tensor Value a))
(Maybe (Tensor Value a))
-> Maybe (Tensor Value a)
forall s a t b. s -> FoldLike a s t a b -> a
^. Node -> Lens' (Map Node (Gradients a)) (Maybe (Gradients a))
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at Node
n LensLike'
(Constant (Maybe (Tensor Value a)))
(Map Node (Gradients a))
(Maybe (Gradients a))
-> ((Maybe (Tensor Value a)
-> Constant (Maybe (Tensor Value a)) (Maybe (Tensor Value a)))
-> Maybe (Gradients a)
-> Constant (Maybe (Tensor Value a)) (Maybe (Gradients a)))
-> FoldLike
(Maybe (Tensor Value a))
(Map Node (Gradients a))
(Map Node (Gradients a))
(Maybe (Tensor Value a))
(Maybe (Tensor Value a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LensLike'
(Constant (Maybe (Tensor Value a)))
(Maybe (Gradients a))
(Gradients a)
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty LensLike'
(Constant (Maybe (Tensor Value a)))
(Maybe (Gradients a))
(Gradients a)
-> ((Maybe (Tensor Value a)
-> Constant (Maybe (Tensor Value a)) (Maybe (Tensor Value a)))
-> Gradients a -> Constant (Maybe (Tensor Value a)) (Gradients a))
-> (Maybe (Tensor Value a)
-> Constant (Maybe (Tensor Value a)) (Maybe (Tensor Value a)))
-> Maybe (Gradients a)
-> Constant (Maybe (Tensor Value a)) (Maybe (Gradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutputIx -> Lens' (Gradients a) (Maybe (Tensor Value a))
forall v. OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt OutputIx
i
computeReachableSet :: [FGL.Node] -> Graph -> IntSet.IntSet
computeReachableSet :: [Node] -> Graph -> IntSet
computeReachableSet vs :: [Node]
vs g :: Graph
g =
[Node] -> IntSet
IntSet.fromList ([Node] -> IntSet) -> [Node] -> IntSet
forall a b. (a -> b) -> a -> b
$ (Tree Node -> [Node]) -> [Tree Node] -> [Node]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Node -> [Node] -> [Node]
forall a. Node -> [a] -> [a]
drop 1 ([Node] -> [Node]) -> (Tree Node -> [Node]) -> Tree Node -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree Node -> [Node]
forall a. Tree a -> [a]
FGL.preorder) ([Node] -> Graph -> [Tree Node]
forall (gr :: * -> * -> *) a b.
Graph gr =>
[Node] -> gr a b -> [Tree Node]
FGL.dff [Node]
vs Graph
g)
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
outputIxAt :: OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt = Node -> LensLike' f (IntMap v) (Maybe v)
forall v. Node -> Lens' (IntMap v) (Maybe v)
intAt (Node -> LensLike' f (IntMap v) (Maybe v))
-> (OutputIx -> Node)
-> OutputIx
-> LensLike' f (IntMap v) (Maybe v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutputIx -> Node
unOutputIx
type PendingGradients a = IntMap.IntMap [Tensor Value a]
type Gradients a = IntMap.IntMap (Tensor Value a)
type Graph = FGL.Gr NodeDef EdgeLabel
type EdgeLabel = (OutputIx, OutputIx)
data GradientsState a = GradientsState
{ GradientsState a -> Map Node (PendingGradients a)
_gradientsPending :: !(Map FGL.Node (PendingGradients a))
, GradientsState a -> Map Node (Gradients a)
_gradientsResult :: !(Map FGL.Node (Gradients a))
}
gradientsPending :: Lens' (GradientsState a) (Map FGL.Node (PendingGradients a))
gradientsPending :: LensLike' f (GradientsState a) (Map Node (PendingGradients a))
gradientsPending = (GradientsState a -> Map Node (PendingGradients a))
-> (GradientsState a
-> Map Node (PendingGradients a) -> GradientsState a)
-> Lens
(GradientsState a)
(GradientsState a)
(Map Node (PendingGradients a))
(Map Node (PendingGradients a))
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GradientsState a -> Map Node (PendingGradients a)
forall a. GradientsState a -> Map Node (PendingGradients a)
_gradientsPending (\x :: GradientsState a
x y :: Map Node (PendingGradients a)
y -> GradientsState a
x { _gradientsPending :: Map Node (PendingGradients a)
_gradientsPending = Map Node (PendingGradients a)
y })
gradientsResult :: Lens' (GradientsState a) (Map FGL.Node (Gradients a))
gradientsResult :: LensLike' f (GradientsState a) (Map Node (Gradients a))
gradientsResult = (GradientsState a -> Map Node (Gradients a))
-> (GradientsState a -> Map Node (Gradients a) -> GradientsState a)
-> Lens
(GradientsState a)
(GradientsState a)
(Map Node (Gradients a))
(Map Node (Gradients a))
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens GradientsState a -> Map Node (Gradients a)
forall a. GradientsState a -> Map Node (Gradients a)
_gradientsResult (\x :: GradientsState a
x y :: Map Node (Gradients a)
y -> GradientsState a
x { _gradientsResult :: Map Node (Gradients a)
_gradientsResult = Map Node (Gradients a)
y })
safeIndex :: [a] -> Int -> Maybe a
_ safeIndex :: [a] -> Node -> Maybe a
`safeIndex` n :: Node
n | Node
n Node -> Node -> Bool
forall a. Ord a => a -> a -> Bool
< 0 = Maybe a
forall a. Maybe a
Nothing
[] `safeIndex` _ = Maybe a
forall a. Maybe a
Nothing
(x :: a
x:_) `safeIndex` 0 = a -> Maybe a
forall a. a -> Maybe a
Just a
x
(_:xs :: [a]
xs) `safeIndex` n :: Node
n = [a]
xs [a] -> Node -> Maybe a
forall a. [a] -> Node -> Maybe a
`safeIndex` (Node
nNode -> Node -> Node
forall a. Num a => a -> a -> a
-1)
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
anon :: a -> (a -> Bool) -> Lens' (Maybe a) a
anon a :: a
a p :: a -> Bool
p = Resetter (Maybe a) (f (Maybe a)) a (f a)
-> (a -> f a) -> Maybe a -> f (Maybe a)
forall s t a b. Resetter s t a b -> (a -> b) -> s -> t
under ((Maybe a -> a) -> (a -> Maybe a) -> Adapter (Maybe a) (Maybe a) a a
forall s a b t. (s -> a) -> (b -> t) -> Adapter s t a b
adapter (a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
a) a -> Maybe a
go) where
go :: a -> Maybe a
go b :: a
b | a -> Bool
p a
b = Maybe a
forall a. Maybe a
Nothing
| Bool
otherwise = a -> Maybe a
forall a. a -> Maybe a
Just a
b
non :: Eq a => a -> Lens' (Maybe a) a
non :: a -> Lens' (Maybe a) a
non a :: a
a = a -> (a -> Bool) -> Lens' (Maybe a) a
forall a. a -> (a -> Bool) -> Lens' (Maybe a) a
anon a
a (a
aa -> a -> Bool
forall a. Eq a => a -> a -> Bool
==)
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
nonEmpty :: Lens' (Maybe (t v)) (t v)
nonEmpty = t v -> (t v -> Bool) -> Lens' (Maybe (t v)) (t v)
forall a. a -> (a -> Bool) -> Lens' (Maybe a) a
anon t v
forall a. Monoid a => a
mempty t v -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null
graphGrads :: forall a. GradientCompatible a
=> Graph
-> IntSet.IntSet
-> Map FGL.Node (PendingGradients a)
-> Build (Map FGL.Node (Gradients a))
graphGrads :: Graph
-> IntSet
-> Map Node (PendingGradients a)
-> Build (Map Node (Gradients a))
graphGrads gr :: Graph
gr reachableSet :: IntSet
reachableSet initPending :: Map Node (PendingGradients a)
initPending = FoldLike
(Map Node (Gradients a))
(GradientsState a)
(GradientsState a)
(Map Node (Gradients a))
(Map Node (Gradients a))
-> GradientsState a -> Map Node (Gradients a)
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike
(Map Node (Gradients a))
(GradientsState a)
(GradientsState a)
(Map Node (Gradients a))
(Map Node (Gradients a))
forall a. Lens' (GradientsState a) (Map Node (Gradients a))
gradientsResult (GradientsState a -> Map Node (Gradients a))
-> BuildT Identity (GradientsState a)
-> Build (Map Node (Gradients a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (GradientsState a -> Node -> BuildT Identity (GradientsState a))
-> GradientsState a -> [Node] -> BuildT Identity (GradientsState a)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM GradientsState a -> Node -> BuildT Identity (GradientsState a)
go GradientsState a
initState [Node]
nodeOrder
where
initState :: GradientsState a
initState = Map Node (PendingGradients a)
-> Map Node (Gradients a) -> GradientsState a
forall a.
Map Node (PendingGradients a)
-> Map Node (Gradients a) -> GradientsState a
GradientsState Map Node (PendingGradients a)
initPending Map Node (Gradients a)
forall k a. Map k a
Map.empty
nodeOrder :: [Node]
nodeOrder = Graph -> [Node]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [Node]
FGL.topsort (Graph -> [Node]) -> (Graph -> Graph) -> Graph -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Graph -> Graph
forall (gr :: * -> * -> *) a b. DynGraph gr => gr a b -> gr a b
FGL.grev (Graph -> [Node]) -> Graph -> [Node]
forall a b. (a -> b) -> a -> b
$ Graph
gr
go :: GradientsState a -> Int -> Build (GradientsState a)
go :: GradientsState a -> Node -> BuildT Identity (GradientsState a)
go state :: GradientsState a
state node :: Node
node = do
Gradients a
outputGrads <-
PendingGradients a -> Build (Gradients a)
forall a.
GradientCompatible a =>
PendingGradients a -> Build (Gradients a)
sumPendingGradient (GradientsState a
state GradientsState a
-> FoldLike
(PendingGradients a)
(GradientsState a)
(GradientsState a)
(PendingGradients a)
(PendingGradients a)
-> PendingGradients a
forall s a t b. s -> FoldLike a s t a b -> a
^. LensLike'
(Constant (PendingGradients a))
(GradientsState a)
(Map Node (PendingGradients a))
forall a. Lens' (GradientsState a) (Map Node (PendingGradients a))
gradientsPending LensLike'
(Constant (PendingGradients a))
(GradientsState a)
(Map Node (PendingGradients a))
-> ((PendingGradients a
-> Constant (PendingGradients a) (PendingGradients a))
-> Map Node (PendingGradients a)
-> Constant (PendingGradients a) (Map Node (PendingGradients a)))
-> FoldLike
(PendingGradients a)
(GradientsState a)
(GradientsState a)
(PendingGradients a)
(PendingGradients a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Node
-> Lens'
(Map Node (PendingGradients a)) (Maybe (PendingGradients a))
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at Node
node LensLike'
(Constant (PendingGradients a))
(Map Node (PendingGradients a))
(Maybe (PendingGradients a))
-> ((PendingGradients a
-> Constant (PendingGradients a) (PendingGradients a))
-> Maybe (PendingGradients a)
-> Constant (PendingGradients a) (Maybe (PendingGradients a)))
-> (PendingGradients a
-> Constant (PendingGradients a) (PendingGradients a))
-> Map Node (PendingGradients a)
-> Constant (PendingGradients a) (Map Node (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PendingGradients a
-> Constant (PendingGradients a) (PendingGradients a))
-> Maybe (PendingGradients a)
-> Constant (PendingGradients a) (Maybe (PendingGradients a))
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty)
if Gradients a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Gradients a
outputGrads
then GradientsState a -> BuildT Identity (GradientsState a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure GradientsState a
state
else do
let nextState :: GradientsState a
nextState = GradientsState a
state GradientsState a
-> (GradientsState a -> GradientsState a) -> GradientsState a
forall s t. s -> (s -> t) -> t
& forall a. Lens' (GradientsState a) (Map Node (Gradients a))
forall (f :: * -> *).
Identical f =>
LensLike' f (GradientsState a) (Map Node (Gradients a))
gradientsResult (forall (f :: * -> *).
Identical f =>
LensLike' f (GradientsState a) (Map Node (Gradients a)))
-> (Map Node (Gradients a) -> Map Node (Gradients a))
-> GradientsState a
-> GradientsState a
forall s t a b. Setter s t a b -> (a -> b) -> s -> t
%~ Node
-> Gradients a -> Map Node (Gradients a) -> Map Node (Gradients a)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Node
node Gradients a
outputGrads
if Node
node Node -> IntSet -> Bool
`IntSet.member` IntSet
reachableSet
then do
let ctx :: Context NodeDef EdgeLabel
ctx = Graph -> Node -> Context NodeDef EdgeLabel
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Node -> Context a b
FGL.context Graph
gr Node
node
[Maybe (Tensor Value a)]
inputGrads <- Context NodeDef EdgeLabel
-> Gradients a -> Graph -> Build [Maybe (Tensor Value a)]
forall a.
GradientCompatible a =>
Context NodeDef EdgeLabel
-> Gradients a -> Graph -> Build [Maybe (Tensor Value a)]
calculateInputGrads Context NodeDef EdgeLabel
ctx Gradients a
outputGrads Graph
gr
GradientsState a -> BuildT Identity (GradientsState a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GradientsState a -> BuildT Identity (GradientsState a))
-> GradientsState a -> BuildT Identity (GradientsState a)
forall a b. (a -> b) -> a -> b
$ Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)] -> GradientsState a -> GradientsState a
forall a.
(TensorType a, Num a) =>
Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)] -> GradientsState a -> GradientsState a
updatePendingGradients Context NodeDef EdgeLabel
ctx [Maybe (Tensor Value a)]
inputGrads GradientsState a
nextState
else
GradientsState a -> BuildT Identity (GradientsState a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure GradientsState a
nextState
sumPendingGradient :: GradientCompatible a
=> PendingGradients a -> Build (Gradients a)
sumPendingGradient :: PendingGradients a -> Build (Gradients a)
sumPendingGradient = IntMap (BuildT Identity (Tensor Value a)) -> Build (Gradients a)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence (IntMap (BuildT Identity (Tensor Value a)) -> Build (Gradients a))
-> (PendingGradients a
-> IntMap (BuildT Identity (Tensor Value a)))
-> PendingGradients a
-> Build (Gradients a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Tensor Value a] -> Maybe (BuildT Identity (Tensor Value a)))
-> PendingGradients a -> IntMap (BuildT Identity (Tensor Value a))
forall a b. (a -> Maybe b) -> IntMap a -> IntMap b
IntMap.mapMaybe [Tensor Value a] -> Maybe (BuildT Identity (Tensor Value a))
forall a (f :: * -> *).
(a /= ByteString, a /= Bool, MonadBuild f, TensorType a) =>
[Tensor Value a] -> Maybe (f (Tensor Value a))
f
where
f :: [Tensor Value a] -> Maybe (f (Tensor Value a))
f [] = Maybe (f (Tensor Value a))
forall a. Maybe a
Nothing
f [x :: Tensor Value a
x] = f (Tensor Value a) -> Maybe (f (Tensor Value a))
forall a. a -> Maybe a
Just (Tensor Value a -> f (Tensor Value a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor Value a
x)
f xs :: [Tensor Value a]
xs = f (Tensor Value a) -> Maybe (f (Tensor Value a))
forall a. a -> Maybe a
Just (Tensor Build a -> f (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> f (Tensor Value a))
-> Tensor Build a -> f (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ [Tensor Value a] -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float, Variant]
t =>
[Tensor v'1 t] -> Tensor Build t
addN [Tensor Value a]
xs)
calculateInputGrads :: forall a. GradientCompatible a
=> FGL.Context NodeDef EdgeLabel
-> Gradients a
-> Graph
-> Build [Maybe (Tensor Value a)]
calculateInputGrads :: Context NodeDef EdgeLabel
-> Gradients a -> Graph -> Build [Maybe (Tensor Value a)]
calculateInputGrads (inputEdges :: Adj EdgeLabel
inputEdges, _, nodeDef :: NodeDef
nodeDef, _) outputGrads :: Gradients a
outputGrads gr :: Graph
gr = do
[Tensor Value a]
fullOutGrads <- OutputIx -> NodeName -> Gradients a -> Build [Tensor Value a]
forall a.
(TensorType a, Num a) =>
OutputIx -> NodeName -> Gradients a -> Build [Tensor Value a]
fullOutputGrads (NodeDef -> OutputIx
numOutputs NodeDef
nodeDef) (NodeDef -> NodeName
nodeDefName NodeDef
nodeDef)
Gradients a
outputGrads
(Maybe (Tensor Build a)
-> BuildT Identity (Maybe (Tensor Value a)))
-> [Maybe (Tensor Build a)] -> Build [Maybe (Tensor Value a)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Tensor Build a -> BuildT Identity (Tensor Value a))
-> Maybe (Tensor Build a)
-> BuildT Identity (Maybe (Tensor Value a))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Tensor Build a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render) ([Maybe (Tensor Build a)] -> Build [Maybe (Tensor Value a)])
-> [Maybe (Tensor Build a)] -> Build [Maybe (Tensor Value a)]
forall a b. (a -> b) -> a -> b
$ Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad (NodeDef
nodeDef NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "op" a) =>
LensLike' f s a
op) NodeDef
nodeDef [Output]
inputTensors [Tensor Value a]
fullOutGrads
where
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
edgeToTensor :: (EdgeLabel, Node) -> Output
edgeToTensor ((i :: OutputIx
i, _), n :: Node
n) =
case Graph -> Node -> Maybe NodeDef
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> Node -> Maybe a
FGL.lab Graph
gr Node
n of
Just edgeNodeDef :: NodeDef
edgeNodeDef -> OutputIx -> NodeName -> Output
Output OutputIx
i (Text -> NodeName
NodeName (Text -> NodeName) -> Text -> NodeName
forall a b. (a -> b) -> a -> b
$ NodeDef
edgeNodeDef NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name)
Nothing -> [Char] -> Output
forall a. HasCallStack => [Char] -> a
error ([Char] -> Output) -> [Char] -> Output
forall a b. (a -> b) -> a -> b
$ "calculateInputGrads: missing input node for "
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Text -> [Char]
Text.unpack (NodeDef
nodeDef NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name)
inputTensors :: [Output]
inputTensors = ((EdgeLabel, Node) -> Output) -> Adj EdgeLabel -> [Output]
forall a b. (a -> b) -> [a] -> [b]
map (EdgeLabel, Node) -> Output
edgeToTensor (Adj EdgeLabel -> [Output]) -> Adj EdgeLabel -> [Output]
forall a b. (a -> b) -> a -> b
$ ((EdgeLabel, Node) -> (EdgeLabel, Node) -> Ordering)
-> Adj EdgeLabel -> Adj EdgeLabel
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((EdgeLabel, Node) -> OutputIx)
-> (EdgeLabel, Node) -> (EdgeLabel, Node) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (EdgeLabel -> OutputIx
forall a b. (a, b) -> b
snd (EdgeLabel -> OutputIx)
-> ((EdgeLabel, Node) -> EdgeLabel)
-> (EdgeLabel, Node)
-> OutputIx
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EdgeLabel, Node) -> EdgeLabel
forall a b. (a, b) -> a
fst)) Adj EdgeLabel
inputEdges
fullOutputGrads :: (TensorType a, Num a)
=> OutputIx
-> NodeName
-> Gradients a
-> Build [Tensor Value a]
fullOutputGrads :: OutputIx -> NodeName -> Gradients a -> Build [Tensor Value a]
fullOutputGrads n :: OutputIx
n o :: NodeName
o gs :: Gradients a
gs =
(OutputIx -> BuildT Identity (Tensor Value a))
-> [OutputIx] -> Build [Tensor Value a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\i :: OutputIx
i -> BuildT Identity (Tensor Value a)
-> (Tensor Value a -> BuildT Identity (Tensor Value a))
-> Maybe (Tensor Value a)
-> BuildT Identity (Tensor Value a)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Tensor Build a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> BuildT Identity (Tensor Value a))
-> Tensor Build a -> BuildT Identity (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ OutputIx -> Tensor Build a
zero OutputIx
i) Tensor Value a -> BuildT Identity (Tensor Value a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Gradients a
gs Gradients a
-> FoldLike
(Maybe (Tensor Value a))
(Gradients a)
(Gradients a)
(Maybe (Tensor Value a))
(Maybe (Tensor Value a))
-> Maybe (Tensor Value a)
forall s a t b. s -> FoldLike a s t a b -> a
^. OutputIx -> Lens' (Gradients a) (Maybe (Tensor Value a))
forall v. OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt OutputIx
i)) [0..OutputIx
nOutputIx -> OutputIx -> OutputIx
forall a. Num a => a -> a -> a
-1]
where
zero :: OutputIx -> Tensor Build a
zero i :: OutputIx
i = Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
zerosLike (Tensor Build a -> Tensor Build a)
-> Tensor Build a -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT (OutputIx -> NodeName -> Output
Output OutputIx
i NodeName
o)
updatePendingGradients :: forall a. (TensorType a, Num a)
=> FGL.Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)]
-> GradientsState a
-> GradientsState a
updatePendingGradients :: Context NodeDef EdgeLabel
-> [Maybe (Tensor Value a)] -> GradientsState a -> GradientsState a
updatePendingGradients (inputEdges :: Adj EdgeLabel
inputEdges, _, nodeDef :: NodeDef
nodeDef, _) inputGrads :: [Maybe (Tensor Value a)]
inputGrads initState :: GradientsState a
initState =
(GradientsState a -> (EdgeLabel, Node) -> GradientsState a)
-> GradientsState a -> Adj EdgeLabel -> GradientsState a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' GradientsState a -> (EdgeLabel, Node) -> GradientsState a
go GradientsState a
initState Adj EdgeLabel
inputEdges
where
go :: GradientsState a -> (EdgeLabel, FGL.Node) -> GradientsState a
go :: GradientsState a -> (EdgeLabel, Node) -> GradientsState a
go state :: GradientsState a
state ((outIndex :: OutputIx
outIndex, OutputIx inIndex :: Node
inIndex), node :: Node
node) =
case Maybe (Tensor Value a)
maybeGradient of
Nothing -> GradientsState a
state
Just g :: Tensor Value a
g ->
GradientsState a
state GradientsState a
-> (GradientsState a -> GradientsState a) -> GradientsState a
forall s t. s -> (s -> t) -> t
& LensLike' f (GradientsState a) (Map Node (PendingGradients a))
forall a. Lens' (GradientsState a) (Map Node (PendingGradients a))
gradientsPending
LensLike' f (GradientsState a) (Map Node (PendingGradients a))
-> (([Tensor Value a] -> f [Tensor Value a])
-> Map Node (PendingGradients a)
-> f (Map Node (PendingGradients a)))
-> ([Tensor Value a] -> f [Tensor Value a])
-> GradientsState a
-> f (GradientsState a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Node
-> Lens'
(Map Node (PendingGradients a)) (Maybe (PendingGradients a))
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at Node
node
LensLike'
f (Map Node (PendingGradients a)) (Maybe (PendingGradients a))
-> (([Tensor Value a] -> f [Tensor Value a])
-> Maybe (PendingGradients a) -> f (Maybe (PendingGradients a)))
-> ([Tensor Value a] -> f [Tensor Value a])
-> Map Node (PendingGradients a)
-> f (Map Node (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LensLike' f (Maybe (PendingGradients a)) (PendingGradients a)
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty
LensLike' f (Maybe (PendingGradients a)) (PendingGradients a)
-> (([Tensor Value a] -> f [Tensor Value a])
-> PendingGradients a -> f (PendingGradients a))
-> ([Tensor Value a] -> f [Tensor Value a])
-> Maybe (PendingGradients a)
-> f (Maybe (PendingGradients a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OutputIx -> Lens' (PendingGradients a) (Maybe [Tensor Value a])
forall v. OutputIx -> Lens' (IntMap v) (Maybe v)
outputIxAt OutputIx
outIndex
LensLike' f (PendingGradients a) (Maybe [Tensor Value a])
-> (([Tensor Value a] -> f [Tensor Value a])
-> Maybe [Tensor Value a] -> f (Maybe [Tensor Value a]))
-> ([Tensor Value a] -> f [Tensor Value a])
-> PendingGradients a
-> f (PendingGradients a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Tensor Value a] -> f [Tensor Value a])
-> Maybe [Tensor Value a] -> f (Maybe [Tensor Value a])
forall (t :: * -> *) v.
(Monoid (t v), Foldable t) =>
Lens' (Maybe (t v)) (t v)
nonEmpty
(forall (f :: * -> *).
Identical f =>
([Tensor Value a] -> f [Tensor Value a])
-> GradientsState a -> f (GradientsState a))
-> ([Tensor Value a] -> [Tensor Value a])
-> GradientsState a
-> GradientsState a
forall s t a b. Setter s t a b -> (a -> b) -> s -> t
%~ (Tensor Value a
gTensor Value a -> [Tensor Value a] -> [Tensor Value a]
forall a. a -> [a] -> [a]
:)
where
badSizeErr :: Maybe (Tensor Value a)
badSizeErr = [Char] -> Maybe (Tensor Value a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> Maybe (Tensor Value a))
-> [Char] -> Maybe (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ [Char] -> Node -> Node -> [Char] -> [Char]
forall r. PrintfType r => [Char] -> r
printf "updatePendingGradients: bad input index \
\%d for inputGrads of length %d in %s"
Node
inIndex ([Maybe (Tensor Value a)] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Maybe (Tensor Value a)]
inputGrads)
(Text -> [Char]
forall a. Show a => a -> [Char]
show (NodeDef
nodeDef NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name))
maybeGradient :: Maybe (Tensor Value a)
maybeGradient = Maybe (Tensor Value a)
-> Maybe (Maybe (Tensor Value a)) -> Maybe (Tensor Value a)
forall a. a -> Maybe a -> a
fromMaybe Maybe (Tensor Value a)
badSizeErr ([Maybe (Tensor Value a)] -> Node -> Maybe (Maybe (Tensor Value a))
forall a. [a] -> Node -> Maybe a
safeIndex [Maybe (Tensor Value a)]
inputGrads Node
inIndex)
createGraph :: NodeName -> (NodeName -> NodeDef)
-> (Graph, Map NodeName FGL.Node)
createGraph :: NodeName -> (NodeName -> NodeDef) -> (Graph, Map NodeName Node)
createGraph nodeName :: NodeName
nodeName nodeDefLookup :: NodeName -> NodeDef
nodeDefLookup = ((NodeName -> NodeDef) -> Gr NodeName EdgeLabel -> Graph
forall (gr :: * -> * -> *) a c b.
DynGraph gr =>
(a -> c) -> gr a b -> gr c b
FGL.nmap NodeName -> NodeDef
nodeDefLookup Gr NodeName EdgeLabel
graph, Map NodeName Node
nodeMap)
where
parseTensorName :: Text -> Maybe (NodeName, OutputIx)
parseTensorName :: Text -> Maybe (NodeName, OutputIx)
parseTensorName n :: Text
n
| Text -> Bool
Text.null Text
n = [Char] -> Maybe (NodeName, OutputIx)
forall a. HasCallStack => [Char] -> a
error "parseTensorName: empty name"
| Text -> Char
Text.head Text
n Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== '^' = Maybe (NodeName, OutputIx)
forall a. Maybe a
Nothing
| Bool
otherwise =
let (nm :: Text
nm, indexStr :: Text
indexStr) = Text -> Text -> (Text, Text)
Text.breakOn ":" Text
n
index :: Node
index | Text -> Bool
Text.null Text
indexStr = 0
| Bool
otherwise = [Char] -> Node
forall a. Read a => [Char] -> a
read ([Char] -> Node) -> [Char] -> Node
forall a b. (a -> b) -> a -> b
$ Text -> [Char]
Text.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ Text -> Text
Text.tail Text
indexStr
in (NodeName, OutputIx) -> Maybe (NodeName, OutputIx)
forall a. a -> Maybe a
Just (Text -> NodeName
NodeName Text
nm, Node -> OutputIx
OutputIx Node
index)
collect :: Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State (Set NodeName)
(Map NodeName [(NodeName, OutputIx, OutputIx)])
collect :: Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
collect outgoingEdge :: Maybe (NodeName, OutputIx, OutputIx)
outgoingEdge nm :: NodeName
nm = do
let nextLookup :: Map NodeName [(NodeName, OutputIx, OutputIx)]
nextLookup = NodeName
-> [(NodeName, OutputIx, OutputIx)]
-> Map NodeName [(NodeName, OutputIx, OutputIx)]
forall k a. k -> a -> Map k a
Map.singleton NodeName
nm (Maybe (NodeName, OutputIx, OutputIx)
-> [(NodeName, OutputIx, OutputIx)]
forall a. Maybe a -> [a]
maybeToList Maybe (NodeName, OutputIx, OutputIx)
outgoingEdge)
Bool
seen <- (Set NodeName -> Bool) -> StateT (Set NodeName) Identity Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (NodeName -> Set NodeName -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member NodeName
nm)
(Set NodeName -> Set NodeName) -> StateT (Set NodeName) Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (NodeName -> Set NodeName -> Set NodeName
forall a. Ord a => a -> Set a -> Set a
Set.insert NodeName
nm)
if Bool
seen
then Map NodeName [(NodeName, OutputIx, OutputIx)]
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map NodeName [(NodeName, OutputIx, OutputIx)]
nextLookup
else do
let inputs :: [Text]
inputs = NodeName -> NodeDef
nodeDefLookup NodeName
nm NodeDef -> FoldLike [Text] NodeDef NodeDef [Text] [Text] -> [Text]
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike [Text] NodeDef NodeDef [Text] [Text]
forall (f :: * -> *) s a.
(Functor f, HasField s "input" a) =>
LensLike' f s a
input
recurse :: OutputIx
-> (NodeName, OutputIx)
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
recurse inIndex :: OutputIx
inIndex (parentName :: NodeName
parentName, outIndex :: OutputIx
outIndex) =
Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
collect ((NodeName, OutputIx, OutputIx)
-> Maybe (NodeName, OutputIx, OutputIx)
forall a. a -> Maybe a
Just (NodeName
nm, OutputIx
outIndex, OutputIx
inIndex)) NodeName
parentName
[Map NodeName [(NodeName, OutputIx, OutputIx)]]
subEdgeLookups <-
(OutputIx
-> (NodeName, OutputIx)
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)]))
-> [OutputIx]
-> [(NodeName, OutputIx)]
-> StateT
(Set NodeName)
Identity
[Map NodeName [(NodeName, OutputIx, OutputIx)]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM OutputIx
-> (NodeName, OutputIx)
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
recurse [0..] ([(NodeName, OutputIx)]
-> StateT
(Set NodeName)
Identity
[Map NodeName [(NodeName, OutputIx, OutputIx)]])
-> [(NodeName, OutputIx)]
-> StateT
(Set NodeName)
Identity
[Map NodeName [(NodeName, OutputIx, OutputIx)]]
forall a b. (a -> b) -> a -> b
$ (Text -> Maybe (NodeName, OutputIx))
-> [Text] -> [(NodeName, OutputIx)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Text -> Maybe (NodeName, OutputIx)
parseTensorName [Text]
inputs
Map NodeName [(NodeName, OutputIx, OutputIx)]
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map NodeName [(NodeName, OutputIx, OutputIx)]
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)]))
-> Map NodeName [(NodeName, OutputIx, OutputIx)]
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
forall a b. (a -> b) -> a -> b
$ ([(NodeName, OutputIx, OutputIx)]
-> [(NodeName, OutputIx, OutputIx)]
-> [(NodeName, OutputIx, OutputIx)])
-> [Map NodeName [(NodeName, OutputIx, OutputIx)]]
-> Map NodeName [(NodeName, OutputIx, OutputIx)]
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(a -> a -> a) -> f (Map k a) -> Map k a
Map.unionsWith [(NodeName, OutputIx, OutputIx)]
-> [(NodeName, OutputIx, OutputIx)]
-> [(NodeName, OutputIx, OutputIx)]
forall a. [a] -> [a] -> [a]
(++) (Map NodeName [(NodeName, OutputIx, OutputIx)]
nextLookupMap NodeName [(NodeName, OutputIx, OutputIx)]
-> [Map NodeName [(NodeName, OutputIx, OutputIx)]]
-> [Map NodeName [(NodeName, OutputIx, OutputIx)]]
forall a. a -> [a] -> [a]
:[Map NodeName [(NodeName, OutputIx, OutputIx)]]
subEdgeLookups)
edgeLookup :: Map NodeName [(NodeName, OutputIx, OutputIx)]
edgeLookup = State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
-> Set NodeName -> Map NodeName [(NodeName, OutputIx, OutputIx)]
forall s a. State s a -> s -> a
evalState (Maybe (NodeName, OutputIx, OutputIx)
-> NodeName
-> State
(Set NodeName) (Map NodeName [(NodeName, OutputIx, OutputIx)])
collect Maybe (NodeName, OutputIx, OutputIx)
forall a. Maybe a
Nothing NodeName
nodeName) Set NodeName
forall a. Set a
Set.empty
nodeMap :: Map NodeName Node
nodeMap = [(NodeName, Node)] -> Map NodeName Node
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(NodeName, Node)] -> Map NodeName Node)
-> [(NodeName, Node)] -> Map NodeName Node
forall a b. (a -> b) -> a -> b
$ [NodeName] -> [Node] -> [(NodeName, Node)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Map NodeName [(NodeName, OutputIx, OutputIx)] -> [NodeName]
forall k a. Map k a -> [k]
Map.keys Map NodeName [(NodeName, OutputIx, OutputIx)]
edgeLookup) [0..]
graph :: Gr NodeName EdgeLabel
graph = [LNode NodeName] -> [LEdge EdgeLabel] -> Gr NodeName EdgeLabel
forall (gr :: * -> * -> *) a b.
Graph gr =>
[LNode a] -> [LEdge b] -> gr a b
FGL.mkGraph ((NodeName, Node) -> LNode NodeName
forall a b. (a, b) -> (b, a)
swap ((NodeName, Node) -> LNode NodeName)
-> [(NodeName, Node)] -> [LNode NodeName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map NodeName Node -> [(NodeName, Node)]
forall k a. Map k a -> [(k, a)]
Map.toList Map NodeName Node
nodeMap)
[ (Map NodeName Node
nodeMap Map NodeName Node -> NodeName -> Node
forall k a. Ord k => Map k a -> k -> a
Map.! NodeName
n, Map NodeName Node
nodeMap Map NodeName Node -> NodeName -> Node
forall k a. Ord k => Map k a -> k -> a
Map.! NodeName
m, (OutputIx
i, OutputIx
j))
| (n :: NodeName
n, edges :: [(NodeName, OutputIx, OutputIx)]
edges) <- Map NodeName [(NodeName, OutputIx, OutputIx)]
-> [(NodeName, [(NodeName, OutputIx, OutputIx)])]
forall k a. Map k a -> [(k, a)]
Map.toList Map NodeName [(NodeName, OutputIx, OutputIx)]
edgeLookup
, (m :: NodeName
m, i :: OutputIx
i, j :: OutputIx
j) <- [(NodeName, OutputIx, OutputIx)]
edges
]
type GradientFunc a = NodeDef
-> [Output]
-> [Tensor Value a]
-> [Maybe (Tensor Build a)]
toT :: Output -> Tensor Build a
toT :: Output -> Tensor Build a
toT = Build Output -> Tensor Build a
forall (v :: * -> *) a. TensorKind v => v Output -> Tensor v a
Tensor (Build Output -> Tensor Build a)
-> (Output -> Build Output) -> Output -> Tensor Build a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> Build Output
forall (f :: * -> *) a. Applicative f => a -> f a
pure
flatSlice :: forall v1 t . TensorType t
=> Tensor v1 t
-> Int32
-> Int32
-> Tensor Build t
flatSlice :: Tensor v1 t -> Int32 -> Int32 -> Tensor Build t
flatSlice t :: Tensor v1 t
t begin :: Int32
begin size :: Int32
size = Tensor v1 t
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor v1 t
t ([Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [Int32
begin]) ([Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [Int32
size])
nodeDefName :: NodeDef -> NodeName
nodeDefName :: NodeDef -> NodeName
nodeDefName = Text -> NodeName
NodeName (Text -> NodeName) -> (NodeDef -> Text) -> NodeDef -> NodeName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FoldLike Text NodeDef NodeDef Text Text -> NodeDef -> Text
forall a s t b. FoldLike a s t a b -> s -> a
view FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "name" a) =>
LensLike' f s a
name
gradForBinaryCwise :: ( OneOf '[ Int32, Int64, Float, Double, Complex Float, Complex Double ] t
)
=> (Tensor v1 t, Tensor v1 t)
-> (Tensor v1 t, Tensor v1 t)
-> [ Maybe (Tensor Build t) ]
gradForBinaryCwise :: (Tensor v1 t, Tensor v1 t)
-> (Tensor v1 t, Tensor v1 t) -> [Maybe (Tensor Build t)]
gradForBinaryCwise (x :: Tensor v1 t
x, gx :: Tensor v1 t
gx) (y :: Tensor v1 t
y, gy :: Tensor v1 t
gy) =
[ Tensor Build t -> Maybe (Tensor Build t)
forall a. a -> Maybe a
Just Tensor Build t
dx
, Tensor Build t -> Maybe (Tensor Build t)
forall a. a -> Maybe a
Just Tensor Build t
dy ]
where
dx :: Tensor Build t
dx = Tensor Build t -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor v1 t -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor v1 t
gx Tensor Build Int32
rx) Tensor Build Int32
sx
dy :: Tensor Build t
dy = Tensor Build t -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor v1 t -> Tensor Build Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor v1 t
gy Tensor Build Int32
ry) Tensor Build Int32
sy
sx :: Tensor Build Int32
sx = Tensor v1 t -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor v1 t
x
sy :: Tensor Build Int32
sy = Tensor v1 t -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor v1 t
y
(rx :: Tensor Build Int32
rx, ry :: Tensor Build Int32
ry) = Tensor Build Int32
-> Tensor Build Int32 -> (Tensor Build Int32, Tensor Build Int32)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
broadcastGradientArgs Tensor Build Int32
sx Tensor Build Int32
sy
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
opGrad :: Text -> GradientFunc a
opGrad "Abs" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dz Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* Tensor Build a -> Tensor Build a
forall a. Num a => a -> a
signum Tensor Build a
x]
opGrad "Neg" _ [_] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a
forall a. Num a => a -> a
negate (Tensor Build a -> Tensor Build a)
-> Tensor Build a -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dz]
opGrad "Relu" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
reluGrad Tensor Value a
dz Tensor Build a
x]
opGrad "ReluGrad" _ [_, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x ] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
reluGrad Tensor Value a
dz Tensor Build a
x, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.zerosLike Tensor Build a
x]
opGrad "Tanh" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
tanhGrad (Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
tanh Tensor Build a
x) Tensor Value a
dz]
opGrad "Sigmoid" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
sigmoidGrad (Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
sigmoid Tensor Build a
x) Tensor Value a
dz]
opGrad "Concat" _ _ix :: [Output]
_ix [dy :: Tensor Value a
dy]
| Node
m Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== 1 = Maybe (Tensor Build a)
forall a. Maybe a
Nothing Maybe (Tensor Build a)
-> [Maybe (Tensor Build a)] -> [Maybe (Tensor Build a)]
forall a. a -> [a] -> [a]
: [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dy]
| Bool
otherwise = Maybe (Tensor Build a)
forall a. Maybe a
Nothing Maybe (Tensor Build a)
-> [Maybe (Tensor Build a)] -> [Maybe (Tensor Build a)]
forall a. a -> [a] -> [a]
: (Tensor Build a -> Maybe (Tensor Build a))
-> [Tensor Build a] -> [Maybe (Tensor Build a)]
forall a b. (a -> b) -> [a] -> [b]
map Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just ([Tensor Build a]
dx [Tensor Build a] -> [Tensor Build Int32] -> [Tensor Build a]
forall (v'1 :: * -> *) (v'2 :: * -> *).
[Tensor v'1 a] -> [Tensor v'2 Int32] -> [Tensor Build a]
`reshapeZip` [Tensor Build Int32]
s)
where
reshapeZip :: [Tensor v'1 a] -> [Tensor v'2 Int32] -> [Tensor Build a]
reshapeZip = (Tensor v'1 a -> Tensor v'2 Int32 -> Tensor Build a)
-> [Tensor v'1 a] -> [Tensor v'2 Int32] -> [Tensor Build a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor v'1 a -> Tensor v'2 Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape
dx :: [Tensor Build a]
dx = Int64
-> Tensor Value a
-> Tensor Build Int32
-> Tensor Build Int32
-> [Tensor Build a]
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t tlen.
(TensorType t, OneOf '[Int32, Int64] tlen) =>
Int64
-> Tensor v'1 t
-> Tensor v'2 tlen
-> Tensor v'3 Int32
-> [Tensor Build t]
CoreOps.splitV (Node -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Node
m) Tensor Value a
dy Tensor Build Int32
ki Tensor Build Int32
_i
s :: [Tensor Build Int32]
s :: [Tensor Build Int32]
s = (Tensor Build a -> Tensor Build Int32)
-> [Tensor Build a] -> [Tensor Build Int32]
forall a b. (a -> b) -> [a] -> [b]
map Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape [Tensor Build a]
x
x :: [Tensor Build a]
x :: [Tensor Build a]
x = (Output -> Tensor Build a) -> [Output] -> [Tensor Build a]
forall a b. (a -> b) -> [a] -> [b]
map Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT ([Output] -> [Tensor Build a]) -> [Output] -> [Tensor Build a]
forall a b. (a -> b) -> a -> b
$ [Output] -> [Output]
forall a. [a] -> [a]
tail [Output]
_ix
_i :: Tensor Build Int32
_i = Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT ([Output] -> Output
forall a. [a] -> a
head [Output]
_ix) Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64, Word16, Word64, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.floorMod` Tensor Build Int32
n
i :: Tensor Build Int32
i = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Build Int32
_i (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1 :: Int32]
ki :: Tensor Build Int32
ki :: Tensor Build Int32
ki = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 ([Tensor Build Int32] -> Tensor Build Int32)
-> [Tensor Build Int32] -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ (Tensor Build Int32 -> Tensor Build Int32)
-> [Tensor Build Int32] -> [Tensor Build Int32]
forall a b. (a -> b) -> [a] -> [b]
map (\t :: Tensor Build Int32
t -> Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor Build Int32
t Tensor Build Int32
i (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1 :: Int32]) [Tensor Build Int32]
s
m :: Node
m = [Tensor Build a] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Tensor Build a]
x
n :: Tensor Build Int32
n = Tensor Build a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank ([Tensor Build a] -> Tensor Build a
forall a. [a] -> a
head [Tensor Build a]
x)
opGrad "Square" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] =
[Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` (2 Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* Tensor Build a
x)]
opGrad "Gather" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
indices] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t tindices
tnumsegments.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tindices,
OneOf '[Int32, Int64] tnumsegments) =>
Tensor v'1 t
-> Tensor v'2 tindices -> Tensor v'3 tnumsegments -> Tensor Build t
CoreOps.unsortedSegmentSum Tensor Build a
values Tensor Build Int32
indices' Tensor Build Int32
numRows
, Maybe (Tensor Build a)
forall a. Maybe a
Nothing
]
where
denseShape :: Tensor Build Int32
denseShape = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
numRows :: Tensor Build Int32
numRows = Tensor Build Int32 -> Tensor Build Int32
forall a (v :: * -> *).
TensorType a =>
Tensor v a -> Tensor Build a
scalarize (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Int32 -> Int32 -> Tensor Build Int32
forall (v1 :: * -> *) t.
TensorType t =>
Tensor v1 t -> Int32 -> Int32 -> Tensor Build t
flatSlice Tensor Build Int32
denseShape 0 1
valuesShape :: Tensor Build Int32
valuesShape = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 [ Tensor Build Int32
allDimensions
, Tensor Build Int32 -> Int32 -> Int32 -> Tensor Build Int32
forall (v1 :: * -> *) t.
TensorType t =>
Tensor v1 t -> Int32 -> Int32 -> Tensor Build t
flatSlice Tensor Build Int32
denseShape 1 (-1)
]
values :: Tensor Build a
values = Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz Tensor Build Int32
valuesShape
indices' :: Tensor Build Int32
indices' = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Build Int32
indices Tensor Build Int32
allDimensions :: Tensor Build Int32
opGrad "Max" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
indices] [dz :: Tensor Value a
dz] =
[Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
indicators Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` Tensor Build a
numSelected Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* Tensor Build a
dz', Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
where
sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
outputShapeKeptDims :: Tensor Build Int32
outputShapeKeptDims = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall t1 t2 (v1 :: * -> *) (v2 :: * -> *).
(OneOf '[Int32, Int64] t1, OneOf '[Int32, Int64] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape Tensor Build Int32
sx (Tensor Build Int32
indices :: Tensor Build Int32)
y :: Tensor Build a
y = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.max Tensor Build a
x Tensor Build Int32
indices
y' :: Tensor Build a
y' = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Build a
y Tensor Build Int32
outputShapeKeptDims
dz' :: Tensor Build a
dz' = Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz Tensor Build Int32
outputShapeKeptDims
indicators :: Tensor Build a
indicators = Tensor Build Bool -> Tensor Build a
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast (Tensor Build Bool -> Tensor Build a)
-> Tensor Build Bool -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a -> Tensor Build Bool
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Bool, ByteString, Int16, Int32,
Int64, Int8, Word16, Word32, Word64, Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build Bool
CoreOps.equal Tensor Build a
y' Tensor Build a
x
numSelected :: Tensor Build a
numSelected = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Build a
indicators Tensor Build Int32
indices) Tensor Build Int32
outputShapeKeptDims
opGrad "Min" u :: NodeDef
u v :: [Output]
v w :: [Tensor Value a]
w = Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad "Max" NodeDef
u [Output]
v [Tensor Value a]
w
opGrad "Maximum" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
(Tensor Build a, Tensor Build a)
-> (Tensor Build a, Tensor Build a) -> [Maybe (Tensor Build a)]
forall t (v1 :: * -> *).
OneOf
'[Int32, Int64, Float, Double, Complex Float, Complex Double] t =>
(Tensor v1 t, Tensor v1 t)
-> (Tensor v1 t, Tensor v1 t) -> [Maybe (Tensor Build t)]
gradForBinaryCwise (Tensor Build a
x, Tensor Build a
gx) (Tensor Build a
y, Tensor Build a
gy)
where
xmask :: Tensor Build Bool
xmask = Tensor Build a -> Tensor Build a -> Tensor Build Bool
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build Bool
CoreOps.greaterEqual Tensor Build a
x Tensor Build a
y
gx :: Tensor Build a
gx = Tensor Build Bool
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
CoreOps.select Tensor Build Bool
xmask Tensor Value a
dz (Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.zerosLike Tensor Value a
dz)
gy :: Tensor Build a
gy = Tensor Build Bool
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
CoreOps.select (Tensor Build Bool -> Tensor Build Bool
forall (v'1 :: * -> *). Tensor v'1 Bool -> Tensor Build Bool
CoreOps.logicalNot Tensor Build Bool
xmask) Tensor Value a
dz (Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.zerosLike Tensor Value a
dz)
opGrad "Sum" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
indices] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.tile Tensor Build a
grad Tensor Build Int32
tileScaling, Maybe (Tensor Build a)
forall a. Maybe a
Nothing ]
where
sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
outputShapeKeptDims :: Tensor Build Int32
outputShapeKeptDims = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall t1 t2 (v1 :: * -> *) (v2 :: * -> *).
(OneOf '[Int32, Int64] t1, OneOf '[Int32, Int64] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape Tensor Build Int32
sx (Tensor Build Int32
indices :: Tensor Build Int32)
tileScaling :: Tensor Build Int32
tileScaling = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v1 :: * -> *) (v2 :: * -> *).
Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv Tensor Build Int32
sx Tensor Build Int32
outputShapeKeptDims
grad :: Tensor Build a
grad = Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz Tensor Build Int32
outputShapeKeptDims
opGrad "Mean" u :: NodeDef
u v :: [Output]
v@[Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, _] w :: [Tensor Value a]
w =
[Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
dz Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` (Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.stopGradient (Tensor Build a -> Tensor Build a)
-> Tensor Build a -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast (Tensor Build Int32 -> Tensor Build a)
-> Tensor Build Int32 -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32
factor), Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
where
[Just dz :: Tensor Build a
dz, Nothing] = Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad "Sum" NodeDef
u [Output]
v [Tensor Value a]
w
inputShape :: Tensor Build Int32
inputShape = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
outputShape :: Tensor Build Int32
outputShape = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
dz :: Tensor Build a)
inputSize :: Tensor Build Int32
inputSize = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.prod Tensor Build Int32
inputShape (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
rangeOfRank Tensor Build Int32
inputShape
outputSize :: Tensor Build Int32
outputSize = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.prod Tensor Build Int32
outputShape (Tensor Build Int32 -> Tensor Build Int32)
-> Tensor Build Int32 -> Tensor Build Int32
forall a b. (a -> b) -> a -> b
$ Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
rangeOfRank Tensor Build Int32
outputShape
factor :: Tensor Build Int32
factor = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v1 :: * -> *) (v2 :: * -> *).
Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv Tensor Build Int32
inputSize Tensor Build Int32
outputSize
opGrad "Add" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Value a
dz Tensor Build Int32
rx) Tensor Build Int32
sx
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Value a
dz Tensor Build Int32
ry) Tensor Build Int32
sy ]
where
sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
sy :: Tensor Build Int32
sy = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
y :: Tensor Build a)
(rx :: Tensor Build Int32
rx, ry :: Tensor Build Int32
ry) = Tensor Build Int32
-> Tensor Build Int32 -> (Tensor Build Int32, Tensor Build Int32)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
broadcastGradientArgs Tensor Build Int32
sx Tensor Build Int32
sy
opGrad "AddN" _ inputs :: [Output]
inputs [dz :: Tensor Value a
dz] =
(Output -> Maybe (Tensor Build a))
-> [Output] -> [Maybe (Tensor Build a)]
forall a b. (a -> b) -> [a] -> [b]
map ((Maybe (Tensor Build a) -> Output -> Maybe (Tensor Build a)
forall a b. a -> b -> a
const (Maybe (Tensor Build a) -> Output -> Maybe (Tensor Build a))
-> (Tensor Value a -> Maybe (Tensor Build a))
-> Tensor Value a
-> Output
-> Maybe (Tensor Build a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> (Tensor Value a -> Tensor Build a)
-> Tensor Value a
-> Maybe (Tensor Build a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr) Tensor Value a
dz) [Output]
inputs
opGrad "Sub" u :: NodeDef
u v :: [Output]
v w :: [Tensor Value a]
w =
[Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just Tensor Build a
x, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (-Tensor Build a
y)]
where
[Just x :: Tensor Build a
x, Just y :: Tensor Build a
y] = Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad "Add" NodeDef
u [Output]
v [Tensor Value a]
w
opGrad "SoftmaxCrossEntropyWithLogits" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz, _] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall t (v1 :: * -> *) (v2 :: * -> *).
TensorType t =>
Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims Tensor Value a
dz (-1) Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* (Tensor Build a, Tensor Build a) -> Tensor Build a
forall a b. (a, b) -> b
snd (Tensor Build a
-> Tensor Build a -> (Tensor Build a, Tensor Build a)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
softmaxCrossEntropyWithLogits Tensor Build a
x Tensor Build a
y)
, Maybe (Tensor Build a)
forall a. Maybe a
Nothing ]
opGrad "Mul" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum (Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Build a
y) Tensor Build Int32
rx) Tensor Build Int32
sx
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum (Tensor Build a
x Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Value a
dz) Tensor Build Int32
ry) Tensor Build Int32
sy ]
where
sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
sy :: Tensor Build Int32
sy = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
y :: Tensor Build a)
(rx :: Tensor Build Int32
rx, ry :: Tensor Build Int32
ry) = Tensor Build Int32
-> Tensor Build Int32 -> (Tensor Build Int32, Tensor Build Int32)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
broadcastGradientArgs Tensor Build Int32
sx Tensor Build Int32
sy
opGrad "Div" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum (Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` Tensor Build a
y) Tensor Build Int32
rx) Tensor Build Int32
sx
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape (Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum (Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` (Tensor Build a -> Tensor Build a
forall a. Num a => a -> a
negate Tensor Build a
x Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` (Tensor Build a
y Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
* Tensor Build a
y)))
Tensor Build Int32
ry)
Tensor Build Int32
sy
]
where
sx :: Tensor Build Int32
sx = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
sy :: Tensor Build Int32
sy = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
y :: Tensor Build a)
(rx :: Tensor Build Int32
rx, ry :: Tensor Build Int32
ry) = Tensor Build Int32
-> Tensor Build Int32 -> (Tensor Build Int32, Tensor Build Int32)
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor v'2 t -> (Tensor Build t, Tensor Build t)
broadcastGradientArgs Tensor Build Int32
sx Tensor Build Int32
sy
opGrad "MatMul" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
let transposeA :: Bool
transposeA = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "transpose_a"
transposeB :: Bool
transposeB = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "transpose_b"
transAttrs :: a -> a -> OpDef -> OpDef
transAttrs a :: a
a b :: a
b =
(Text -> Lens' OpDef a
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "transpose_a" (forall (f :: * -> *). Identical f => LensLike' f OpDef a)
-> a -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ a
a) (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef a
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "transpose_b" (forall (f :: * -> *). Identical f => LensLike' f OpDef a)
-> a -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ a
b)
in case (Bool
transposeA, Bool
transposeB) of
(False, False) ->
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
False Bool
True) Tensor Value a
dz Tensor Build a
y
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
True Bool
False) Tensor Build a
x Tensor Value a
dz]
(False, True) ->
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul Tensor Value a
dz Tensor Build a
y
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
True Bool
False) Tensor Value a
dz Tensor Build a
x]
(True, False) ->
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
False Bool
True) Tensor Build a
y Tensor Value a
dz
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul Tensor Build a
x Tensor Value a
dz]
(True, True) ->
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
True Bool
True) Tensor Build a
y Tensor Value a
dz
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
matMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
transAttrs Bool
True Bool
True) Tensor Value a
dz Tensor Build a
x]
opGrad "BatchMatMul" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
let adjX :: Bool
adjX = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "adj_x"
adjY :: Bool
adjY = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "adj_y"
adjAttrs :: a -> a -> OpDef -> OpDef
adjAttrs a :: a
a b :: a
b =
(Text -> Lens' OpDef a
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "adj_x" (forall (f :: * -> *). Identical f => LensLike' f OpDef a)
-> a -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ a
a) (OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef a
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "adj_y" (forall (f :: * -> *). Identical f => LensLike' f OpDef a)
-> a -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ a
b)
in case (Bool
adjX, Bool
adjY) of
(False, False) ->
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
False Bool
True) Tensor Value a
dz Tensor Build a
y
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
True Bool
False) Tensor Build a
x Tensor Value a
dz]
(False, True) ->
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul Tensor Value a
dz Tensor Build a
y
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
True Bool
False) Tensor Value a
dz Tensor Build a
x]
(True, False) ->
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
False Bool
True) Tensor Build a
y Tensor Value a
dz
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul Tensor Build a
x Tensor Value a
dz]
(True, True) ->
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
True Bool
True) Tensor Build a
y Tensor Value a
dz
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
(OpDef -> OpDef) -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.batchMatMul' (Bool -> Bool -> OpDef -> OpDef
forall a a. (Attribute a, Attribute a) => a -> a -> OpDef -> OpDef
adjAttrs Bool
True Bool
True) Tensor Value a
dz Tensor Build a
x]
opGrad "Transpose" _ [_, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
p] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.transpose Tensor Value a
dz
(Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t.
OneOf '[Int32, Int64] t =>
Tensor v'1 t -> Tensor Build t
CoreOps.invertPermutation Tensor Build Int32
p :: Tensor Build Int32)
, Maybe (Tensor Build a)
forall a. Maybe a
Nothing
]
opGrad "Conv2D" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build Int32
-> Tensor Build a
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Int32, Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 Int32
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
CoreOps.conv2DBackpropInput'
((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "use_cudnn_on_gpu" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
useCudnnOnGpu)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
ByteString
padding (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
x) Tensor Build a
y Tensor Value a
dz
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build a
-> Tensor Build Int32
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
CoreOps.conv2DBackpropFilter'
((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "use_cudnn_on_gpu" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
useCudnnOnGpu)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
ByteString
padding Tensor Build a
x (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
y) Tensor Value a
dz
]
where
strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
useCudnnOnGpu :: Bool
useCudnnOnGpu = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "use_cudnn_on_gpu" :: Bool
dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString
opGrad "Conv2DBackpropInput" nodeDef :: NodeDef
nodeDef [_, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
[ Maybe (Tensor Build a)
forall a. Maybe a
Nothing
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Value a
-> Tensor Build Int32
-> Tensor Build a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
CoreOps.conv2DBackpropFilter'
((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "use_cudnn_on_gpu" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
useCudnnOnGpu)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
ByteString
padding Tensor Value a
dz (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
x) Tensor Build a
y
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString -> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.conv2D'
((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "use_cudnn_on_gpu" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
useCudnnOnGpu)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
ByteString
padding Tensor Value a
dz Tensor Build a
x
]
where
strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
useCudnnOnGpu :: Bool
useCudnnOnGpu = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "use_cudnn_on_gpu" :: Bool
dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString
opGrad "DepthwiseConv2dNative" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build Int32
-> Tensor Build a
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 Int32
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
CoreOps.depthwiseConv2dNativeBackpropInput'
((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
ByteString
padding (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
x) Tensor Build a
y Tensor Value a
dz
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build a
-> Tensor Build Int32
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
CoreOps.depthwiseConv2dNativeBackpropFilter'
((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
ByteString
padding Tensor Build a
x (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
y) Tensor Value a
dz
]
where
strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString
opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef :: NodeDef
nodeDef [_, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
y] [dz :: Tensor Value a
dz] =
[ Maybe (Tensor Build a)
forall a. Maybe a
Nothing
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Value a
-> Tensor Build Int32
-> Tensor Build a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 Int32
-> Tensor v'3 t
-> Tensor Build t
CoreOps.depthwiseConv2dNativeBackpropFilter'
((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
ByteString
padding Tensor Value a
dz (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build a
x) Tensor Build a
y
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString -> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> ByteString -> Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.depthwiseConv2dNative'
((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
ByteString
padding Tensor Value a
dz Tensor Build a
x
]
where
strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString
opGrad "MaxPool" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> ByteString
-> Tensor Build a
-> Tensor Build a
-> Tensor Value a
-> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
OneOf
'[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
Float]
t =>
(OpDef -> OpDef)
-> ByteString
-> Tensor v'1 t
-> Tensor v'2 t
-> Tensor v'3 t
-> Tensor Build t
CoreOps.maxPoolGrad'
((Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "ksize" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
ksize)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef [Int64]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "strides" (forall (f :: * -> *). Identical f => LensLike' f OpDef [Int64])
-> [Int64] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Int64]
strides)
(OpDef -> OpDef) -> (OpDef -> OpDef) -> OpDef -> OpDef
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "data_format" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
dataFormat))
ByteString
padding Tensor Build a
x Tensor Build a
output Tensor Value a
dz
]
where
output :: Tensor Build a
output :: Tensor Build a
output = Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT (Output -> Tensor Build a) -> Output -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ OutputIx -> NodeName -> Output
Output 0 (NodeDef -> NodeName
nodeDefName NodeDef
nodeDef)
ksize :: [Int64]
ksize = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "ksize" :: [Int64]
strides :: [Int64]
strides = NodeDef -> Text -> [Int64]
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "strides" :: [Int64]
padding :: ByteString
padding = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "padding" :: ByteString
dataFormat :: ByteString
dataFormat = NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "data_format" :: ByteString
opGrad "Reshape" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, _] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz (Tensor Build Int32 -> Tensor Build a)
-> Tensor Build Int32 -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a), Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "ExpandDims" n :: NodeDef
n xs :: [Output]
xs@[Output -> Tensor Build Any
forall a. Output -> Tensor Build a
toT -> Tensor Build Any
_, _] dzs :: [Tensor Value a]
dzs@[_] = Text -> GradientFunc a
forall a. GradientCompatible a => Text -> GradientFunc a
opGrad "Reshape" NodeDef
n [Output]
xs [Tensor Value a]
dzs
opGrad "Squeeze" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
reshape Tensor Value a
dz (Tensor Build Int32 -> Tensor Build a)
-> Tensor Build Int32 -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)]
opGrad "Pad" _ [Output -> Tensor Build Float
forall a. Output -> Tensor Build a
toT -> Tensor Build Float
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
padPattern] [dz :: Tensor Value a
dz] =
[Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor Value a
dz Tensor Build Int32
gradientSliceBegin Tensor Build Int32
gradientSliceSize, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
where
v1 :: Tensor Build Int32
v1 = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1]
rankx' :: Tensor Build Int32
rankx' = Tensor Build Float -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank (Tensor Build Float
x :: Tensor Build Float)
rankx :: Tensor Build Int32
rankx = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build Int32
rankx' Tensor Build Int32
v1
padPatternSliceSize :: Tensor Build Int32
padPatternSliceSize = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 [Tensor Build Int32
rankx, Tensor Build Int32
v1]
padPatternSliceBegin :: Tensor Build Int32
padPatternSliceBegin = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [0, 0]
Tensor Build Int32
padPatternSliced :: Tensor Build Int32 = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor Build Int32
padPattern Tensor Build Int32
padPatternSliceBegin Tensor Build Int32
padPatternSliceSize
gradientSliceBegin :: Tensor Build Int32
gradientSliceBegin = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build Int32
padPatternSliced Tensor Build Int32
rankx
gradientSliceSize :: Tensor Build Int32
gradientSliceSize = Tensor Build Float -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build Float
x :: Tensor Build Float)
opGrad "Slice" _ [Output -> Tensor Build Float
forall a. Output -> Tensor Build a
toT -> Tensor Build Float
inputvec, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
beginvec, _] [dz :: Tensor Value a
dz] =
[Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.pad Tensor Value a
dz Tensor Build Int32
paddings, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
where
v1 :: Tensor Build Int32
v1 = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1 :: Int32]
inputRank' :: Tensor Build Int32
inputRank' = Tensor Build Float -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank (Tensor Build Float
inputvec :: Tensor Build Float)
inputRank :: Tensor Build Int32
inputRank = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build Int32
inputRank' Tensor Build Int32
v1
padShape :: Tensor Build Int32
padShape = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 [Tensor Build Int32
inputRank, Tensor Build Int32
v1]
beforePad :: Tensor Build Int32
beforePad = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build Int32
beginvec Tensor Build Int32
padShape
afterPad :: Tensor Build Int32
afterPad = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape (Tensor Build Float -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Build Float
inputvec Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall a. Num a => a -> a -> a
- Tensor Value a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor Value a
dz Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall a. Num a => a -> a -> a
- Tensor Build Int32
beginvec) Tensor Build Int32
padShape
paddings :: Tensor Build Int32
paddings = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 1 [Tensor Build Int32
beforePad, Tensor Build Int32
afterPad]
opGrad "BatchToSpaceND" _ [_, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT @Int32 -> Tensor Build Int32
blockShape, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT @Int32 -> Tensor Build Int32
crops] [dz :: Tensor Value a
dz] =
[Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t
tblock_shape tpaddings.
(TensorType t, OneOf '[Int32, Int64] tblock_shape,
OneOf '[Int32, Int64] tpaddings) =>
Tensor v'1 t
-> Tensor v'2 tblock_shape
-> Tensor v'3 tpaddings
-> Tensor Build t
CoreOps.spaceToBatchND Tensor Value a
dz Tensor Build Int32
blockShape Tensor Build Int32
crops, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "SpaceToBatchND" _ [_, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT @Int32 -> Tensor Build Int32
blockShape, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT @Int32 -> Tensor Build Int32
paddings] [dz :: Tensor Value a
dz] =
[Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t
tblock_shape tpaddings.
(TensorType t, OneOf '[Int32, Int64] tblock_shape,
OneOf '[Int32, Int64] tpaddings) =>
Tensor v'1 t
-> Tensor v'2 tblock_shape
-> Tensor v'3 tpaddings
-> Tensor Build t
CoreOps.batchToSpaceND Tensor Value a
dz Tensor Build Int32
blockShape Tensor Build Int32
paddings, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "OneHot" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "TruncatedNormal" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "RefIdentity" _ _ [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dz]
opGrad "Cast" nodeDef :: NodeDef
nodeDef _ [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just Tensor Build a
reverseCast]
where
reverseCast :: Tensor Build a
reverseCast =
[Int64] -> Build OpDef -> Tensor Build a
forall a. PureResult a => [Int64] -> Build OpDef -> a
pureOp [] (Build OpDef -> Tensor Build a) -> Build OpDef -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ OpDef -> Build OpDef
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpType -> OpDef
opDef "Cast"
OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "DstT" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ (NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "SrcT" :: ByteString)
OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "SrcT" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ (NodeDef -> Text -> ByteString
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "DstT" :: ByteString)
OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Lens' OpDef [Output]
forall (f :: * -> *). Identical f => LensLike' f OpDef [Output]
opInputs (forall (f :: * -> *). Identical f => LensLike' f OpDef [Output])
-> [Output] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Tensor Value a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput Tensor Value a
dz])
opGrad "DynamicStitch" nodeDef :: NodeDef
nodeDef inputs :: [Output]
inputs [dz :: Tensor Value a
dz] =
Node -> Maybe (Tensor Build a) -> [Maybe (Tensor Build a)]
forall a. Node -> a -> [a]
replicate Node
halfLen Maybe (Tensor Build a)
forall a. Maybe a
Nothing [Maybe (Tensor Build a)]
-> [Maybe (Tensor Build a)] -> [Maybe (Tensor Build a)]
forall a. [a] -> [a] -> [a]
++ [Maybe (Tensor Build a)]
valuesGrads
where
halfLen :: Node
halfLen =
let len :: Node
len = [Output] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Output]
inputs
half :: Node
half = Node
len Node -> Node -> Node
forall a. Integral a => a -> a -> a
`div` 2
in if 2 Node -> Node -> Node
forall a. Num a => a -> a -> a
* Node
half Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== Node
len
then Node
half
else [Char] -> Node
forall a. HasCallStack => [Char] -> a
error ("Uneven input size " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (Node, [Char]) -> [Char]
forall a. Show a => a -> [Char]
show (Node
len, NodeDef -> [Char]
forall msg. Message msg => msg -> [Char]
showMessage NodeDef
nodeDef))
valuesGrads :: [Maybe (Tensor Build a)]
valuesGrads = [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.gather Tensor Value a
dz (Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT Output
idx :: Tensor Build Int32)
| Output
idx <- Node -> [Output] -> [Output]
forall a. Node -> [a] -> [a]
take Node
halfLen [Output]
inputs
]
opGrad "DynamicPartition" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
xs, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
indices] dz :: [Tensor Value a]
dz =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just Tensor Build a
reconstructed, Maybe (Tensor Build a)
forall a. Maybe a
Nothing ]
where
reconstructed :: Tensor Build a
reconstructed = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Build a
stitched
(Tensor Build a -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.shape (Tensor Build a
xs :: Tensor Build a) :: Tensor Build Int32)
stitched :: Tensor Build a
stitched = [Tensor Build Int32] -> [Tensor Value a] -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
[Tensor v'1 Int32] -> [Tensor v'2 t] -> Tensor Build t
CoreOps.dynamicStitch [Tensor Build Int32]
partitionedIndices [Tensor Value a]
dz
partitionedIndices :: [Tensor Build Int32]
partitionedIndices = Int64
-> Tensor Build Int32 -> Tensor Build Int32 -> [Tensor Build Int32]
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Int64 -> Tensor v'1 t -> Tensor v'2 Int32 -> [Tensor Build t]
CoreOps.dynamicPartition Int64
np Tensor Build Int32
originalIndices Tensor Build Int32
indices
np :: Int64
np = NodeDef -> Text -> Int64
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "num_partitions" :: Int64
originalIndices :: Tensor Build Int32
originalIndices =
Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape (Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.size Tensor Build Int32
indices) 1) Tensor Build Int32
prefixShape
prefixShape :: Tensor Build Int32
prefixShape = Tensor Build Int32 -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shapeInt32 Tensor Build Int32
indices
shapeInt32 :: Tensor v'1 t -> Tensor Build Int32
shapeInt32 t :: Tensor v'1 t
t = Tensor v'1 t -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.shape Tensor v'1 t
t :: Tensor Build Int32
opGrad "Select" _ [Output -> Tensor Build Bool
forall a. Output -> Tensor Build a
toT -> Tensor Build Bool
c, Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, _] [dz :: Tensor Value a
dz] =
[ Maybe (Tensor Build a)
forall a. Maybe a
Nothing
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build Bool
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
CoreOps.select Tensor Build Bool
c Tensor Value a
dz Tensor Build a
zeros
, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build Bool
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
CoreOps.select Tensor Build Bool
c Tensor Build a
zeros Tensor Value a
dz
]
where zeros :: Tensor Build a
zeros = Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
CoreOps.zerosLike Tensor Build a
x
opGrad "Log" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Double, Float]
t =>
Tensor v'1 t -> Tensor Build t
CoreOps.inv Tensor Build a
x ]
opGrad "Exp" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
dz Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
CoreOps.exp Tensor Build a
x ]
opGrad "SparseSegmentSum" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
y, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
t] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t tindices
tnumsegments.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tindices,
OneOf '[Int32, Int64] tnumsegments) =>
Tensor v'1 t
-> Tensor v'2 tindices -> Tensor v'3 tnumsegments -> Tensor Build t
CoreOps.unsortedSegmentSum
(Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.gather Tensor Value a
dz (Tensor Build Int32
t :: Tensor Build Int32))
(Tensor Build Int32
y :: Tensor Build Int32) Tensor Build Int32
inputRows
, Maybe (Tensor Build a)
forall a. Maybe a
Nothing
, Maybe (Tensor Build a)
forall a. Maybe a
Nothing
]
where inputRows :: Tensor Build Int32
inputRows = Tensor Build Int32 -> Int32 -> Int32 -> Tensor Build Int32
forall (v1 :: * -> *) t.
TensorType t =>
Tensor v1 t -> Int32 -> Int32 -> Tensor Build t
flatSlice (Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)) 0 1
opGrad "LabelClasses" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "LabelWeights" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "Size" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "Tile" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, Output -> Tensor Build Int32
forall a. Output -> Tensor Build a
toT -> Tensor Build Int32
multiples] [dz :: Tensor Value a
dz] =
[Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just Tensor Build a
inputGrad, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
where
inputGrad :: Tensor Build a
inputGrad = Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Build a
reshapedDz Tensor Build Int32
axes
inputShape :: Tensor Build Int32
inputShape = Tensor Build a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape (Tensor Build a
x :: Tensor Build a)
packed :: Tensor Build Int32
packed = [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
[Tensor v'1 t] -> Tensor Build t
CoreOps.pack [Tensor Build Int32
multiples, Tensor Build Int32
inputShape]
perm :: Tensor Build Int32
perm = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1, 0 :: Int32]
splitShape :: Tensor Build Int32
splitShape = Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape (Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.transpose Tensor Build Int32
packed Tensor Build Int32
perm) Tensor Build Int32
allDimensions
axes :: Tensor Build Int32
axes = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.size Tensor Build Int32
splitShape) (2 :: Tensor Build Int32)
reshapedDz :: Tensor Build a
reshapedDz = Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor Value a
dz Tensor Build Int32
splitShape
opGrad "ResizeBilinear" nodeDef :: NodeDef
nodeDef [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x, _] [dz :: Tensor Value a
dz] =
[ Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ (OpDef -> OpDef)
-> Tensor Build Float -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Word16, Double, Float] t =>
(OpDef -> OpDef)
-> Tensor v'1 Float -> Tensor v'2 t -> Tensor Build t
CoreOps.resizeBilinearGrad'
(Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "align_corners" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
align)
(Tensor Value a -> Tensor Build Float
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast Tensor Value a
dz)
Tensor Build a
x
, Maybe (Tensor Build a)
forall a. Maybe a
Nothing
]
where
align :: Bool
align = NodeDef -> Text -> Bool
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
nodeDef "align_corners" :: Bool
opGrad "ZerosLike" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "Fill" _ _ [dz :: Tensor Value a
dz] = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
sum Tensor Value a
dz Tensor Build Int32
rx]
where
rx :: Tensor Build Int32
rx = Tensor Value a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
rangeOfRank Tensor Value a
dz
opGrad "ReadVariableOp" _ _ [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a -> Tensor Build a
forall (v :: * -> *) a.
TensorKind v =>
Tensor v a -> Tensor Build a
expr Tensor Value a
dz]
opGrad "Const" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing, Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "StopGradient" _ _ _ = [Maybe (Tensor Build a)
forall a. Maybe a
Nothing]
opGrad "VarHandleOp" _ _ _ = []
opGrad "Sqrt" _ [Output -> Tensor Build a
forall a. Output -> Tensor Build a
toT -> Tensor Build a
x] [dz :: Tensor Value a
dz] = [Tensor Build a -> Maybe (Tensor Build a)
forall a. a -> Maybe a
Just (Tensor Build a -> Maybe (Tensor Build a))
-> Tensor Build a -> Maybe (Tensor Build a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
sq' Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Value a
dz]
where
sq' :: Tensor Build a
sq' = a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar 1 Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` (a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar 2 Tensor Build a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mul` Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
CoreOps.sqrt Tensor Build a
x)
opGrad n :: Text
n nodeDef :: NodeDef
nodeDef ins :: [Output]
ins grads :: [Tensor Value a]
grads =
[Char] -> [Maybe (Tensor Build a)]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [Maybe (Tensor Build a)])
-> [Char] -> [Maybe (Tensor Build a)]
forall a b. (a -> b) -> a -> b
$ "no gradient implemented for " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
(Text, Node, Node, [Char], [Output]) -> [Char]
forall a. Show a => a -> [Char]
show (Text
n, [Output] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Output]
ins, [Tensor Value a] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [Tensor Value a]
grads, NodeDef -> [Char]
forall msg. Message msg => msg -> [Char]
showMessage NodeDef
nodeDef, [Output]
ins)
numOutputs :: NodeDef -> OutputIx
numOutputs :: NodeDef -> OutputIx
numOutputs o :: NodeDef
o =
case NodeDef
o NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "op" a) =>
LensLike' f s a
op of
"Abs" -> 1
"Add" -> 1
"AddN" -> 1
"BatchToSpaceND" -> 1
"BatchMatMul" -> 1
"Cast" -> 1
"Const" -> 1
"Concat" -> 1
"Conv2D" -> 1
"Conv2DBackpropInput" -> 1
"DepthwiseConv2dNative" -> 1
"DepthwiseConv2dNativeBackpropInput" -> 1
"Div" -> 1
"DynamicStitch" -> 1
"DynamicPartition" ->
Int64 -> OutputIx
forall a b. (Integral a, Num b) => a -> b
fromIntegral (NodeDef -> Text -> Int64
forall a1. Attribute a1 => NodeDef -> Text -> a1
lookupAttr NodeDef
o "num_partitions" :: Int64)
"Exp" -> 1
"ExpandDims" -> 1
"Gather" -> 1
"LabelClasses" -> 1
"LabelWeights" -> 1
"Log" -> 1
"MatMul" -> 1
"Max" -> 1
"Maximum" -> 1
"MaxPool" -> 1
"Mean" -> 1
"Min" -> 1
"Mul" -> 1
"Neg" -> 1
"Pad" -> 1
"Placeholder" -> 1
"StopGradient" -> 1
"OneHot" -> 1
"ReadVariableOp" -> 1
"RefIdentity" -> 1
"Relu" -> 1
"ReluGrad" -> 1
"Reshape" -> 1
"Select" -> 1
"Sigmoid" -> 1
"Size" -> 1
"Slice" -> 1
"SoftmaxCrossEntropyWithLogits" -> 2
"SpaceToBatchND" -> 1
"SparseSegmentSum" -> 1
"Square" -> 1
"Squeeze" -> 1
"Sqrt" -> 1
"Sub" -> 1
"Sum" -> 1
"Tanh" -> 1
"Tile" -> 1
"ResizeBilinear" -> 1
"Transpose" -> 1
"TruncatedNormal" -> 1
"VarHandleOp" -> 1
"Variable" -> 1
"ZerosLike" -> 1
"Fill" -> 1
_ -> [Char] -> OutputIx
forall a. HasCallStack => [Char] -> a
error ([Char] -> OutputIx) -> [Char] -> OutputIx
forall a b. (a -> b) -> a -> b
$ "numOutputs not implemented for " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Text -> [Char]
forall a. Show a => a -> [Char]
show (NodeDef
o NodeDef -> FoldLike Text NodeDef NodeDef Text Text -> Text
forall s a t b. s -> FoldLike a s t a b -> a
^. FoldLike Text NodeDef NodeDef Text Text
forall (f :: * -> *) s a.
(Functor f, HasField s "op" a) =>
LensLike' f s a
op)
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv x :: Tensor v1 Int32
x y :: Tensor v2 Int32
y = Tensor v1 Int32
x Tensor v1 Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` (Tensor v2 Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int16, Int32, Int64, Word16, Word8, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.maximum Tensor v2 Int32
y 1)
allDimensions :: Tensor Build Int32
allDimensions :: Tensor Build Int32
allDimensions = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [-1 :: Int32]
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Build Int32
rangeOfRank :: Tensor v1 t -> Tensor Build Int32
rangeOfRank x :: Tensor v1 t
x = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor v1 t -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank Tensor v1 t
x) 1
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
lookupAttr :: NodeDef -> Text -> a1
lookupAttr nodeDef :: NodeDef
nodeDef attrName :: Text
attrName = NodeDef
nodeDef NodeDef -> FoldLike a1 NodeDef NodeDef a1 a1 -> a1
forall s a t b. s -> FoldLike a s t a b -> a
^. LensLike' (Constant a1) NodeDef (Map Text AttrValue)
forall (f :: * -> *) s a.
(Functor f, HasField s "attr" a) =>
LensLike' f s a
attr LensLike' (Constant a1) NodeDef (Map Text AttrValue)
-> ((a1 -> Constant a1 a1)
-> Map Text AttrValue -> Constant a1 (Map Text AttrValue))
-> FoldLike a1 NodeDef NodeDef a1 a1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Lens' (Map Text AttrValue) (Maybe AttrValue)
forall k v. Ord k => k -> Lens' (Map k v) (Maybe v)
at Text
attrName LensLike' (Constant a1) (Map Text AttrValue) (Maybe AttrValue)
-> ((a1 -> Constant a1 a1)
-> Maybe AttrValue -> Constant a1 (Maybe AttrValue))
-> (a1 -> Constant a1 a1)
-> Map Text AttrValue
-> Constant a1 (Map Text AttrValue)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AttrValue -> Lens' (Maybe AttrValue) AttrValue
forall a. Eq a => a -> Lens' (Maybe a) a
non AttrValue
forall a. Message a => a
def LensLike' (Constant a1) (Maybe AttrValue) AttrValue
-> ((a1 -> Constant a1 a1) -> AttrValue -> Constant a1 AttrValue)
-> (a1 -> Constant a1 a1)
-> Maybe AttrValue
-> Constant a1 (Maybe AttrValue)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a1 -> Constant a1 a1) -> AttrValue -> Constant a1 AttrValue
forall a. Attribute a => Lens' AttrValue a
attrLens