1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-16 16:09:43 +01:00
tensorflow-haskell/tensorflow-ops/tests/EmbeddingOpsTest.hs

142 lines
5.6 KiB
Haskell
Raw Normal View History

2016-10-24 21:26:42 +02:00
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | Tests for EmbeddingOps.
module Main where
import Data.Int (Int32, Int64)
import Data.List (genericLength)
import Google.Test (googleTest)
import TensorFlow.EmbeddingOps (embeddingLookup)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.Framework.Providers.HUnit (testCase)
2016-10-24 21:26:42 +02:00
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run)
import qualified Data.Vector as V
import qualified TensorFlow.GenOps.Core as CoreOps
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
buildAndRun = TF.runSession . TF.buildAnd TF.run
-- | Tries to perform a simple embedding lookup, with two partitions.
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
let embedding1 = [ 1, 1, 1 ] :: [Int32]
let embedding2 = [ 0, 0, 0 ] :: [Int32]
let embedding = [ TF.constant shape embedding1
, TF.constant shape embedding2
]
let idValues = [0, 1] :: [Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids
(values, shape) <- buildAndRun $ do
vs <- op
return (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]
-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
-- | Tries to perform a simple embedding lookup, with only a single partition.
testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do
let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items.
let embeddingInit = [ 1, 1, 1
, 0, 0, 0 ] :: [Int32]
let embedding = TF.constant shape embeddingInit
let idValues = [0, 1] :: [Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids
(values, shape) <- buildAndRun $ do
vs <- op
return (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]
-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
2016-10-24 21:26:42 +02:00
-- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
=> LookupExample a -> Property
testEmbeddingLookupUndoesSplit
(LookupExample numParts
shape@(TF.Shape (firstDim : restDims))
values
indices) =
let modShardedValues :: [TF.Tensor TF.Value a] =
CoreOps.dynamicPartition numParts shapedValues cyclicCounter
cyclicCounter :: TF.Tensor TF.Value Int32 =
TF.vector [0..fromIntegral firstDim-1]
`CoreOps.mod` fromIntegral numParts
indicesVector = TF.vector indices
directs = CoreOps.gather shapedValues indicesVector
shapedValues = TF.constant shape values
in monadicIO $ run $ do
(shapeOut, got, want :: V.Vector a) <-
TF.runSession $ TF.buildAnd TF.run $ do
embeddings <- embeddingLookup modShardedValues indicesVector
return (TF.cast (TF.shape embeddings), embeddings, directs)
-- Checks the explicitly documented invariant of embeddingLookup.
shapeOut @=? V.fromList (genericLength indices : restDims)
got @=? want
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit.
data LookupExample a = LookupExample
Int64 -- ^ number of ways to split.
TF.Shape -- ^ shape of the generated tensor
[a] -- ^ data for the tensor
[Int32] -- ^ indices to split the tensor by
deriving Show
instance Arbitrary a => Arbitrary (LookupExample a) where
arbitrary = do
rank <- choose (1, 4)
-- Takes rank-th root of 100 to cap the tensor size.
let maxDim = fromIntegral $ ceiling $ 100 ** (1 / fromIntegral rank)
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
values <- vectorOf (fromIntegral $ product shape) arbitrary
numParts <- choose (2, 15)
indSize <- choose (0, fromIntegral $ firstDim - 1)
indices <- vectorOf indSize (choose (0, fromIntegral firstDim - 1))
return $ LookupExample numParts (TF.Shape shape) values indices
main :: IO ()
main = googleTest
[ testProperty "EmbeddingLookupUndoesSplit"
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
, testEmbeddingLookupHasRightShape
, testEmbeddingLookupHasRightShapeWithPartition
2016-10-24 21:26:42 +02:00
]