1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

genericLength is too generic.

Avoid folding in TF.
This commit is contained in:
Greg Steuck 2016-11-09 14:20:26 -08:00
parent ec5c5228e1
commit d9115c716f

View file

@ -23,9 +23,8 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM) import Control.Monad (zipWithM)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render) import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType) import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
@ -74,7 +73,8 @@ embeddingLookup params@(p0 : _) ids = do
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1)) tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
render $ CoreOps.reshape unshapedResult finalShape render $ CoreOps.reshape unshapedResult finalShape
where where
np = genericLength params -- Avoids genericLength here which would be evaluated by TF.
np = fromIntegral (length params)
flatIds = CoreOps.reshape ids (singleton (-1)) flatIds = CoreOps.reshape ids (singleton (-1))
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np) pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
newIds = flatIds `CoreOps.div` np newIds = flatIds `CoreOps.div` np