mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Added reduceSum to Ops
This commit is contained in:
parent
eca4ff8981
commit
97b4bb5bab
1 changed files with 21 additions and 0 deletions
|
@ -133,6 +133,8 @@ module TensorFlow.Ops
|
||||||
, CoreOps.sub'
|
, CoreOps.sub'
|
||||||
, CoreOps.sum
|
, CoreOps.sum
|
||||||
, CoreOps.sum'
|
, CoreOps.sum'
|
||||||
|
, reduceSum
|
||||||
|
, reduceSum'
|
||||||
, CoreOps.transpose
|
, CoreOps.transpose
|
||||||
, CoreOps.transpose'
|
, CoreOps.transpose'
|
||||||
, truncatedNormal
|
, truncatedNormal
|
||||||
|
@ -314,6 +316,25 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||||
where
|
where
|
||||||
scalarShape = [] :: [Int32]
|
scalarShape = [] :: [Int32]
|
||||||
|
|
||||||
|
-- | Sum a tensor down to a scalar
|
||||||
|
-- Seee `TensorFlow.GenOps.Core.sum`
|
||||||
|
reduceSum
|
||||||
|
:: ( TensorType a
|
||||||
|
, OneOf '[ Double, Float, Int32, Int64
|
||||||
|
, Complex Float, Complex Double] a
|
||||||
|
)
|
||||||
|
=> Tensor v a -> Tensor Build a
|
||||||
|
reduceSum x = CoreOps.sum x allAxes
|
||||||
|
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
|
||||||
|
|
||||||
|
reduceSum'
|
||||||
|
:: ( TensorType a
|
||||||
|
, OneOf '[ Double, Float, Int32, Int64
|
||||||
|
, Complex Float, Complex Double] a
|
||||||
|
)
|
||||||
|
=> OpParams -> Tensor v a -> Tensor Build a
|
||||||
|
reduceSum' params x = CoreOps.sum' params x allAxes
|
||||||
|
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
|
||||||
|
|
||||||
-- | Create a constant vector.
|
-- | Create a constant vector.
|
||||||
vector :: TensorType a => [a] -> Tensor Build a
|
vector :: TensorType a => [a] -> Tensor Build a
|
||||||
|
|
Loading…
Reference in a new issue