From d9115c716fd5c89aa89e26a2c361933acc52ad24 Mon Sep 17 00:00:00 2001 From: Greg Steuck Date: Wed, 9 Nov 2016 14:20:26 -0800 Subject: [PATCH] genericLength is too generic. Avoid folding in TF. --- tensorflow-ops/src/TensorFlow/EmbeddingOps.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs index c6d2718..af90406 100644 --- a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs +++ b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs @@ -23,9 +23,8 @@ module TensorFlow.EmbeddingOps where import Control.Monad (zipWithM) import Data.Int (Int32, Int64) -import Data.List (genericLength) import TensorFlow.Build (Build, colocateWith, render) -import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor +import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Types (OneOf, TensorType) import qualified TensorFlow.GenOps.Core as CoreOps @@ -74,7 +73,8 @@ embeddingLookup params@(p0 : _) ids = do tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1)) render $ CoreOps.reshape unshapedResult finalShape where - np = genericLength params + -- Avoids genericLength here which would be evaluated by TF. + np = fromIntegral (length params) flatIds = CoreOps.reshape ids (singleton (-1)) pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np) newIds = flatIds `CoreOps.div` np