1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-28 23:48:36 +02:00
tensorflow-haskell/tensorflow-ops/tests/TypesTest.hs
fkm3 f170df9d13 Support fetching storable vectors + use them in benchmark (#50)
In addition, you can now fetch TensorData directly. This might be useful in
scenarios where you feed the result of a computation back in, like RNN.

Before:

benchmarking feedFetch/4 byte
time                 83.31 μs   (81.88 μs .. 84.75 μs)
                     0.997 R²   (0.994 R² .. 0.998 R²)
mean                 87.32 μs   (86.06 μs .. 88.83 μs)
std dev              4.580 μs   (3.698 μs .. 5.567 μs)
variance introduced by outliers: 55% (severely inflated)

benchmarking feedFetch/4 KiB
time                 114.9 μs   (111.5 μs .. 118.2 μs)
                     0.996 R²   (0.994 R² .. 0.998 R²)
mean                 117.3 μs   (116.2 μs .. 118.6 μs)
std dev              3.877 μs   (3.058 μs .. 5.565 μs)
variance introduced by outliers: 31% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 109.0 ms   (107.9 ms .. 110.7 ms)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 108.6 ms   (108.2 ms .. 109.2 ms)
std dev              740.2 μs   (353.2 μs .. 1.186 ms)

After:

benchmarking feedFetch/4 byte
time                 82.92 μs   (80.55 μs .. 85.24 μs)
                     0.996 R²   (0.993 R² .. 0.998 R²)
mean                 83.58 μs   (82.34 μs .. 84.89 μs)
std dev              4.327 μs   (3.664 μs .. 5.375 μs)
variance introduced by outliers: 54% (severely inflated)

benchmarking feedFetch/4 KiB
time                 85.69 μs   (83.81 μs .. 87.30 μs)
                     0.997 R²   (0.996 R² .. 0.999 R²)
mean                 86.99 μs   (86.11 μs .. 88.15 μs)
std dev              3.608 μs   (2.854 μs .. 5.273 μs)
variance introduced by outliers: 43% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 1.582 ms   (1.509 ms .. 1.677 ms)
                     0.970 R²   (0.936 R² .. 0.993 R²)
mean                 1.645 ms   (1.554 ms .. 1.981 ms)
std dev              490.6 μs   (138.9 μs .. 1.067 ms)
variance introduced by outliers: 97% (severely inflated)
2016-12-14 18:53:06 -08:00

135 lines
5.3 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- Purposely disabled to confirm doubleFuncNoSig can be written without type.
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
import Control.Monad (replicateM)
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Google.Test (googleTest)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.QuickCheck (Arbitrary(..), listOf, suchThat)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.GenOps.Core as TF (select)
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
instance Arbitrary B.ByteString where
arbitrary = B.pack <$> arbitrary
-- Test encoding tensors, feeding them through tensorflow, and decoding the
-- results.
testFFIRoundTrip :: Test
testFFIRoundTrip = testCase "testFFIRoundTrip" $
TF.runSession $ do
let floatData = V.fromList [1..6 :: Float]
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
boolData = V.fromList [True, True, False, True, False, False]
f <- TF.build $ TF.placeholder [2,3]
s <- TF.build $ TF.placeholder [2,3]
b <- TF.build $ TF.placeholder [2,3]
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
, TF.feed s (TF.encodeTensorData [2,3] stringData)
, TF.feed b (TF.encodeTensorData [2,3] boolData)
]
-- Do something idempotent to the tensors to verify that tensorflow can
-- handle the encoding. Originally this used `TF.identity`, but that
-- wasn't enough to catch a bug in the encoding of Bool.
(f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b)
liftIO $ do
floatData @=? f'
stringData @=? s'
boolData @=? b'
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
deriving Show
instance Arbitrary a => Arbitrary (TensorDataInputs a) where
arbitrary = do
-- Limit the size of the final vector, and also guard against overflow
-- (i.e., p<0) when there are too many dimensions
let validProduct p = p > 0 && p < 100
sizes <- listOf (arbitrary `suchThat` (>0))
`suchThat` (validProduct . product)
elems <- replicateM (fromIntegral $ product sizes) arbitrary
return $ TensorDataInputs sizes (V.fromList elems)
-- Test that a vector is unchanged after being encoded and decoded.
encodeDecodeProp :: (TF.TensorDataType V.Vector a, Eq a) => TensorDataInputs a -> Bool
encodeDecodeProp (TensorDataInputs shape vec) =
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
testEncodeDecodeQcFloat :: Test
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
testEncodeDecodeQcInt64 :: Test
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
testEncodeDecodeQcString :: Test
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
doubleOrInt64Func :: TF.OneOf '[Double, Int64] a => a -> a
doubleOrInt64Func = id
doubleOrFloatFunc :: TF.OneOf '[Double, Float] a => a -> a
doubleOrFloatFunc = id
doubleFunc :: TF.OneOf '[Double] a => a -> a
doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
-- No explicit type signature; make sure it can be inferred automatically.
-- (Note: this would fail if we didn't have NoMonomorphismRestriction, since it
-- can't simplify the type all the way to `Double -> Double`.
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
typeConstraintTests :: Test
typeConstraintTests = testCase "type constraints" $ do
42 @=? doubleOrInt64Func (42 :: Double)
42 @=? doubleOrInt64Func (42 :: Int64)
42 @=? doubleOrFloatFunc (42 :: Double)
42 @=? doubleOrFloatFunc (42 :: Float)
42 @=? doubleFunc (42 :: Double)
42 @=? doubleFuncNoSig (42 :: Double)
main :: IO ()
main = googleTest [ testFFIRoundTrip
, testEncodeDecodeQcFloat
, testEncodeDecodeQcInt64
, testEncodeDecodeQcString
, typeConstraintTests
]