mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-26 11:39:46 +01:00
52 lines
1.7 KiB
Haskell
52 lines
1.7 KiB
Haskell
|
-- Copyright 2016 TensorFlow authors.
|
||
|
--
|
||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
-- you may not use this file except in compliance with the License.
|
||
|
-- You may obtain a copy of the License at
|
||
|
--
|
||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||
|
--
|
||
|
-- Unless required by applicable law or agreed to in writing, software
|
||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
-- See the License for the specific language governing permissions and
|
||
|
-- limitations under the License.
|
||
|
|
||
|
{-# LANGUAGE OverloadedStrings #-}
|
||
|
module Main where
|
||
|
|
||
|
import qualified Data.ByteString as B
|
||
|
import qualified Data.ByteString.Lazy as BL
|
||
|
import Data.Monoid ((<>))
|
||
|
import Data.Word (Word8)
|
||
|
import Data.Serialize (runGet, runPut)
|
||
|
import Test.Framework (Test, defaultMain)
|
||
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||
|
|
||
|
import TensorFlow.Records (getTFRecord, putTFRecord)
|
||
|
|
||
|
main :: IO ()
|
||
|
main = defaultMain tests
|
||
|
|
||
|
tests :: [Test]
|
||
|
tests =
|
||
|
[ testProperty "Inverse" propEncodeDecodeInverse
|
||
|
, testProperty "FixedRecord" propFixedRecord
|
||
|
]
|
||
|
|
||
|
-- There's no (Arbitrary BL.ByteString), so pack it from a list of chunks.
|
||
|
propEncodeDecodeInverse :: [[Word8]] -> Bool
|
||
|
propEncodeDecodeInverse s =
|
||
|
let bs = BL.fromChunks . fmap B.pack $ s
|
||
|
in runGet getTFRecord (runPut (putTFRecord bs)) == Right bs
|
||
|
|
||
|
propFixedRecord :: Bool
|
||
|
propFixedRecord =
|
||
|
("\x42" == case runGet getTFRecord record of
|
||
|
Left err -> error err -- Make the error appear in the test failure.
|
||
|
Right x -> x) &&
|
||
|
(runPut (putTFRecord "\x42") == record)
|
||
|
where
|
||
|
record = "\x01\x00\x00\x00\x00\x00\x00\x00" <> "\x01\x75\xde\x41" <>
|
||
|
"\x42" <> "\x52\xcf\xb8\x1e"
|