diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 91fa72c..cae1473 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -444,8 +444,8 @@ gradForBinaryCwise (x, gx) (y, gy) = where dx = reshape (sum gx rx) sx dy = reshape (sum gy ry) sy - sx = shape x -- (x :: Tensor Build t) - sy = shape y -- (y :: Tensor Build t) + sx = shape x + sy = shape y (rx, ry) = broadcastGradientArgs sx sy -- | The gradient function for an op type.