mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
genericLength is too generic.
Avoid folding in TF.
This commit is contained in:
parent
ec5c5228e1
commit
d9115c716f
1 changed files with 3 additions and 3 deletions
|
@ -23,9 +23,8 @@ module TensorFlow.EmbeddingOps where
|
|||
|
||||
import Control.Monad (zipWithM)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
import TensorFlow.Build (Build, colocateWith, render)
|
||||
import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Tensor (Tensor, Value)
|
||||
import TensorFlow.Types (OneOf, TensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
@ -74,7 +73,8 @@ embeddingLookup params@(p0 : _) ids = do
|
|||
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
|
||||
render $ CoreOps.reshape unshapedResult finalShape
|
||||
where
|
||||
np = genericLength params
|
||||
-- Avoids genericLength here which would be evaluated by TF.
|
||||
np = fromIntegral (length params)
|
||||
flatIds = CoreOps.reshape ids (singleton (-1))
|
||||
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
|
||||
newIds = flatIds `CoreOps.div` np
|
||||
|
|
Loading…
Reference in a new issue