1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-04-08 23:15:15 +02:00

Fixed by adding previously missing reshape.

The comment did say that only flat shapes were supported though.
This commit is contained in:
Greg Steuck 2016-11-09 11:47:49 -08:00
parent 9c81241439
commit ec5c5228e1

View file

@ -25,7 +25,7 @@ import Control.Monad (zipWithM)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.List (genericLength) import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render) import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops () -- Num instance for Tensor import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType) import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
@ -56,24 +56,33 @@ embeddingLookup :: forall a b v .
-> Tensor Value b -> Tensor Value b
-- ^ A `Tensor` with type `int32` or `int64` -- ^ A `Tensor` with type `int32` or `int64`
-- containing the ids to be looked up in `params`. -- containing the ids to be looked up in `params`.
-- The ids are required to be flat on entry and have -- The ids are required to have fewer than 2^31
-- fewer than 2^31 entries. -- entries.
-> Build (Tensor Value a) -> Build (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup params@(p1 : _) ids = embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
go (np :: Int32) embeddingLookup params@(p0 : _) ids = do
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
partitionedResult <- zipWithM
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
params gatherIds
let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult
-- Shape restoration is not as optimal as it would be with client
-- side shape tracking.
paramShape <- colocateWith p0 (render (shape p0))
let finalShape = CoreOps.concat 0 [shape ids, tailShape]
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
render $ CoreOps.reshape unshapedResult finalShape
where where
go 1 = colocateWith p1 (render $ CoreOps.gather p1 ids) np = genericLength params
go _ = CoreOps.dynamicStitch pindices <$> partitionedResult flatIds = CoreOps.reshape ids (singleton (-1))
np = genericLength params pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
pAssignments = CoreOps.cast (ids `CoreOps.mod` np) newIds = flatIds `CoreOps.div` np
newIds = ids `CoreOps.div` np originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1 -- Partition list of ids based on assignments into np separate lists
-- Partition list of ids based on assignments into np separate lists gatherIds = CoreOps.dynamicPartition np newIds pAssignments
gatherIds = CoreOps.dynamicPartition np newIds pAssignments -- Similarly, partition the original indices.
-- Similarly, partition the original indices. pindices = CoreOps.dynamicPartition np originalIndices pAssignments
pindices = CoreOps.dynamicPartition np originalIndices pAssignments singleton i = vector [i :: Int32]
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
partitionedResult = zipWithM embeddingLookup [] _ = error "embeddingLookup requires params to be non empty"
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
params gatherIds