{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Test
( assertAllClose
) where
import qualified Data.Vector as V
import Test.HUnit ((@?))
import Test.HUnit.Lang (Assertion)
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
assertAllClose :: Vector Float -> Vector Float -> Assertion
assertAllClose xs :: Vector Float
xs ys :: Vector Float
ys = (Float -> Bool) -> Vector Float -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
<= Float
tol) ((Float -> Float -> Float)
-> Vector Float -> Vector Float -> Vector Float
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Float -> Float -> Float
forall a. Num a => a -> a -> a
absDiff Vector Float
xs Vector Float
ys) Bool -> String -> Assertion
forall t.
(HasCallStack, AssertionPredicable t) =>
t -> String -> Assertion
@?
"Difference > tolerance: \nxs: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Vector Float -> String
forall a. Show a => a -> String
show Vector Float
xs String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\nys: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Vector Float -> String
forall a. Show a => a -> String
show Vector Float
ys
String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\ntolerance: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Float -> String
forall a. Show a => a -> String
show Float
tol
where
absDiff :: a -> a -> a
absDiff x :: a
x y :: a
y = a -> a
forall a. Num a => a -> a
abs (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y)
tol :: Float
tol = 0.001 :: Float