mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
# This is a combination of 2 commits.
# The first commit's message is: test case to show can't calculate grad # This is the 2nd commit message: typo
This commit is contained in:
parent
93e27a12c6
commit
178ccbc68b
|
@ -23,7 +23,8 @@ import Data.List (genericLength)
|
|||
import Google.Test (googleTest)
|
||||
import TensorFlow.EmbeddingOps (embeddingLookup)
|
||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||
import Test.HUnit ((@=?))
|
||||
import Test.HUnit.Lang (Assertion(..))
|
||||
import Test.HUnit ((@=?), (@?))
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||
|
@ -34,10 +35,15 @@ import qualified TensorFlow.Ops as TF
|
|||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
|
||||
|
||||
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
|
||||
buildAndRun = TF.runSession . TF.buildAnd TF.run
|
||||
|
||||
|
||||
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
||||
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||
|
@ -85,6 +91,45 @@ testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $
|
|||
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]
|
||||
|
||||
|
||||
-- | Check that we can calculate gradients w.r.t embeddings.
|
||||
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||
let xVals = V.fromList ([10, 20] :: [Float]) -- Same as "embedding", so gradient should be zero.
|
||||
let shape = TF.Shape [2]
|
||||
|
||||
gs <- TF.runSession $ do
|
||||
grads <- TF.build $ do
|
||||
let shape = TF.Shape [2, 1]
|
||||
let embeddingInit = [10, 20] :: [Float]
|
||||
let idValues = [0, 1] :: [Int32]
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
|
||||
x <- TF.placeholder (TF.Shape [2])
|
||||
embedding <- TF.initializedVariable =<< (TF.render $ TF.constant shape embeddingInit)
|
||||
|
||||
op <- embeddingLookup [embedding] ids
|
||||
let loss = TF.mean (CoreOps.square (TF.abs (op - x))) (TF.scalar (0 :: Int32))
|
||||
|
||||
grad <- fmap head (TF.gradients loss [embedding])
|
||||
|
||||
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
|
||||
|
||||
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
|
||||
|
||||
-- Gradients should be zero (or close)
|
||||
assertAllClose gs (V.fromList ([0, 0] :: [Float]))
|
||||
|
||||
|
||||
-- TODO: Move this out into a central testing utils lib and remove from NN
|
||||
-- tests
|
||||
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
||||
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
|
||||
++ "\ntolerance: " ++ show tol)
|
||||
where
|
||||
absDiff x y = abs (x - y)
|
||||
tol = 0.001 :: Float
|
||||
|
||||
|
||||
-- Verifies that direct gather is the same as dynamic split into
|
||||
-- partitions, followed by embedding lookup.
|
||||
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
|
||||
|
@ -138,4 +183,5 @@ main = googleTest
|
|||
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
|
||||
, testEmbeddingLookupHasRightShape
|
||||
, testEmbeddingLookupHasRightShapeWithPartition
|
||||
, testEmbeddingLookupGradients
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue
Block a user