1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

added lib for testing helper functions

This commit is contained in:
silky 2016-11-17 13:00:37 +11:00
parent 7e1421c4a7
commit de58a3ffa4
5 changed files with 5 additions and 18 deletions

View File

@ -11,6 +11,7 @@ packages:
- tensorflow-mnist-input-data
- tensorflow-queue
- tensorflow-nn
- tensorflow-test
extra-deps:
# proto-lens is not yet in Stackage.

View File

@ -30,6 +30,7 @@ Test-Suite NNTest
, QuickCheck
, base
, tensorflow
, tensorflow-test
, tensorflow-ops
, tensorflow-nn
, google-shim

View File

@ -22,6 +22,7 @@ module Main where
import Data.Maybe (fromMaybe)
import Google.Test (googleTest)
import TensorFlow.Test (assertAllClose)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@?))
import Test.HUnit.Lang (Assertion(..))
@ -63,15 +64,6 @@ defInputs = Inputs {
}
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
++ "\ntolerance: " ++ show tol)
where
absDiff x y = abs (x - y)
tol = 0.001 :: Float
testLogisticOutput = testCase "testLogisticOutput" $ do
let inputs = defInputs
vLogits = TF.vector $ logits inputs

View File

@ -60,6 +60,7 @@ Test-Suite EmbeddingOpsTest
, lens-family
, google-shim
, tensorflow
, tensorflow-test
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-proto

View File

@ -28,6 +28,7 @@ import Test.HUnit ((@=?), (@?))
import Test.Framework.Providers.HUnit (testCase)
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run)
import TensorFlow.Test (assertAllClose)
import qualified Data.Vector as V
import qualified TensorFlow.GenOps.Core as CoreOps
@ -119,15 +120,6 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
assertAllClose gs (V.fromList ([0, 0] :: [Float]))
-- TODO: Move this out into a central testing utils lib and remove from NN
-- tests
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
++ "\ntolerance: " ++ show tol)
where
absDiff x y = abs (x - y)
tol = 0.001 :: Float
-- Verifies that direct gather is the same as dynamic split into