mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
added lib for testing helper functions
This commit is contained in:
parent
7e1421c4a7
commit
de58a3ffa4
|
@ -11,6 +11,7 @@ packages:
|
|||
- tensorflow-mnist-input-data
|
||||
- tensorflow-queue
|
||||
- tensorflow-nn
|
||||
- tensorflow-test
|
||||
|
||||
extra-deps:
|
||||
# proto-lens is not yet in Stackage.
|
||||
|
|
|
@ -30,6 +30,7 @@ Test-Suite NNTest
|
|||
, QuickCheck
|
||||
, base
|
||||
, tensorflow
|
||||
, tensorflow-test
|
||||
, tensorflow-ops
|
||||
, tensorflow-nn
|
||||
, google-shim
|
||||
|
|
|
@ -22,6 +22,7 @@ module Main where
|
|||
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Test (assertAllClose)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@?))
|
||||
import Test.HUnit.Lang (Assertion(..))
|
||||
|
@ -63,15 +64,6 @@ defInputs = Inputs {
|
|||
}
|
||||
|
||||
|
||||
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
||||
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
|
||||
++ "\ntolerance: " ++ show tol)
|
||||
where
|
||||
absDiff x y = abs (x - y)
|
||||
tol = 0.001 :: Float
|
||||
|
||||
|
||||
testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||
let inputs = defInputs
|
||||
vLogits = TF.vector $ logits inputs
|
||||
|
|
|
@ -60,6 +60,7 @@ Test-Suite EmbeddingOpsTest
|
|||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-test
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
|
|
|
@ -28,6 +28,7 @@ import Test.HUnit ((@=?), (@?))
|
|||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||
import TensorFlow.Test (assertAllClose)
|
||||
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
@ -119,15 +120,6 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|||
assertAllClose gs (V.fromList ([0, 0] :: [Float]))
|
||||
|
||||
|
||||
-- TODO: Move this out into a central testing utils lib and remove from NN
|
||||
-- tests
|
||||
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
||||
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
|
||||
++ "\ntolerance: " ++ show tol)
|
||||
where
|
||||
absDiff x y = abs (x - y)
|
||||
tol = 0.001 :: Float
|
||||
|
||||
|
||||
-- Verifies that direct gather is the same as dynamic split into
|
||||
|
|
Loading…
Reference in New Issue
Block a user