1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Added reduceSum to Ops

This commit is contained in:
Jarl Christian Berentsen 2017-05-04 09:26:43 +02:00 committed by fkm3
parent eca4ff8981
commit 97b4bb5bab

View file

@ -133,6 +133,8 @@ module TensorFlow.Ops
, CoreOps.sub' , CoreOps.sub'
, CoreOps.sum , CoreOps.sum
, CoreOps.sum' , CoreOps.sum'
, reduceSum
, reduceSum'
, CoreOps.transpose , CoreOps.transpose
, CoreOps.transpose' , CoreOps.transpose'
, truncatedNormal , truncatedNormal
@ -314,6 +316,25 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
where where
scalarShape = [] :: [Int32] scalarShape = [] :: [Int32]
-- | Sum a tensor down to a scalar
-- Seee `TensorFlow.GenOps.Core.sum`
reduceSum
:: ( TensorType a
, OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a
)
=> Tensor v a -> Tensor Build a
reduceSum x = CoreOps.sum x allAxes
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
reduceSum'
:: ( TensorType a
, OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a
)
=> OpParams -> Tensor v a -> Tensor Build a
reduceSum' params x = CoreOps.sum' params x allAxes
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
-- | Create a constant vector. -- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Build a vector :: TensorType a => [a] -> Tensor Build a