{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module TensorFlow.Ops
( CoreOps.add
, CoreOps.add'
, CoreOps.abs
, CoreOps.abs'
, CoreOps.addN
, CoreOps.addN'
, CoreOps.argMax
, CoreOps.argMax'
, CoreOps.assign
, CoreOps.assign'
, CoreOps.broadcastGradientArgs
, CoreOps.broadcastGradientArgs'
, CoreOps.cast
, CoreOps.cast'
, CoreOps.concat
, CoreOps.concat'
, constant
, constant'
, CoreOps.equal
, CoreOps.equal'
, expandDims
, expandDims'
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
, zeroInitializedVariable'
, CoreOps.fill
, CoreOps.fill'
, CoreOps.identity
, CoreOps.identity'
, CoreOps.matMul
, CoreOps.matMul'
, CoreOps.einsum
, CoreOps.einsum'
, matTranspose
, matTranspose'
, CoreOps.mean
, CoreOps.mean'
, CoreOps.mul
, CoreOps.mul'
, CoreOps.neg
, CoreOps.neg'
, CoreOps.oneHot
, CoreOps.oneHot'
, CoreOps.pack
, CoreOps.pack'
, placeholder
, placeholder'
, CoreOps.range
, CoreOps.range'
, reducedShape
, reduceMean
, reduceMean'
, CoreOps.relu
, CoreOps.relu'
, CoreOps.reluGrad
, CoreOps.reluGrad'
, CoreOps.tanh
, CoreOps.tanhGrad
, CoreOps.reshape
, CoreOps.reshape'
, restore
, restoreFromName
, save
, scalar
, scalar'
, shape
, shape'
, CoreOps.sigmoid
, CoreOps.sigmoidGrad
, CoreOps.sign
, CoreOps.sign'
, CoreOps.size
, CoreOps.size'
, CoreOps.softmax
, CoreOps.softmax'
, CoreOps.softmaxCrossEntropyWithLogits
, CoreOps.softmaxCrossEntropyWithLogits'
, CoreOps.sparseToDense
, CoreOps.sparseToDense'
, CoreOps.sub
, CoreOps.sub'
, CoreOps.sum
, CoreOps.sum'
, reduceSum
, reduceSum'
, CoreOps.transpose
, CoreOps.transpose'
, truncatedNormal
, truncatedNormal'
, CoreOps.variable
, CoreOps.variable'
, vector
, vector'
, zeros
, CoreOps.zerosLike
, CoreOps.zerosLike'
, scalarize
) where
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Int (Int32, Int64)
import Data.Word (Word16)
import Prelude hiding (abs, sum, concat)
import Data.ProtoLens.Default(def)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import Text.Printf (printf)
import Proto.Tensorflow.Core.Framework.Tensor (TensorProto)
import Proto.Tensorflow.Core.Framework.Tensor_Fields
( dtype
, tensorShape
)
import qualified Proto.Tensorflow.Core.Framework.TensorShape_Fields
as TensorShape
import TensorFlow.Build
import TensorFlow.BuildOp
import TensorFlow.ControlFlow (group)
import TensorFlow.Tensor
import TensorFlow.Types
import qualified TensorFlow.GenOps.Core as CoreOps
import qualified Prelude (abs)
instance ( TensorType a
, Num a
, v ~ Build
, OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) => Num (Tensor v a) where
+ :: Tensor v a -> Tensor v a -> Tensor v a
(+) = Tensor v a -> Tensor v a -> Tensor v a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, ByteString, Int16, Int32, Int64,
Int8, Word16, Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.add
* :: Tensor v a -> Tensor v a -> Tensor v a
(*) = Tensor v a -> Tensor v a -> Tensor v a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.mul
(-) = Tensor v a -> Tensor v a -> Tensor v a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
CoreOps.sub
abs :: Tensor v a -> Tensor v a
abs = Tensor v a -> Tensor v a
forall (v'1 :: * -> *) t.
OneOf '[Int16, Int32, Int64, Int8, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
CoreOps.abs
fromInteger :: Integer -> Tensor v a
fromInteger = a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar (a -> Tensor Build a)
-> (Integer -> a) -> Integer -> Tensor Build a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
signum :: Tensor v a -> Tensor v a
signum = Tensor v a -> Tensor v a
forall (v'1 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int32, Int64, Word16, Double,
Float]
t =>
Tensor v'1 t -> Tensor Build t
CoreOps.sign
negate :: Tensor v a -> Tensor v a
negate = Tensor v a -> Tensor v a
forall (v'1 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Double, Float]
t =>
Tensor v'1 t -> Tensor Build t
CoreOps.neg
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
matTranspose :: Tensor e a -> Tensor Build a
matTranspose = OpParams -> Tensor e a -> Tensor Build a
forall a (v :: * -> *).
TensorType a =>
OpParams -> Tensor v a -> Tensor Build a
matTranspose' OpParams
forall a. a -> a
id
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Build a
matTranspose' :: OpParams -> Tensor v a -> Tensor Build a
matTranspose' params :: OpParams
params = (Tensor v a -> Tensor Build Int32 -> Tensor Build a)
-> Tensor Build Int32 -> Tensor v a -> Tensor Build a
forall a b c. (a -> b -> c) -> b -> a -> c
flip (OpParams -> Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tperm.
(TensorType t, OneOf '[Int32, Int64] tperm) =>
OpParams -> Tensor v'1 t -> Tensor v'2 tperm -> Tensor Build t
CoreOps.transpose' OpParams
params) ([Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [1, 0 :: Int32])
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder :: Shape -> m (Tensor Value a)
placeholder = OpParams -> Shape -> m (Tensor Value a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a) =>
OpParams -> Shape -> m (Tensor Value a)
placeholder' OpParams
forall a. a -> a
id
placeholder' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Tensor Value a)
placeholder' :: OpParams -> Shape -> m (Tensor Value a)
placeholder' params :: OpParams
params pShape :: Shape
pShape
= Build (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Tensor Value a) -> m (Tensor Value a))
-> Build (Tensor Value a) -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ [Int64] -> OpDef -> Build (Tensor Value a)
forall a. BuildResult a => [Int64] -> OpDef -> Build a
buildOp [] (OpDef -> Build (Tensor Value a))
-> OpDef -> Build (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ OpType -> OpDef
opDef "Placeholder"
OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef DataType
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "dtype" (forall (f :: * -> *). Identical f => LensLike' f OpDef DataType)
-> DataType -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef Shape
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "shape" (forall (f :: * -> *). Identical f => LensLike' f OpDef Shape)
-> Shape -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ Shape
pShape
OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& OpParams
params
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor v a -> m (Tensor Ref a)
initializedVariable :: Tensor v a -> m (Tensor Ref a)
initializedVariable = OpParams -> Tensor v a -> m (Tensor Ref a)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' OpParams
forall a. a -> a
id
initializedVariable' :: (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' :: OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' params :: OpParams
params initializer :: Tensor v a
initializer = do
Tensor Ref a
v <- OpParams -> Shape -> m (Tensor Ref a)
forall dtype (m' :: * -> *).
(MonadBuild m', TensorType dtype) =>
OpParams -> Shape -> m' (Tensor Ref dtype)
CoreOps.variable' OpParams
params []
Tensor Ref a
i <- OpParams -> Tensor Ref a -> Tensor v a -> m (Tensor Ref a)
forall (v'2 :: * -> *) t (m' :: * -> *).
(MonadBuild m', TensorType t) =>
OpParams -> Tensor Ref t -> Tensor v'2 t -> m' (Tensor Ref t)
CoreOps.assign' (Text -> Lens' OpDef Bool
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "validate_shape" (forall (f :: * -> *). Identical f => LensLike' f OpDef Bool)
-> Bool -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ Bool
False) Tensor Ref a
v
Tensor v a
initializer
ControlNode -> m ()
forall (m :: * -> *). MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode -> m ()) -> m ControlNode -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor Ref a -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
group Tensor Ref a
i
Tensor Ref a -> m (Tensor Ref a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor Ref a
v
zeroInitializedVariable
:: (MonadBuild m, TensorType a, Num a) =>
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable :: Shape -> m (Tensor Ref a)
zeroInitializedVariable = OpParams -> Shape -> m (Tensor Ref a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a, Num a) =>
OpParams -> Shape -> m (Tensor Ref a)
zeroInitializedVariable' OpParams
forall a. a -> a
id
zeroInitializedVariable'
:: (MonadBuild m, TensorType a, Num a) =>
OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable' :: OpParams -> Shape -> m (Tensor Ref a)
zeroInitializedVariable' params :: OpParams
params = OpParams -> Tensor Build a -> m (Tensor Ref a)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' OpParams
params (Tensor Build a -> m (Tensor Ref a))
-> (Shape -> Tensor Build a) -> Shape -> m (Tensor Ref a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Tensor Build a
forall a. (Num a, TensorType a) => Shape -> Tensor Build a
zeros
save :: forall a m v . (Rendered (Tensor v), MonadBuild m, TensorType a)
=> ByteString
-> [Tensor v a]
-> m ControlNode
save :: ByteString -> [Tensor v a] -> m ControlNode
save path :: ByteString
path xs :: [Tensor v a]
xs = Build ControlNode -> m ControlNode
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build ControlNode -> m ControlNode)
-> Build ControlNode -> m ControlNode
forall a b. (a -> b) -> a -> b
$ do
let toByteStringTensor :: Tensor v a -> Tensor Build ByteString
toByteStringTensor = ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar (ByteString -> Tensor Build ByteString)
-> (Tensor v a -> ByteString)
-> Tensor v a
-> Tensor Build ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8 (Text -> ByteString)
-> (Tensor v a -> Text) -> Tensor v a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Output -> Text
encodeOutput (Output -> Text) -> (Tensor v a -> Output) -> Tensor v a -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor v a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput
let names :: [Tensor Build ByteString]
names = (Tensor v a -> Tensor Build ByteString)
-> [Tensor v a] -> [Tensor Build ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Tensor v a -> Tensor Build ByteString
forall a. Tensor v a -> Tensor Build ByteString
toByteStringTensor [Tensor v a]
xs
let types :: [DataType]
types = Int -> DataType -> [DataType]
forall a. Int -> a -> [a]
replicate ([Tensor v a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor v a]
xs) (a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a))
[Output]
names' <- Tensor Build ByteString -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs (Tensor Build ByteString -> Build [Output])
-> Tensor Build ByteString -> Build [Output]
forall a b. (a -> b) -> a -> b
$ [Tensor Build ByteString] -> Tensor Build ByteString
forall (v'1 :: * -> *) t.
TensorType t =>
[Tensor v'1 t] -> Tensor Build t
CoreOps.pack [Tensor Build ByteString]
names
[Output]
xs' <- [Tensor v a] -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs [Tensor v a]
xs
[Output]
path' <- Tensor Build ByteString -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs (Tensor Build ByteString -> Build [Output])
-> Tensor Build ByteString -> Build [Output]
forall a b. (a -> b) -> a -> b
$ ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar ByteString
path
[Int64] -> OpDef -> Build ControlNode
forall a. BuildResult a => [Int64] -> OpDef -> Build a
buildOp [] (OpDef -> Build ControlNode) -> OpDef -> Build ControlNode
forall a b. (a -> b) -> a -> b
$ OpType -> OpDef
opDef "Save"
OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef [DataType]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "T" (forall (f :: * -> *). Identical f => LensLike' f OpDef [DataType])
-> [DataType] -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ [DataType]
types
OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Lens' OpDef [Output]
forall (f :: * -> *). Identical f => LensLike' f OpDef [Output]
opInputs (forall (f :: * -> *). Identical f => LensLike' f OpDef [Output])
-> [Output] -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ ([Output]
path' [Output] -> [Output] -> [Output]
forall a. [a] -> [a] -> [a]
++ [Output]
names' [Output] -> [Output] -> [Output]
forall a. [a] -> [a] -> [a]
++ [Output]
xs')
restoreFromName :: forall a m . (MonadBuild m, TensorType a)
=> ByteString
-> ByteString
-> Tensor Ref a
-> m ControlNode
restoreFromName :: ByteString -> ByteString -> Tensor Ref a -> m ControlNode
restoreFromName path :: ByteString
path name :: ByteString
name x :: Tensor Ref a
x = Build ControlNode -> m ControlNode
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build ControlNode -> m ControlNode)
-> Build ControlNode -> m ControlNode
forall a b. (a -> b) -> a -> b
$ do
[Output]
path' <- Tensor Build ByteString -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs (Tensor Build ByteString -> Build [Output])
-> Tensor Build ByteString -> Build [Output]
forall a b. (a -> b) -> a -> b
$ ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar ByteString
path
[Output]
name' <- Tensor Build ByteString -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs (Tensor Build ByteString -> Build [Output])
-> Tensor Build ByteString -> Build [Output]
forall a b. (a -> b) -> a -> b
$ ByteString -> Tensor Build ByteString
forall a. TensorType a => a -> Tensor Build a
scalar ByteString
name
Tensor Value a
restoreOp <- [Int64] -> OpDef -> Build (Tensor Value a)
forall a. BuildResult a => [Int64] -> OpDef -> Build a
buildOp [] (OpDef -> Build (Tensor Value a))
-> OpDef -> Build (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ OpType -> OpDef
opDef "Restore"
OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef DataType
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "dt" (forall (f :: * -> *). Identical f => LensLike' f OpDef DataType)
-> DataType -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& Lens' OpDef [Output]
forall (f :: * -> *). Identical f => LensLike' f OpDef [Output]
opInputs (forall (f :: * -> *). Identical f => LensLike' f OpDef [Output])
-> [Output] -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ ([Output]
path' [Output] -> [Output] -> [Output]
forall a. [a] -> [a] -> [a]
++ [Output]
name')
Tensor Ref a -> Build ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
group (Tensor Ref a -> Build ControlNode)
-> BuildT Identity (Tensor Ref a) -> Build ControlNode
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor Ref a -> Tensor Value a -> BuildT Identity (Tensor Ref a)
forall (v'2 :: * -> *) t (m' :: * -> *).
(MonadBuild m', TensorType t) =>
Tensor Ref t -> Tensor v'2 t -> m' (Tensor Ref t)
CoreOps.assign Tensor Ref a
x (Tensor Value a
restoreOp :: Tensor Value a)
restore :: forall a m . (MonadBuild m, TensorType a)
=> ByteString
-> Tensor Ref a
-> m ControlNode
restore :: ByteString -> Tensor Ref a -> m ControlNode
restore path :: ByteString
path x :: Tensor Ref a
x = ByteString -> ByteString -> Tensor Ref a -> m ControlNode
forall a (m :: * -> *).
(MonadBuild m, TensorType a) =>
ByteString -> ByteString -> Tensor Ref a -> m ControlNode
restoreFromName ByteString
path ByteString
name Tensor Ref a
x
where
name :: ByteString
name = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Output -> Text
encodeOutput (Output -> Text) -> Output -> Text
forall a b. (a -> b) -> a -> b
$ Tensor Ref a -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput Tensor Ref a
x
constant :: TensorType a => Shape -> [a] -> Tensor Build a
constant :: Shape -> [a] -> Tensor Build a
constant = OpParams -> Shape -> [a] -> Tensor Build a
forall a.
TensorType a =>
OpParams -> Shape -> [a] -> Tensor Build a
constant' OpParams
forall a. a -> a
id
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Build a
constant' :: OpParams -> Shape -> [a] -> Tensor Build a
constant' params :: OpParams
params (Shape cShape :: [Int64]
cShape) values :: [a]
values
| Bool
invalidLength = [Char] -> Tensor Build a
forall a. HasCallStack => [Char] -> a
error [Char]
invalidLengthMsg
| Bool
otherwise = OpParams -> Tensor Build a
forall dtype. TensorType dtype => OpParams -> Tensor Build dtype
CoreOps.const' (OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef TensorProto
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "value" (forall (f :: * -> *).
Identical f =>
LensLike' f OpDef TensorProto)
-> TensorProto -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ TensorProto
typedNode))
where
invalidLength :: Bool
invalidLength = [Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
cShape Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
values)
invalidLengthMsg :: [Char]
invalidLengthMsg = [Char] -> Int64 -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "invalid tensor length: expected %d got %d"
([Int64] -> Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
cShape)
([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
values)
typedNode :: TensorProto
typedNode :: TensorProto
typedNode = TensorProto
forall a. Message a => a
def
TensorProto -> (TensorProto -> TensorProto) -> TensorProto
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorProto DataType
forall (f :: * -> *) s a.
(Functor f, HasField s "dtype" a) =>
LensLike' f s a
dtype (forall (f :: * -> *).
Identical f =>
LensLike' f TensorProto DataType)
-> DataType -> TensorProto -> TensorProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
TensorProto -> (TensorProto -> TensorProto) -> TensorProto
forall s t. s -> (s -> t) -> t
& LensLike' f TensorProto TensorShapeProto
forall (f :: * -> *) s a.
(Functor f, HasField s "tensorShape" a) =>
LensLike' f s a
tensorShapeLensLike' f TensorProto TensorShapeProto
-> (([TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> TensorShapeProto -> f TensorShapeProto)
-> ([TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> TensorProto
-> f TensorProto
forall b c a. (b -> c) -> (a -> b) -> a -> c
.([TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> TensorShapeProto -> f TensorShapeProto
forall (f :: * -> *) s a.
(Functor f, HasField s "dim" a) =>
LensLike' f s a
TensorShape.dim (forall (f :: * -> *).
Identical f =>
([TensorShapeProto'Dim] -> f [TensorShapeProto'Dim])
-> TensorProto -> f TensorProto)
-> [TensorShapeProto'Dim] -> TensorProto -> TensorProto
forall s t a b. Setter s t a b -> b -> s -> t
.~
[TensorShapeProto'Dim
forall a. Message a => a
def TensorShapeProto'Dim
-> (TensorShapeProto'Dim -> TensorShapeProto'Dim)
-> TensorShapeProto'Dim
forall s t. s -> (s -> t) -> t
& forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto'Dim Int64
forall (f :: * -> *) s a.
(Functor f, HasField s "size" a) =>
LensLike' f s a
TensorShape.size (forall (f :: * -> *).
Identical f =>
LensLike' f TensorShapeProto'Dim Int64)
-> Int64 -> TensorShapeProto'Dim -> TensorShapeProto'Dim
forall s t a b. Setter s t a b -> b -> s -> t
.~ Int64
x | Int64
x <- [Int64]
cShape]
TensorProto -> (TensorProto -> TensorProto) -> TensorProto
forall s t. s -> (s -> t) -> t
& forall a. TensorType a => Lens' TensorProto [a]
forall (f :: * -> *). Identical f => LensLike' f TensorProto [a]
tensorVal (forall (f :: * -> *). Identical f => LensLike' f TensorProto [a])
-> [a] -> TensorProto -> TensorProto
forall s t a b. Setter s t a b -> b -> s -> t
.~ [a]
values
scalarize :: TensorType a => Tensor v a -> Tensor Build a
scalarize :: Tensor v a -> Tensor Build a
scalarize t :: Tensor v a
t = Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.reshape Tensor v a
t ([Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [Int32]
scalarShape)
where
scalarShape :: [Int32]
scalarShape = [] :: [Int32]
reduceSum :: (OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) =>
Tensor v a -> Tensor Build a
reduceSum :: Tensor v a -> Tensor Build a
reduceSum x :: Tensor v a
x = Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.sum Tensor v a
x Tensor Build Int32
allAxes
where allAxes :: Tensor Build Int32
allAxes = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor v a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank Tensor v a
x :: Tensor Build Int32) 1
reduceSum' :: (OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) =>
OpParams -> Tensor v a -> Tensor Build a
reduceSum' :: OpParams -> Tensor v a -> Tensor Build a
reduceSum' params :: OpParams
params x :: Tensor v a
x = OpParams -> Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
OpParams -> Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.sum' OpParams
params Tensor v a
x Tensor Build Int32
allAxes
where allAxes :: Tensor Build Int32
allAxes = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor v a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank Tensor v a
x :: Tensor Build Int32) 1
reduceMean
:: ( TensorType a
, OneOf '[ Double, Float, Complex Float, Complex Double] a
)
=> Tensor v a -> Tensor Build a
reduceMean :: Tensor v a -> Tensor Build a
reduceMean = OpParams -> Tensor v a -> Tensor Build a
forall a (v :: * -> *).
(TensorType a,
OneOf '[Double, Float, Complex Float, Complex Double] a) =>
OpParams -> Tensor v a -> Tensor Build a
reduceMean' OpParams
forall a. a -> a
id
reduceMean'
:: ( TensorType a
, OneOf '[ Double, Float, Complex Float, Complex Double] a
)
=> OpParams -> Tensor v a -> Tensor Build a
reduceMean' :: OpParams -> Tensor v a -> Tensor Build a
reduceMean' params :: OpParams
params x :: Tensor v a
x = OpParams -> Tensor v a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t tidx.
(OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t,
OneOf '[Int32, Int64] tidx) =>
OpParams -> Tensor v'1 t -> Tensor v'2 tidx -> Tensor Build t
CoreOps.mean' OpParams
params Tensor v a
x Tensor Build Int32
allAxes
where allAxes :: Tensor Build Int32
allAxes = Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor v a -> Tensor Build Int32
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build Int32
CoreOps.rank Tensor v a
x :: Tensor Build Int32) 1
vector :: TensorType a => [a] -> Tensor Build a
vector :: [a] -> Tensor Build a
vector = OpParams -> [a] -> Tensor Build a
forall a. TensorType a => OpParams -> [a] -> Tensor Build a
vector' OpParams
forall a. a -> a
id
vector' :: TensorType a => OpParams -> [a] -> Tensor Build a
vector' :: OpParams -> [a] -> Tensor Build a
vector' params :: OpParams
params xs :: [a]
xs = OpParams -> Shape -> [a] -> Tensor Build a
forall a.
TensorType a =>
OpParams -> Shape -> [a] -> Tensor Build a
constant' OpParams
params [Int -> Item Shape
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Item Shape) -> Int -> Item Shape
forall a b. (a -> b) -> a -> b
$ [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs] [a]
xs
scalar :: TensorType a => a -> Tensor Build a
scalar :: a -> Tensor Build a
scalar = OpParams -> a -> Tensor Build a
forall a. TensorType a => OpParams -> a -> Tensor Build a
scalar' OpParams
forall a. a -> a
id
scalar' :: TensorType a => OpParams -> a -> Tensor Build a
scalar' :: OpParams -> a -> Tensor Build a
scalar' params :: OpParams
params x :: a
x = OpParams -> Shape -> [a] -> Tensor Build a
forall a.
TensorType a =>
OpParams -> Shape -> [a] -> Tensor Build a
constant' OpParams
params [] [a
Item [a]
x]
truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
=> Tensor v Int64
-> m (Tensor Value a)
truncatedNormal :: Tensor v Int64 -> m (Tensor Value a)
truncatedNormal = Tensor v Int64 -> m (Tensor Value a)
forall (v'1 :: * -> *) dtype t (m' :: * -> *).
(MonadBuild m', OneOf '[Word16, Double, Float] dtype,
OneOf '[Int32, Int64] t) =>
Tensor v'1 t -> m' (Tensor Value dtype)
CoreOps.truncatedNormal
truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
=> OpParams -> Tensor v Int64
-> m (Tensor Value a)
truncatedNormal' :: OpParams -> Tensor v Int64 -> m (Tensor Value a)
truncatedNormal' = OpParams -> Tensor v Int64 -> m (Tensor Value a)
forall (v'1 :: * -> *) dtype t (m' :: * -> *).
(MonadBuild m', OneOf '[Word16, Double, Float] dtype,
OneOf '[Int32, Int64] t) =>
OpParams -> Tensor v'1 t -> m' (Tensor Value dtype)
CoreOps.truncatedNormal'
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
zeros :: Shape -> Tensor Build a
zeros (Shape s :: [Int64]
s) = Tensor Build Int64 -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t index_type.
(TensorType t, OneOf '[Int32, Int64] index_type) =>
Tensor v'1 index_type -> Tensor v'2 t -> Tensor Build t
CoreOps.fill ([Int64] -> Tensor Build Int64
forall a. TensorType a => [a] -> Tensor Build a
vector [Int64]
s) (a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
scalar 0)
shape :: TensorType t => Tensor v t -> Tensor Build Int32
shape :: Tensor v t -> Tensor Build Int32
shape = Tensor v t -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.shape
shape' :: TensorType t => OpParams -> Tensor v t -> Tensor Build Int32
shape' :: OpParams -> Tensor v t -> Tensor Build Int32
shape' = OpParams -> Tensor v t -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
OpParams -> Tensor v'1 t -> Tensor Build out_type
CoreOps.shape'
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims :: Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims = Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tshape.
(TensorType t, OneOf '[Int32, Int64] tshape) =>
Tensor v'1 t -> Tensor v'2 tshape -> Tensor Build t
CoreOps.expandDims
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims' :: OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims' = OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t tperm.
(TensorType t, OneOf '[Int32, Int64] tperm) =>
OpParams -> Tensor v'1 t -> Tensor v'2 tperm -> Tensor Build t
CoreOps.expandDims'
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape :: Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape inputShape :: Tensor v1 t1
inputShape axes :: Tensor v2 t2
axes =
let inputShape32 :: Tensor Build Int32
inputShape32 = Tensor v1 t1 -> Tensor Build Int32
forall srcT (v'1 :: * -> *).
TensorType srcT =>
Tensor v'1 srcT -> Tensor Build Int32
toInt32 Tensor v1 t1
inputShape
axes32 :: Tensor Build Int32
axes32 = Tensor v2 t2 -> Tensor Build Int32
forall srcT (v'1 :: * -> *).
TensorType srcT =>
Tensor v'1 srcT -> Tensor Build Int32
toInt32 Tensor v2 t2
axes
toInt32 :: Tensor v'1 srcT -> Tensor Build Int32
toInt32 x :: Tensor v'1 srcT
x = Tensor v'1 srcT -> Tensor Build Int32
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast Tensor v'1 srcT
x :: Tensor Build Int32
inputRank :: Tensor Build Int32
inputRank = Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.size Tensor Build Int32
inputShape32
axesMod :: Tensor Build Int32
axesMod = (Tensor Build Int32
axes32 Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall a. Num a => a -> a -> a
+ Tensor Build Int32
inputRank) Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mod` Tensor Build Int32
inputRank
axesShape :: Tensor Build Int32
axesShape = Tensor Build Int32 -> Tensor Build Int32
forall srcT (v'1 :: * -> *).
TensorType srcT =>
Tensor v'1 srcT -> Tensor Build Int32
shape Tensor Build Int32
axesMod
in [Tensor Build Int32] -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
[Tensor v'1 Int32] -> [Tensor v'2 t] -> Tensor Build t
CoreOps.dynamicStitch
[Tensor Build Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 Tensor Build Int32
inputRank 1,
Item [Tensor Build Int32]
Tensor Build Int32
axesMod]
[Item [Tensor Build Int32]
Tensor Build Int32
inputShape32,
Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t index_type.
(TensorType t, OneOf '[Int32, Int64] index_type) =>
Tensor v'1 index_type -> Tensor v'2 t -> Tensor Build t
CoreOps.fill Tensor Build Int32
axesShape 1]