1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Fixed #19 by adding previously missing reshape.

The comment did say that only flat shapes were supported though.
This commit is contained in:
Greg Steuck 2016-11-09 11:47:49 -08:00
parent 9c81241439
commit ec5c5228e1

View file

@ -25,7 +25,7 @@ import Control.Monad (zipWithM)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.List (genericLength) import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render) import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops () -- Num instance for Tensor import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType) import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
@ -56,24 +56,33 @@ embeddingLookup :: forall a b v .
-> Tensor Value b -> Tensor Value b
-- ^ A `Tensor` with type `int32` or `int64` -- ^ A `Tensor` with type `int32` or `int64`
-- containing the ids to be looked up in `params`. -- containing the ids to be looked up in `params`.
-- The ids are required to be flat on entry and have -- The ids are required to have fewer than 2^31
-- fewer than 2^31 entries. -- entries.
-> Build (Tensor Value a) -> Build (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup params@(p1 : _) ids = embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
go (np :: Int32) embeddingLookup params@(p0 : _) ids = do
where -- Do np separate lookups, finding embeddings for plist[p] in params[p]
go 1 = colocateWith p1 (render $ CoreOps.gather p1 ids) partitionedResult <- zipWithM
go _ = CoreOps.dynamicStitch pindices <$> partitionedResult (\p g -> colocateWith p $ render $ CoreOps.gather p g)
np = genericLength params params gatherIds
pAssignments = CoreOps.cast (ids `CoreOps.mod` np) let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult
newIds = ids `CoreOps.div` np -- Shape restoration is not as optimal as it would be with client
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1 -- side shape tracking.
-- Partition list of ids based on assignments into np separate lists paramShape <- colocateWith p0 (render (shape p0))
gatherIds = CoreOps.dynamicPartition np newIds pAssignments let finalShape = CoreOps.concat 0 [shape ids, tailShape]
-- Similarly, partition the original indices. tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
pindices = CoreOps.dynamicPartition np originalIndices pAssignments render $ CoreOps.reshape unshapedResult finalShape
-- Do np separate lookups, finding embeddings for plist[p] in params[p] where
partitionedResult = zipWithM np = genericLength params
(\p g -> colocateWith p $ render $ CoreOps.gather p g) flatIds = CoreOps.reshape ids (singleton (-1))
params gatherIds pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
newIds = flatIds `CoreOps.div` np
originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1
-- Partition list of ids based on assignments into np separate lists
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
-- Similarly, partition the original indices.
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
singleton i = vector [i :: Int32]
embeddingLookup [] _ = error "embeddingLookup requires params to be non empty"