tensorflow-haskell/tensorflow-ops/tests/TypesTest.hs

120 lines
4.6 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
import Control.Monad (replicateM)
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Google.Test (googleTest)
import Test.Framework.Providers.HUnit (testCase)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.QuickCheck (Arbitrary(..), listOf, suchThat)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
instance Arbitrary B.ByteString where
arbitrary = B.pack <$> arbitrary
-- Test encoding tensors, feeding them through tensorflow, and decoding the
-- results.
testFFIRoundTrip = testCase "testFFIRoundTrip" $
TF.runSession $ do
let floatData = V.fromList [1..6 :: Float]
stringData = V.fromList [B8.pack (show x) | x <- [1..6]]
f <- TF.build $ TF.placeholder [2,3]
s <- TF.build $ TF.placeholder [2,3]
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
, TF.feed s (TF.encodeTensorData [2,3] stringData)
]
-- It is an error to fetch a tensor that is being fed, so the tensors
-- are passed through identity.
(f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s)
liftIO $ do
floatData @=? f'
stringData @=? s'
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
deriving Show
instance Arbitrary a => Arbitrary (TensorDataInputs a) where
arbitrary = do
-- Limit the size of the final vector, and also guard against overflow
-- (i.e., p<0) when there are too many dimensions
let validProduct p = p > 0 && p < 100
sizes <- listOf (arbitrary `suchThat` (>0))
`suchThat` (validProduct . product)
elems <- replicateM (fromIntegral $ product sizes) arbitrary
return $ TensorDataInputs sizes (V.fromList elems)
-- Test that a vector is unchanged after being encoded and decoded.
encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
encodeDecodeProp (TensorDataInputs shape vec) =
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
doubleOrInt64Func :: TF.OneOf '[Double, Int64] a => a -> a
doubleOrInt64Func = id
doubleOrFloatFunc :: TF.OneOf '[Double, Float] a => a -> a
doubleOrFloatFunc = id
doubleFunc :: TF.OneOf '[Double] a => a -> a
doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
-- No explicit type signature; make sure it can be inferred automatically.
-- (Note: this would fail if we didn't have NoMonomorphismRestriction, since it
-- can't simplify the type all the way to `Double -> Double`.
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
typeConstraintTests = testCase "type constraints" $ do
42 @=? doubleOrInt64Func (42 :: Double)
42 @=? doubleOrInt64Func (42 :: Int64)
42 @=? doubleOrFloatFunc (42 :: Double)
42 @=? doubleOrFloatFunc (42 :: Float)
42 @=? doubleFunc (42 :: Double)
42 @=? doubleFuncNoSig (42 :: Double)
main :: IO ()
main = googleTest [ testFFIRoundTrip
, testEncodeDecodeQcFloat
, testEncodeDecodeQcInt64
, testEncodeDecodeQcString
, typeConstraintTests
]