From 97b4bb5bab731860f8c25b77333ad3f47a57ba1c Mon Sep 17 00:00:00 2001 From: Jarl Christian Berentsen Date: Thu, 4 May 2017 09:26:43 +0200 Subject: [PATCH] Added reduceSum to Ops --- tensorflow-ops/src/TensorFlow/Ops.hs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 3f27a80..40cedfc 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -133,6 +133,8 @@ module TensorFlow.Ops , CoreOps.sub' , CoreOps.sum , CoreOps.sum' + , reduceSum + , reduceSum' , CoreOps.transpose , CoreOps.transpose' , truncatedNormal @@ -314,6 +316,25 @@ scalarize t = CoreOps.reshape t (vector scalarShape) where scalarShape = [] :: [Int32] +-- | Sum a tensor down to a scalar +-- Seee `TensorFlow.GenOps.Core.sum` +reduceSum + :: ( TensorType a + , OneOf '[ Double, Float, Int32, Int64 + , Complex Float, Complex Double] a + ) + => Tensor v a -> Tensor Build a +reduceSum x = CoreOps.sum x allAxes + where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1 + +reduceSum' + :: ( TensorType a + , OneOf '[ Double, Float, Int32, Int64 + , Complex Float, Complex Double] a + ) + => OpParams -> Tensor v a -> Tensor Build a +reduceSum' params x = CoreOps.sum' params x allAxes + where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1 -- | Create a constant vector. vector :: TensorType a => [a] -> Tensor Build a