1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-25 02:59:46 +01:00

genericLength is too generic.

Avoid folding in TF.
This commit is contained in:
Greg Steuck 2016-11-09 14:20:26 -08:00
parent ec5c5228e1
commit d9115c716f

View file

@ -23,9 +23,8 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor
import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
@ -74,7 +73,8 @@ embeddingLookup params@(p0 : _) ids = do
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
render $ CoreOps.reshape unshapedResult finalShape
where
np = genericLength params
-- Avoids genericLength here which would be evaluated by TF.
np = fromIntegral (length params)
flatIds = CoreOps.reshape ids (singleton (-1))
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
newIds = flatIds `CoreOps.div` np