mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Fixed #19 by adding previously missing reshape.
The comment did say that only flat shapes were supported though.
This commit is contained in:
parent
9c81241439
commit
ec5c5228e1
1 changed files with 29 additions and 20 deletions
|
@ -25,7 +25,7 @@ import Control.Monad (zipWithM)
|
|||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
import TensorFlow.Build (Build, colocateWith, render)
|
||||
import TensorFlow.Ops () -- Num instance for Tensor
|
||||
import TensorFlow.Ops (shape, scalar, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Tensor (Tensor, Value)
|
||||
import TensorFlow.Types (OneOf, TensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
@ -56,24 +56,33 @@ embeddingLookup :: forall a b v .
|
|||
-> Tensor Value b
|
||||
-- ^ A `Tensor` with type `int32` or `int64`
|
||||
-- containing the ids to be looked up in `params`.
|
||||
-- The ids are required to be flat on entry and have
|
||||
-- fewer than 2^31 entries.
|
||||
-- The ids are required to have fewer than 2^31
|
||||
-- entries.
|
||||
-> Build (Tensor Value a)
|
||||
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup params@(p1 : _) ids =
|
||||
go (np :: Int32)
|
||||
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
|
||||
embeddingLookup params@(p0 : _) ids = do
|
||||
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
|
||||
partitionedResult <- zipWithM
|
||||
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
|
||||
params gatherIds
|
||||
let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult
|
||||
-- Shape restoration is not as optimal as it would be with client
|
||||
-- side shape tracking.
|
||||
paramShape <- colocateWith p0 (render (shape p0))
|
||||
let finalShape = CoreOps.concat 0 [shape ids, tailShape]
|
||||
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
|
||||
render $ CoreOps.reshape unshapedResult finalShape
|
||||
where
|
||||
go 1 = colocateWith p1 (render $ CoreOps.gather p1 ids)
|
||||
go _ = CoreOps.dynamicStitch pindices <$> partitionedResult
|
||||
np = genericLength params
|
||||
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
|
||||
newIds = ids `CoreOps.div` np
|
||||
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
|
||||
-- Partition list of ids based on assignments into np separate lists
|
||||
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
|
||||
-- Similarly, partition the original indices.
|
||||
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
|
||||
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
|
||||
partitionedResult = zipWithM
|
||||
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
|
||||
params gatherIds
|
||||
np = genericLength params
|
||||
flatIds = CoreOps.reshape ids (singleton (-1))
|
||||
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
|
||||
newIds = flatIds `CoreOps.div` np
|
||||
originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1
|
||||
-- Partition list of ids based on assignments into np separate lists
|
||||
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
|
||||
-- Similarly, partition the original indices.
|
||||
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
|
||||
singleton i = vector [i :: Int32]
|
||||
|
||||
embeddingLookup [] _ = error "embeddingLookup requires params to be non empty"
|
||||
|
|
Loading…
Reference in a new issue