diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 3f27a80..40cedfc 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -133,6 +133,8 @@ module TensorFlow.Ops , CoreOps.sub' , CoreOps.sum , CoreOps.sum' + , reduceSum + , reduceSum' , CoreOps.transpose , CoreOps.transpose' , truncatedNormal @@ -314,6 +316,25 @@ scalarize t = CoreOps.reshape t (vector scalarShape) where scalarShape = [] :: [Int32] +-- | Sum a tensor down to a scalar +-- Seee `TensorFlow.GenOps.Core.sum` +reduceSum + :: ( TensorType a + , OneOf '[ Double, Float, Int32, Int64 + , Complex Float, Complex Double] a + ) + => Tensor v a -> Tensor Build a +reduceSum x = CoreOps.sum x allAxes + where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1 + +reduceSum' + :: ( TensorType a + , OneOf '[ Double, Float, Int32, Int64 + , Complex Float, Complex Double] a + ) + => OpParams -> Tensor v a -> Tensor Build a +reduceSum' params x = CoreOps.sum' params x allAxes + where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1 -- | Create a constant vector. vector :: TensorType a => [a] -> Tensor Build a