Fixed #19 by adding previously missing reshape.

The comment did say that only flat shapes were supported though.
This commit is contained in:
Greg Steuck 2016-11-09 11:47:49 -08:00
parent 9c81241439
commit ec5c5228e1
1 changed files with 29 additions and 20 deletions

View File

@ -25,7 +25,7 @@ import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops () -- Num instance for Tensor
import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
@ -56,24 +56,33 @@ embeddingLookup :: forall a b v .
-> Tensor Value b
-- ^ A `Tensor` with type `int32` or `int64`
-- containing the ids to be looked up in `params`.
-- The ids are required to be flat on entry and have
-- fewer than 2^31 entries.
-- The ids are required to have fewer than 2^31
-- entries.
-> Build (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup params@(p1 : _) ids =
go (np :: Int32)
where
go 1 = colocateWith p1 (render $ CoreOps.gather p1 ids)
go _ = CoreOps.dynamicStitch pindices <$> partitionedResult
np = genericLength params
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
newIds = ids `CoreOps.div` np
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
-- Partition list of ids based on assignments into np separate lists
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
-- Similarly, partition the original indices.
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
partitionedResult = zipWithM
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
params gatherIds
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
embeddingLookup params@(p0 : _) ids = do
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
partitionedResult <- zipWithM
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
params gatherIds
let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult
-- Shape restoration is not as optimal as it would be with client
-- side shape tracking.
paramShape <- colocateWith p0 (render (shape p0))
let finalShape = CoreOps.concat 0 [shape ids, tailShape]
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
render $ CoreOps.reshape unshapedResult finalShape
where
np = genericLength params
flatIds = CoreOps.reshape ids (singleton (-1))
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
newIds = flatIds `CoreOps.div` np
originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1
-- Partition list of ids based on assignments into np separate lists
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
-- Similarly, partition the original indices.
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
singleton i = vector [i :: Int32]
embeddingLookup [] _ = error "embeddingLookup requires params to be non empty"