{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import TensorFlow.Build (MonadBuild)
import TensorFlow.Ops (shape, vector)
import TensorFlow.Tensor (Tensor, Value, Rendered, colocateWith, render)
import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
embeddingLookup :: forall a b v1 v2 m .
( MonadBuild m
, Rendered (Tensor v1)
, TensorType a
, OneOf '[Int64, Int32] b
, Num b
)
=> [Tensor v1 a]
-> Tensor v2 b
-> m (Tensor Value a)
embeddingLookup :: [Tensor v1 a] -> Tensor v2 b -> m (Tensor Value a)
embeddingLookup [p0 :: Tensor v1 a
p0] ids :: Tensor v2 b
ids = Tensor v1 a -> m (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) (t :: * -> *) b a.
(MonadBuild m, Rendered t) =>
t b -> m a -> m a
colocateWith Tensor v1 a
p0 (Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor v1 a -> Tensor v2 b -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =>
Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams
CoreOps.gather Tensor v1 a
p0 Tensor v2 b
ids)
embeddingLookup params :: [Tensor v1 a]
params@(p0 :: Tensor v1 a
p0 : _) ids :: Tensor v2 b
ids = do
[Tensor Value a]
partitionedResult <- (Tensor v1 a -> Tensor Build b -> m (Tensor Value a))
-> [Tensor v1 a] -> [Tensor Build b] -> m [Tensor Value a]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
(\p :: Tensor v1 a
p g :: Tensor Build b
g -> Tensor v1 a -> m (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) (t :: * -> *) b a.
(MonadBuild m, Rendered t) =>
t b -> m a -> m a
colocateWith Tensor v1 a
p (m (Tensor Value a) -> m (Tensor Value a))
-> m (Tensor Value a) -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor v1 a -> Tensor Build b -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =>
Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams
CoreOps.gather Tensor v1 a
p Tensor Build b
g)
[Tensor v1 a]
params [Tensor Build b]
gatherIds
let unshapedResult :: Tensor Build a
unshapedResult = [Tensor Build Int32] -> [Tensor Value a] -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
[Tensor v'1 Int32] -> [Tensor v'2 t] -> Tensor Build t
CoreOps.dynamicStitch [Tensor Build Int32]
forall t.
(t /= Int8, t /= Int16, t /= Word8, t /= ByteString, t /= Bool,
t /= Word16, t /= Float, t /= Double, TensorType t, Num t) =>
[Tensor Build t]
pindices [Tensor Value a]
partitionedResult
Tensor Value Int32
paramShape <- Tensor v1 a -> m (Tensor Value Int32) -> m (Tensor Value Int32)
forall (m :: * -> *) (t :: * -> *) b a.
(MonadBuild m, Rendered t) =>
t b -> m a -> m a
colocateWith Tensor v1 a
p0 (Tensor Build Int32 -> m (Tensor Value Int32)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor v1 a -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor v1 a
p0))
let finalShape :: Tensor Build Int32
finalShape = Tensor Build Int32 -> [Tensor Build Int32] -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Tensor v'1 Int32 -> [Tensor v'2 t] -> Tensor Build t
CoreOps.concat 0 [Tensor v2 b -> Tensor Build Int32
forall t (v :: * -> *).
TensorType t =>
Tensor v t -> Tensor Build Int32
shape Tensor v2 b
ids, Tensor Build Int32
tailShape]
tailShape :: Tensor Build Int32
tailShape = Tensor Value Int32
-> Tensor Build Int32 -> Tensor Build Int32 -> Tensor Build Int32
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t index.
(TensorType t, OneOf '[Int32, Int64] index) =>
Tensor v'1 t
-> Tensor v'2 index -> Tensor v'3 index -> Tensor Build t
CoreOps.slice Tensor Value Int32
paramShape (Int32 -> Tensor Build Int32
singleton 1) (Int32 -> Tensor Build Int32
singleton (-1))
Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build Int32 -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =>
Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams
CoreOps.reshape Tensor Build a
unshapedResult Tensor Build Int32
finalShape
where
np :: b
np = Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor v1 a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor v1 a]
params)
flatIds :: Tensor Build b
flatIds = Tensor v2 b -> Tensor Build Int32 -> Tensor Build b
forall (v'1 :: * -> *) (v'2 :: * -> *) tparams tindices.
(TensorType tparams, OneOf '[Int32, Int64] tindices) =>
Tensor v'1 tparams -> Tensor v'2 tindices -> Tensor Build tparams
CoreOps.reshape Tensor v2 b
ids (Int32 -> Tensor Build Int32
singleton (-1))
pAssignments :: Tensor Build dstT
pAssignments = Tensor Build b -> Tensor Build dstT
forall (v'1 :: * -> *) srcT dstT.
(TensorType srcT, TensorType dstT) =>
Tensor v'1 srcT -> Tensor Build dstT
CoreOps.cast (Tensor Build b
flatIds Tensor Build b -> Tensor Build b -> Tensor Build b
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf '[Int32, Int64, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.mod` Tensor Build b
forall b. Num b => b
np)
newIds :: Tensor Build b
newIds = Tensor Build b
flatIds Tensor Build b -> Tensor Build b -> Tensor Build b
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`CoreOps.div` Tensor Build b
forall b. Num b => b
np
originalIndices :: Tensor Build tidx
originalIndices = Tensor Build tidx
-> Tensor Build tidx -> Tensor Build tidx -> Tensor Build tidx
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) tidx.
OneOf '[Int32, Int64, Word16, Double, Float] tidx =>
Tensor v'1 tidx
-> Tensor v'2 tidx -> Tensor v'3 tidx -> Tensor Build tidx
CoreOps.range 0 (Tensor Build b -> Tensor Build tidx
forall (v'1 :: * -> *) t out_type.
(TensorType t, OneOf '[Int32, Int64] out_type) =>
Tensor v'1 t -> Tensor Build out_type
CoreOps.size Tensor Build b
flatIds) 1
gatherIds :: [Tensor Build b]
gatherIds = Int64 -> Tensor Build b -> Tensor Build Int32 -> [Tensor Build b]
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Int64 -> Tensor v'1 t -> Tensor v'2 Int32 -> [Tensor Build t]
CoreOps.dynamicPartition Int64
forall b. Num b => b
np Tensor Build b
newIds Tensor Build Int32
forall dstT. TensorType dstT => Tensor Build dstT
pAssignments
pindices :: [Tensor Build t]
pindices = Int64 -> Tensor Build t -> Tensor Build Int32 -> [Tensor Build t]
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
TensorType t =>
Int64 -> Tensor v'1 t -> Tensor v'2 Int32 -> [Tensor Build t]
CoreOps.dynamicPartition Int64
forall b. Num b => b
np Tensor Build t
forall tidx.
(tidx /= Int8, tidx /= Int16, tidx /= Word8, tidx /= ByteString,
tidx /= Bool, tidx /= Word16, tidx /= Float, tidx /= Double,
TensorType tidx, Num tidx) =>
Tensor Build tidx
originalIndices Tensor Build Int32
forall dstT. TensorType dstT => Tensor Build dstT
pAssignments
singleton :: Int32 -> Tensor Build Int32
singleton i :: Int32
i = [Int32] -> Tensor Build Int32
forall a. TensorType a => [a] -> Tensor Build a
vector [Int32
i :: Int32]
embeddingLookup [] _ = [Char] -> m (Tensor Value a)
forall a. HasCallStack => [Char] -> a
error "embeddingLookup requires params to be non empty"