mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Added reduceSum to Ops
This commit is contained in:
parent
eca4ff8981
commit
97b4bb5bab
1 changed files with 21 additions and 0 deletions
|
@ -133,6 +133,8 @@ module TensorFlow.Ops
|
|||
, CoreOps.sub'
|
||||
, CoreOps.sum
|
||||
, CoreOps.sum'
|
||||
, reduceSum
|
||||
, reduceSum'
|
||||
, CoreOps.transpose
|
||||
, CoreOps.transpose'
|
||||
, truncatedNormal
|
||||
|
@ -314,6 +316,25 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
|
|||
where
|
||||
scalarShape = [] :: [Int32]
|
||||
|
||||
-- | Sum a tensor down to a scalar
|
||||
-- Seee `TensorFlow.GenOps.Core.sum`
|
||||
reduceSum
|
||||
:: ( TensorType a
|
||||
, OneOf '[ Double, Float, Int32, Int64
|
||||
, Complex Float, Complex Double] a
|
||||
)
|
||||
=> Tensor v a -> Tensor Build a
|
||||
reduceSum x = CoreOps.sum x allAxes
|
||||
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
|
||||
|
||||
reduceSum'
|
||||
:: ( TensorType a
|
||||
, OneOf '[ Double, Float, Int32, Int64
|
||||
, Complex Float, Complex Double] a
|
||||
)
|
||||
=> OpParams -> Tensor v a -> Tensor Build a
|
||||
reduceSum' params x = CoreOps.sum' params x allAxes
|
||||
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
|
||||
|
||||
-- | Create a constant vector.
|
||||
vector :: TensorType a => [a] -> Tensor Build a
|
||||
|
|
Loading…
Reference in a new issue