diff --git a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs index c6d2718..af90406 100644 --- a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs +++ b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs @@ -23,9 +23,8 @@ module TensorFlow.EmbeddingOps where import Control.Monad (zipWithM) import Data.Int (Int32, Int64) -import Data.List (genericLength) import TensorFlow.Build (Build, colocateWith, render) -import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor +import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Types (OneOf, TensorType) import qualified TensorFlow.GenOps.Core as CoreOps @@ -74,7 +73,8 @@ embeddingLookup params@(p0 : _) ids = do tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1)) render $ CoreOps.reshape unshapedResult finalShape where - np = genericLength params + -- Avoids genericLength here which would be evaluated by TF. + np = fromIntegral (length params) flatIds = CoreOps.reshape ids (singleton (-1)) pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np) newIds = flatIds `CoreOps.div` np