1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00
tensorflow-haskell/tensorflow-ops/tests/TypesTest.hs

135 lines
5.2 KiB
Haskell
Raw Permalink Normal View History

2016-10-24 21:26:42 +02:00
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
2016-10-24 21:26:42 +02:00
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- Purposely disabled to confirm doubleFuncNoSig can be written without type.
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
2016-10-24 21:26:42 +02:00
import Control.Monad (replicateM)
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Test.Framework (defaultMain, Test)
2016-10-24 21:26:42 +02:00
import Test.Framework.Providers.HUnit (testCase)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.QuickCheck (Arbitrary(..), listOf, suchThat)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.GenOps.Core as TF (select)
2016-10-24 21:26:42 +02:00
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
instance Arbitrary B.ByteString where
arbitrary = B.pack <$> arbitrary
-- Test encoding tensors, feeding them through tensorflow, and decoding the
-- results.
testFFIRoundTrip :: Test
2016-10-24 21:26:42 +02:00
testFFIRoundTrip = testCase "testFFIRoundTrip" $
TF.runSession $ do
let floatData = V.fromList [1..6 :: Float]
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
boolData = V.fromList [True, True, False, True, False, False]
f <- TF.placeholder [2,3]
s <- TF.placeholder [2,3]
b <- TF.placeholder [2,3]
2016-10-24 21:26:42 +02:00
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
, TF.feed s (TF.encodeTensorData [2,3] stringData)
, TF.feed b (TF.encodeTensorData [2,3] boolData)
2016-10-24 21:26:42 +02:00
]
-- Do something idempotent to the tensors to verify that tensorflow can
-- handle the encoding. Originally this used `TF.identity`, but that
-- wasn't enough to catch a bug in the encoding of Bool.
(f', s', b') <- TF.runWithFeeds feeds
(f `TF.add` 0, TF.identity s, TF.select b b b)
2016-10-24 21:26:42 +02:00
liftIO $ do
floatData @=? f'
stringData @=? s'
boolData @=? b'
2016-10-24 21:26:42 +02:00
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
deriving Show
instance Arbitrary a => Arbitrary (TensorDataInputs a) where
arbitrary = do
-- Limit the size of the final vector, and also guard against overflow
-- (i.e., p<0) when there are too many dimensions
let validProduct p = p > 0 && p < 100
sizes <- listOf (arbitrary `suchThat` (>0))
`suchThat` (validProduct . product)
elems <- replicateM (fromIntegral $ product sizes) arbitrary
return $ TensorDataInputs sizes (V.fromList elems)
-- Test that a vector is unchanged after being encoded and decoded.
Support fetching storable vectors + use them in benchmark (#50) In addition, you can now fetch TensorData directly. This might be useful in scenarios where you feed the result of a computation back in, like RNN. Before: benchmarking feedFetch/4 byte time 83.31 μs (81.88 μs .. 84.75 μs) 0.997 R² (0.994 R² .. 0.998 R²) mean 87.32 μs (86.06 μs .. 88.83 μs) std dev 4.580 μs (3.698 μs .. 5.567 μs) variance introduced by outliers: 55% (severely inflated) benchmarking feedFetch/4 KiB time 114.9 μs (111.5 μs .. 118.2 μs) 0.996 R² (0.994 R² .. 0.998 R²) mean 117.3 μs (116.2 μs .. 118.6 μs) std dev 3.877 μs (3.058 μs .. 5.565 μs) variance introduced by outliers: 31% (moderately inflated) benchmarking feedFetch/4 MiB time 109.0 ms (107.9 ms .. 110.7 ms) 1.000 R² (0.999 R² .. 1.000 R²) mean 108.6 ms (108.2 ms .. 109.2 ms) std dev 740.2 μs (353.2 μs .. 1.186 ms) After: benchmarking feedFetch/4 byte time 82.92 μs (80.55 μs .. 85.24 μs) 0.996 R² (0.993 R² .. 0.998 R²) mean 83.58 μs (82.34 μs .. 84.89 μs) std dev 4.327 μs (3.664 μs .. 5.375 μs) variance introduced by outliers: 54% (severely inflated) benchmarking feedFetch/4 KiB time 85.69 μs (83.81 μs .. 87.30 μs) 0.997 R² (0.996 R² .. 0.999 R²) mean 86.99 μs (86.11 μs .. 88.15 μs) std dev 3.608 μs (2.854 μs .. 5.273 μs) variance introduced by outliers: 43% (moderately inflated) benchmarking feedFetch/4 MiB time 1.582 ms (1.509 ms .. 1.677 ms) 0.970 R² (0.936 R² .. 0.993 R²) mean 1.645 ms (1.554 ms .. 1.981 ms) std dev 490.6 μs (138.9 μs .. 1.067 ms) variance introduced by outliers: 97% (severely inflated)
2016-12-15 03:53:06 +01:00
encodeDecodeProp :: (TF.TensorDataType V.Vector a, Eq a) => TensorDataInputs a -> Bool
2016-10-24 21:26:42 +02:00
encodeDecodeProp (TensorDataInputs shape vec) =
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
testEncodeDecodeQcFloat :: Test
2016-10-24 21:26:42 +02:00
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
testEncodeDecodeQcInt64 :: Test
2016-10-24 21:26:42 +02:00
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
testEncodeDecodeQcString :: Test
2016-10-24 21:26:42 +02:00
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
doubleOrInt64Func :: TF.OneOf '[Double, Int64] a => a -> a
doubleOrInt64Func = id
doubleOrFloatFunc :: TF.OneOf '[Double, Float] a => a -> a
doubleOrFloatFunc = id
doubleFunc :: TF.OneOf '[Double] a => a -> a
doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
-- No explicit type signature; make sure it can be inferred automatically.
-- (Note: this would fail if we didn't have NoMonomorphismRestriction, since it
-- can't simplify the type all the way to `Double -> Double`.
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
typeConstraintTests :: Test
2016-10-24 21:26:42 +02:00
typeConstraintTests = testCase "type constraints" $ do
42 @=? doubleOrInt64Func (42 :: Double)
42 @=? doubleOrInt64Func (42 :: Int64)
42 @=? doubleOrFloatFunc (42 :: Double)
42 @=? doubleOrFloatFunc (42 :: Float)
42 @=? doubleFunc (42 :: Double)
42 @=? doubleFuncNoSig (42 :: Double)
main :: IO ()
main = defaultMain
[ testFFIRoundTrip
, testEncodeDecodeQcFloat
, testEncodeDecodeQcInt64
, testEncodeDecodeQcString
, typeConstraintTests
]