mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 06:49:44 +01:00
genericLength is too generic.
Avoid folding in TF.
This commit is contained in:
parent
ec5c5228e1
commit
d9115c716f
1 changed files with 3 additions and 3 deletions
|
@ -23,9 +23,8 @@ module TensorFlow.EmbeddingOps where
|
||||||
|
|
||||||
import Control.Monad (zipWithM)
|
import Control.Monad (zipWithM)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.List (genericLength)
|
|
||||||
import TensorFlow.Build (Build, colocateWith, render)
|
import TensorFlow.Build (Build, colocateWith, render)
|
||||||
import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor
|
import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
|
||||||
import TensorFlow.Tensor (Tensor, Value)
|
import TensorFlow.Tensor (Tensor, Value)
|
||||||
import TensorFlow.Types (OneOf, TensorType)
|
import TensorFlow.Types (OneOf, TensorType)
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
@ -74,7 +73,8 @@ embeddingLookup params@(p0 : _) ids = do
|
||||||
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
|
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
|
||||||
render $ CoreOps.reshape unshapedResult finalShape
|
render $ CoreOps.reshape unshapedResult finalShape
|
||||||
where
|
where
|
||||||
np = genericLength params
|
-- Avoids genericLength here which would be evaluated by TF.
|
||||||
|
np = fromIntegral (length params)
|
||||||
flatIds = CoreOps.reshape ids (singleton (-1))
|
flatIds = CoreOps.reshape ids (singleton (-1))
|
||||||
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
|
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
|
||||||
newIds = flatIds `CoreOps.div` np
|
newIds = flatIds `CoreOps.div` np
|
||||||
|
|
Loading…
Reference in a new issue