120 lines
4.6 KiB
Haskell
120 lines
4.6 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE ConstraintKinds #-}
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
|
{-# LANGUAGE OverloadedLists #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE TypeFamilies #-}
|
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
|
|
|
import Control.Monad (replicateM)
|
|
import Control.Monad.IO.Class (liftIO)
|
|
import Data.Int (Int64)
|
|
import Google.Test (googleTest)
|
|
import Test.Framework.Providers.HUnit (testCase)
|
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
|
import Test.HUnit ((@=?))
|
|
import Test.QuickCheck (Arbitrary(..), listOf, suchThat)
|
|
import qualified Data.ByteString as B
|
|
import qualified Data.ByteString.Char8 as B8
|
|
import qualified Data.Vector as V
|
|
|
|
import qualified TensorFlow.ControlFlow as TF
|
|
import qualified TensorFlow.Ops as TF
|
|
import qualified TensorFlow.Session as TF
|
|
import qualified TensorFlow.Tensor as TF
|
|
import qualified TensorFlow.Types as TF
|
|
|
|
instance Arbitrary B.ByteString where
|
|
arbitrary = B.pack <$> arbitrary
|
|
|
|
-- Test encoding tensors, feeding them through tensorflow, and decoding the
|
|
-- results.
|
|
testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
|
TF.runSession $ do
|
|
let floatData = V.fromList [1..6 :: Float]
|
|
stringData = V.fromList [B8.pack (show x) | x <- [1..6]]
|
|
f <- TF.build $ TF.placeholder [2,3]
|
|
s <- TF.build $ TF.placeholder [2,3]
|
|
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
|
, TF.feed s (TF.encodeTensorData [2,3] stringData)
|
|
]
|
|
-- It is an error to fetch a tensor that is being fed, so the tensors
|
|
-- are passed through identity.
|
|
(f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s)
|
|
liftIO $ do
|
|
floatData @=? f'
|
|
stringData @=? s'
|
|
|
|
|
|
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
|
|
deriving Show
|
|
|
|
instance Arbitrary a => Arbitrary (TensorDataInputs a) where
|
|
arbitrary = do
|
|
-- Limit the size of the final vector, and also guard against overflow
|
|
-- (i.e., p<0) when there are too many dimensions
|
|
let validProduct p = p > 0 && p < 100
|
|
sizes <- listOf (arbitrary `suchThat` (>0))
|
|
`suchThat` (validProduct . product)
|
|
elems <- replicateM (fromIntegral $ product sizes) arbitrary
|
|
return $ TensorDataInputs sizes (V.fromList elems)
|
|
|
|
-- Test that a vector is unchanged after being encoded and decoded.
|
|
encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
|
|
encodeDecodeProp (TensorDataInputs shape vec) =
|
|
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
|
|
|
|
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
|
|
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
|
|
|
|
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
|
|
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
|
|
|
|
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
|
|
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
|
|
|
|
doubleOrInt64Func :: TF.OneOf '[Double, Int64] a => a -> a
|
|
doubleOrInt64Func = id
|
|
|
|
doubleOrFloatFunc :: TF.OneOf '[Double, Float] a => a -> a
|
|
doubleOrFloatFunc = id
|
|
|
|
doubleFunc :: TF.OneOf '[Double] a => a -> a
|
|
doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
|
|
|
|
-- No explicit type signature; make sure it can be inferred automatically.
|
|
-- (Note: this would fail if we didn't have NoMonomorphismRestriction, since it
|
|
-- can't simplify the type all the way to `Double -> Double`.
|
|
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
|
|
|
|
typeConstraintTests = testCase "type constraints" $ do
|
|
42 @=? doubleOrInt64Func (42 :: Double)
|
|
42 @=? doubleOrInt64Func (42 :: Int64)
|
|
42 @=? doubleOrFloatFunc (42 :: Double)
|
|
42 @=? doubleOrFloatFunc (42 :: Float)
|
|
42 @=? doubleFunc (42 :: Double)
|
|
42 @=? doubleFuncNoSig (42 :: Double)
|
|
|
|
|
|
main :: IO ()
|
|
main = googleTest [ testFFIRoundTrip
|
|
, testEncodeDecodeQcFloat
|
|
, testEncodeDecodeQcInt64
|
|
, testEncodeDecodeQcString
|
|
, typeConstraintTests
|
|
]
|