2016-10-24 21:26:42 +02:00
|
|
|
-- Copyright 2016 TensorFlow authors.
|
|
|
|
--
|
|
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
-- you may not use this file except in compliance with the License.
|
|
|
|
-- You may obtain a copy of the License at
|
|
|
|
--
|
|
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
--
|
|
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
-- See the License for the specific language governing permissions and
|
|
|
|
-- limitations under the License.
|
|
|
|
|
2016-12-15 03:53:06 +01:00
|
|
|
{-# LANGUAGE FlexibleContexts #-}
|
2016-10-24 21:26:42 +02:00
|
|
|
{-# LANGUAGE RankNTypes #-}
|
|
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
|
|
|
|
-- | Tests for EmbeddingOps.
|
|
|
|
module Main where
|
|
|
|
|
2017-04-07 00:10:33 +02:00
|
|
|
import Control.Monad.IO.Class (liftIO)
|
2016-10-24 21:26:42 +02:00
|
|
|
import Data.Int (Int32, Int64)
|
|
|
|
import Data.List (genericLength)
|
|
|
|
import TensorFlow.EmbeddingOps (embeddingLookup)
|
2017-05-11 00:26:03 +02:00
|
|
|
import Test.Framework (defaultMain, Test)
|
2016-10-24 21:26:42 +02:00
|
|
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
2016-11-18 19:42:02 +01:00
|
|
|
import Test.HUnit ((@=?))
|
2016-11-09 00:30:05 +01:00
|
|
|
import Test.Framework.Providers.HUnit (testCase)
|
2016-10-24 21:26:42 +02:00
|
|
|
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
|
|
|
import Test.QuickCheck.Monadic (monadicIO, run)
|
2016-11-17 22:54:36 +01:00
|
|
|
import TensorFlow.Test (assertAllClose)
|
2016-10-24 21:26:42 +02:00
|
|
|
|
|
|
|
import qualified Data.Vector as V
|
|
|
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
|
|
|
import qualified TensorFlow.Ops as TF
|
|
|
|
import qualified TensorFlow.Session as TF
|
|
|
|
import qualified TensorFlow.Tensor as TF
|
|
|
|
import qualified TensorFlow.Types as TF
|
2016-11-17 22:54:36 +01:00
|
|
|
import qualified TensorFlow.Gradient as TF
|
|
|
|
import qualified TensorFlow.Build as TF
|
2016-11-09 00:30:05 +01:00
|
|
|
|
2016-11-17 22:54:36 +01:00
|
|
|
|
2016-11-09 00:30:05 +01:00
|
|
|
-- | Tries to perform a simple embedding lookup, with two partitions.
|
2016-11-18 19:42:02 +01:00
|
|
|
testEmbeddingLookupHasRightShapeWithPartition :: Test
|
|
|
|
testEmbeddingLookupHasRightShapeWithPartition =
|
2016-11-17 22:54:36 +01:00
|
|
|
testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
2016-11-18 19:42:02 +01:00
|
|
|
let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
2016-11-17 22:54:36 +01:00
|
|
|
let embedding1 = [1, 1, 1 :: Int32]
|
|
|
|
let embedding2 = [0, 0, 0 :: Int32]
|
2016-11-09 00:30:05 +01:00
|
|
|
|
2016-11-17 22:54:36 +01:00
|
|
|
let idValues = [0, 1 :: Int32]
|
2016-11-09 00:30:05 +01:00
|
|
|
|
2017-03-18 20:08:53 +01:00
|
|
|
(values, shape) <- TF.runSession $ do
|
2017-04-07 00:10:33 +02:00
|
|
|
embedding <- mapM TF.render [ TF.constant embShape embedding1
|
|
|
|
, TF.constant embShape embedding2
|
|
|
|
]
|
|
|
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
|
|
|
vs <- embeddingLookup embedding ids
|
2017-03-18 20:08:53 +01:00
|
|
|
TF.run (vs, TF.shape vs)
|
2016-11-09 00:30:05 +01:00
|
|
|
|
|
|
|
-- This is the shape that is returned in the equiv. Python.
|
2016-11-17 22:54:36 +01:00
|
|
|
shape @=? V.fromList [1, 2, 3]
|
2016-11-09 00:30:05 +01:00
|
|
|
|
|
|
|
-- "[0, 1]" should pull out the resulting vector.
|
2016-11-17 22:54:36 +01:00
|
|
|
values @=? V.fromList [1, 1, 1, 0, 0, 0]
|
2016-11-09 00:30:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
-- | Tries to perform a simple embedding lookup, with only a single partition.
|
2016-11-18 19:42:02 +01:00
|
|
|
testEmbeddingLookupHasRightShape :: Test
|
|
|
|
testEmbeddingLookupHasRightShape =
|
2016-11-17 22:54:36 +01:00
|
|
|
testCase "testEmbeddingLookupHasRightShape" $ do
|
|
|
|
-- Consider a 3-dim embedding of two items
|
2016-11-18 19:42:02 +01:00
|
|
|
let embShape = TF.Shape [2, 3]
|
2016-11-09 00:30:05 +01:00
|
|
|
let embeddingInit = [ 1, 1, 1
|
2016-11-17 22:54:36 +01:00
|
|
|
, 0, 0, 0 :: Int32
|
|
|
|
]
|
2016-11-09 00:30:05 +01:00
|
|
|
|
2016-11-17 22:54:36 +01:00
|
|
|
let idValues = [0, 1 :: Int32]
|
2016-11-09 00:30:05 +01:00
|
|
|
|
2017-03-18 20:08:53 +01:00
|
|
|
(values, shape) <- TF.runSession $ do
|
2017-04-07 00:10:33 +02:00
|
|
|
embedding <- TF.render $ TF.constant embShape embeddingInit
|
|
|
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
|
|
|
vs <- embeddingLookup [embedding] ids
|
2017-03-18 20:08:53 +01:00
|
|
|
TF.run (vs, TF.shape vs)
|
2016-11-09 00:30:05 +01:00
|
|
|
|
|
|
|
-- This is the shape that is returned in the equiv. Python.
|
2016-11-17 22:54:36 +01:00
|
|
|
shape @=? V.fromList [1, 2, 3]
|
2016-11-09 00:30:05 +01:00
|
|
|
|
|
|
|
-- "[0, 1]" should pull out the resulting vector.
|
2016-11-17 22:54:36 +01:00
|
|
|
values @=? V.fromList [1, 1, 1, 0, 0, 0]
|
|
|
|
|
|
|
|
-- | Check that we can calculate gradients w.r.t embeddings.
|
2016-11-18 19:42:02 +01:00
|
|
|
testEmbeddingLookupGradients :: Test
|
2016-11-17 22:54:36 +01:00
|
|
|
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|
|
|
-- Agrees with "embedding", so gradient should be zero.
|
|
|
|
let xVals = V.fromList ([20, 20 :: Float])
|
|
|
|
let shape = TF.Shape [2]
|
|
|
|
|
|
|
|
gs <- TF.runSession $ do
|
2016-11-18 19:42:02 +01:00
|
|
|
let embShape = TF.Shape [2, 1]
|
2016-11-17 22:54:36 +01:00
|
|
|
let embeddingInit = [1, 20 ::Float]
|
|
|
|
let idValues = [1, 1 :: Int32]
|
|
|
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
|
|
|
|
2016-11-18 19:42:02 +01:00
|
|
|
x <- TF.placeholder (TF.Shape [2])
|
|
|
|
embedding <- TF.initializedVariable
|
2017-04-07 00:10:33 +02:00
|
|
|
(TF.constant embShape embeddingInit)
|
2016-11-18 19:42:02 +01:00
|
|
|
|
2016-11-17 22:54:36 +01:00
|
|
|
op <- embeddingLookup [embedding] ids
|
2017-04-07 00:10:33 +02:00
|
|
|
let twoNorm = CoreOps.square $ TF.abs (op `TF.sub` x)
|
2016-11-17 22:54:36 +01:00
|
|
|
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
|
|
|
|
|
|
|
grad <- fmap head (TF.gradients loss [embedding])
|
2017-03-18 20:08:53 +01:00
|
|
|
TF.runWithFeeds
|
|
|
|
[TF.feed x $ TF.encodeTensorData shape xVals]
|
|
|
|
grad
|
2016-11-17 22:54:36 +01:00
|
|
|
-- Gradients should be zero (or close)
|
|
|
|
assertAllClose gs (V.fromList ([0, 0 :: Float]))
|
2016-11-09 00:30:05 +01:00
|
|
|
|
|
|
|
|
2016-10-24 21:26:42 +02:00
|
|
|
-- Verifies that direct gather is the same as dynamic split into
|
|
|
|
-- partitions, followed by embedding lookup.
|
2016-12-15 03:53:06 +01:00
|
|
|
testEmbeddingLookupUndoesSplit ::
|
|
|
|
forall a. (TF.TensorDataType V.Vector a, Show a, Eq a)
|
|
|
|
=> LookupExample a -> Property
|
2016-10-24 21:26:42 +02:00
|
|
|
testEmbeddingLookupUndoesSplit
|
|
|
|
(LookupExample numParts
|
|
|
|
shape@(TF.Shape (firstDim : restDims))
|
|
|
|
values
|
2017-04-07 00:10:33 +02:00
|
|
|
indices) = monadicIO $ run $ TF.runSession $ do
|
|
|
|
let shapedValues = TF.constant shape values
|
|
|
|
indicesVector <- TF.render $ TF.vector indices
|
|
|
|
let directs = CoreOps.gather shapedValues indicesVector
|
|
|
|
let cyclicCounter :: TF.Tensor TF.Build Int32 =
|
2016-10-24 21:26:42 +02:00
|
|
|
TF.vector [0..fromIntegral firstDim-1]
|
|
|
|
`CoreOps.mod` fromIntegral numParts
|
2017-04-07 00:10:33 +02:00
|
|
|
modShardedValues :: [TF.Tensor TF.Value a] <-
|
|
|
|
mapM TF.render $ CoreOps.dynamicPartition numParts shapedValues cyclicCounter
|
|
|
|
embeddings <- embeddingLookup modShardedValues indicesVector
|
|
|
|
(shapeOut, got, want :: V.Vector a) <-
|
|
|
|
TF.run (TF.cast (TF.shape embeddings), embeddings, directs)
|
|
|
|
-- Checks the explicitly documented invariant of embeddingLookup.
|
|
|
|
liftIO $ shapeOut @=? V.fromList (genericLength indices : restDims)
|
|
|
|
liftIO $ got @=? want
|
2016-10-24 21:26:42 +02:00
|
|
|
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
|
|
|
|
|
|
|
|
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit.
|
|
|
|
data LookupExample a = LookupExample
|
|
|
|
Int64 -- ^ number of ways to split.
|
|
|
|
TF.Shape -- ^ shape of the generated tensor
|
|
|
|
[a] -- ^ data for the tensor
|
|
|
|
[Int32] -- ^ indices to split the tensor by
|
|
|
|
deriving Show
|
|
|
|
|
|
|
|
instance Arbitrary a => Arbitrary (LookupExample a) where
|
|
|
|
arbitrary = do
|
|
|
|
rank <- choose (1, 4)
|
|
|
|
-- Takes rank-th root of 100 to cap the tensor size.
|
2016-11-18 19:42:02 +01:00
|
|
|
let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64)
|
|
|
|
doubleMaxDim :: Double
|
|
|
|
doubleMaxDim = 100 ** (1 / fromIntegral rank)
|
2016-10-24 21:26:42 +02:00
|
|
|
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
|
|
|
|
values <- vectorOf (fromIntegral $ product shape) arbitrary
|
|
|
|
numParts <- choose (2, 15)
|
|
|
|
indSize <- choose (0, fromIntegral $ firstDim - 1)
|
|
|
|
indices <- vectorOf indSize (choose (0, fromIntegral firstDim - 1))
|
|
|
|
return $ LookupExample numParts (TF.Shape shape) values indices
|
|
|
|
|
|
|
|
main :: IO ()
|
2017-05-11 00:26:03 +02:00
|
|
|
main = defaultMain
|
2016-10-24 21:26:42 +02:00
|
|
|
[ testProperty "EmbeddingLookupUndoesSplit"
|
|
|
|
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
|
2016-11-09 00:30:05 +01:00
|
|
|
, testEmbeddingLookupHasRightShape
|
|
|
|
, testEmbeddingLookupHasRightShapeWithPartition
|
2016-11-17 22:54:36 +01:00
|
|
|
, testEmbeddingLookupGradients
|
2016-10-24 21:26:42 +02:00
|
|
|
]
|