1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

# This is a combination of 2 commits.

# The first commit's message is:

test case to show can't calculate grad

# This is the 2nd commit message:

typo
This commit is contained in:
silky 2016-11-11 10:06:07 +11:00
parent 93e27a12c6
commit 178ccbc68b

View File

@ -23,7 +23,8 @@ import Data.List (genericLength)
import Google.Test (googleTest)
import TensorFlow.EmbeddingOps (embeddingLookup)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.HUnit.Lang (Assertion(..))
import Test.HUnit ((@=?), (@?))
import Test.Framework.Providers.HUnit (testCase)
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run)
@ -34,10 +35,15 @@ import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Nodes as TF
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
buildAndRun = TF.runSession . TF.buildAnd TF.run
-- | Tries to perform a simple embedding lookup, with two partitions.
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
@ -85,6 +91,45 @@ testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
-- | Check that we can calculate gradients w.r.t embeddings.
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
let xVals = V.fromList ([10, 20] :: [Float]) -- Same as "embedding", so gradient should be zero.
let shape = TF.Shape [2]
gs <- TF.runSession $ do
grads <- TF.build $ do
let shape = TF.Shape [2, 1]
let embeddingInit = [10, 20] :: [Float]
let idValues = [0, 1] :: [Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
x <- TF.placeholder (TF.Shape [2])
embedding <- TF.initializedVariable =<< (TF.render $ TF.constant shape embeddingInit)
op <- embeddingLookup [embedding] ids
let loss = TF.mean (CoreOps.square (TF.abs (op - x))) (TF.scalar (0 :: Int32))
grad <- fmap head (TF.gradients loss [embedding])
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
-- Gradients should be zero (or close)
assertAllClose gs (V.fromList ([0, 0] :: [Float]))
-- TODO: Move this out into a central testing utils lib and remove from NN
-- tests
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
++ "\ntolerance: " ++ show tol)
where
absDiff x y = abs (x - y)
tol = 0.001 :: Float
-- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
@ -138,4 +183,5 @@ main = googleTest
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
, testEmbeddingLookupHasRightShape
, testEmbeddingLookupHasRightShapeWithPartition
, testEmbeddingLookupGradients
]