From bf0abd6d8287e3d40d73c22a2e30e3ee96ffd529 Mon Sep 17 00:00:00 2001 From: Andrew Pritchard Date: Thu, 9 Feb 2017 07:06:55 -0800 Subject: [PATCH] Add pure-Haskell implementation of TFRecords. The tensorflow-records package implements encoding/decoding of the format, and the tensorflow-records-conduit package provides wrappers and utilities for use with Conduit. --- .../src/TensorFlow/Records/Conduit.hs | 53 +++++++ tensorflow-records/src/TensorFlow/CRC32C.hs | 54 ++++++++ tensorflow-records/src/TensorFlow/Records.hs | 131 ++++++++++++++++++ tensorflow-records/tests/Main.hs | 51 +++++++ 4 files changed, 289 insertions(+) create mode 100644 tensorflow-records-conduit/src/TensorFlow/Records/Conduit.hs create mode 100644 tensorflow-records/src/TensorFlow/CRC32C.hs create mode 100644 tensorflow-records/src/TensorFlow/Records.hs create mode 100644 tensorflow-records/tests/Main.hs diff --git a/tensorflow-records-conduit/src/TensorFlow/Records/Conduit.hs b/tensorflow-records-conduit/src/TensorFlow/Records/Conduit.hs new file mode 100644 index 0000000..9728233 --- /dev/null +++ b/tensorflow-records-conduit/src/TensorFlow/Records/Conduit.hs @@ -0,0 +1,53 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- | Conduit wrappers for TensorFlow.Records. + +{-# LANGUAGE Rank2Types #-} +module TensorFlow.Records.Conduit + ( + -- * Encode/Decode + encodeTFRecords + , decodeTFRecords + + -- * Source/Sink + , sinkTFRecords + , sourceTFRecords + ) where + +import Control.Monad.Catch (MonadThrow) +import Control.Monad.Trans.Resource (MonadResource) +import qualified Data.ByteString as B +import qualified Data.ByteString.Lazy as BL +import Data.Conduit ((=$=), Conduit, Consumer, Producer) +import Data.Conduit.Binary (sinkFile, sourceFile) +import Data.Conduit.Cereal (conduitGet2, conduitPut) + +import TensorFlow.Records (getTFRecord, putTFRecord) + +-- | Decode TFRecords from a stream of bytes. +decodeTFRecords :: MonadThrow m => Conduit B.ByteString m BL.ByteString +decodeTFRecords = conduitGet2 getTFRecord + +-- | Read TFRecords from a file. +sourceTFRecords :: (MonadResource m, MonadThrow m) => FilePath -> Producer m BL.ByteString +sourceTFRecords path = sourceFile path =$= decodeTFRecords + +-- | Encode TFRecords to a stream of bytes. +encodeTFRecords :: Monad m => Conduit BL.ByteString m B.ByteString +encodeTFRecords = conduitPut putTFRecord + +-- | Write TFRecords to a file. +sinkTFRecords :: (MonadResource m) => FilePath -> Consumer BL.ByteString m () +sinkTFRecords path = encodeTFRecords =$= sinkFile path diff --git a/tensorflow-records/src/TensorFlow/CRC32C.hs b/tensorflow-records/src/TensorFlow/CRC32C.hs new file mode 100644 index 0000000..eaa2dfa --- /dev/null +++ b/tensorflow-records/src/TensorFlow/CRC32C.hs @@ -0,0 +1,54 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- | FFI wrappers of CRC32C digest. Import qualified. + +module TensorFlow.CRC32C + ( value, extend + , mask, unmask + , valueMasked + ) where + +import Data.Bits (rotateL, rotateR) +import qualified Data.ByteString as B +import qualified Data.ByteString.Lazy as BL +import Data.List (foldl') +import Data.Word (Word32) + +-- | Compute the CRC32C checksum of the concatenation of the bytes checksummed +-- by the given CRC32C value and the bytes in the given ByteString. +extend :: Word32 -> B.ByteString -> Word32 +extend = _crcExtend + +-- | Compute the CRC32C checksum of the given bytes. +value :: BL.ByteString -> Word32 +value = foldl' extend 0 . BL.toChunks + +-- | Scramble a CRC32C value so that the result can be safely stored in a +-- bytestream that may itself be CRC'd. +-- +-- This masking is the algorithm specified by TensorFlow's TFRecords format. +mask :: Word32 -> Word32 +mask x = rotateR x 15 + maskDelta + +-- | Inverse of 'mask'. +unmask :: Word32 -> Word32 +unmask x = rotateL (x - maskDelta) 15 + +-- | Convenience function combining 'value' and 'mask'. +valueMasked :: BL.ByteString -> Word32 +valueMasked = mask . value + +maskDelta :: Word32 +maskDelta = 0xa282ead8 diff --git a/tensorflow-records/src/TensorFlow/Records.hs b/tensorflow-records/src/TensorFlow/Records.hs new file mode 100644 index 0000000..337199b --- /dev/null +++ b/tensorflow-records/src/TensorFlow/Records.hs @@ -0,0 +1,131 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- | Encoder and decoder for the TensorFlow \"TFRecords\" format. + +{-# LANGUAGE Rank2Types #-} +module TensorFlow.Records + ( + -- * Records + putTFRecord + , getTFRecord + , getTFRecords + + -- * Implementation + + -- | These may be useful for encoding or decoding to types other than + -- 'ByteString' that have their own Cereal codecs. + , getTFRecordLength + , getTFRecordData + , putTFRecordLength + , putTFRecordData + ) where + +import Control.Monad (when) +import Data.ByteString.Unsafe (unsafePackCStringLen) +import qualified Data.ByteString.Builder as B (Builder) +import Data.ByteString.Builder.Extra (runBuilder, Next(..)) +import qualified Data.ByteString.Lazy as BL +import Data.Serialize.Get + ( Get + , getBytes + , getWord32le + , getWord64le + , getLazyByteString + , isEmpty + , lookAhead + ) +import Data.Serialize + ( Put + , execPut + , putLazyByteString + , putWord32le + , putWord64le + ) +import Data.Word (Word8, Word64) +import Foreign.Marshal.Alloc (allocaBytes) +import Foreign.Ptr (Ptr, castPtr) +import System.IO.Unsafe (unsafePerformIO) + +import qualified TensorFlow.CRC32C as CRC + +-- | Parse one TFRecord. +getTFRecord :: Get BL.ByteString +getTFRecord = getTFRecordLength >>= getTFRecordData + +-- | Parse many TFRecords as a list. Note you probably want streaming instead +-- as provided by the tensorflow-records-conduit package. +getTFRecords :: Get [BL.ByteString] +getTFRecords = do + e <- isEmpty + if e then return [] else (:) <$> getTFRecord <*> getTFRecords + +getCheckMaskedCRC32C :: BL.ByteString -> Get () +getCheckMaskedCRC32C bs = do + wireCRC <- getWord32le + let maskedCRC = CRC.valueMasked bs + when (maskedCRC /= wireCRC) $ fail $ + "getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++ + ", expected: " ++ show wireCRC + +-- | Get a length and verify its checksum. +getTFRecordLength :: Get Word64 +getTFRecordLength = do + buf <- lookAhead (getBytes 8) + getWord64le <* getCheckMaskedCRC32C (BL.fromStrict buf) + +-- | Get a record payload and verify its checksum. +getTFRecordData :: Word64 -> Get BL.ByteString +getTFRecordData len = if len > 0x7fffffffffffffff + then fail "getTFRecordData: Record size overflows Int64" + else do + bs <- getLazyByteString (fromIntegral len) + getCheckMaskedCRC32C bs + return bs + +putMaskedCRC32C :: BL.ByteString -> Put +putMaskedCRC32C = putWord32le . CRC.valueMasked + +-- Runs a Builder that's known to write a fixed number of bytes on a +-- stack-allocated buffer, and runs the given IO action on the result. Raises +-- exceptions if the Builder yields ByteString chunks or attempts to write more +-- bytes than expected. +unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r +unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do + (_, signal) <- runBuilder b ptr n + case signal of + Done -> act ptr + More _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned More." + Chunk _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned Chunk." + +-- | Put a record length and its checksum. +putTFRecordLength :: Word64 -> Put +putTFRecordLength x = + let put = putWord64le x + len = 8 + crc = CRC.mask $ unsafePerformIO $ + -- Serialized Word64 is always 8 bytes, so we can go fast by using + -- alloca. + unsafeWithFixedWidthBuilder len (execPut put) $ + \ptr -> CRC.extend 0 <$> unsafePackCStringLen (castPtr ptr, len) + in put *> putWord32le crc + +-- | Put a record payload and its checksum. +putTFRecordData :: BL.ByteString -> Put +putTFRecordData bs = putLazyByteString bs *> putMaskedCRC32C bs + +-- | Put one TFRecord with the given contents. +putTFRecord :: BL.ByteString -> Put +putTFRecord bs = + putTFRecordLength (fromIntegral $ BL.length bs) *> putTFRecordData bs diff --git a/tensorflow-records/tests/Main.hs b/tensorflow-records/tests/Main.hs new file mode 100644 index 0000000..b7473d8 --- /dev/null +++ b/tensorflow-records/tests/Main.hs @@ -0,0 +1,51 @@ +-- Copyright 2016 TensorFlow authors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +{-# LANGUAGE OverloadedStrings #-} +module Main where + +import qualified Data.ByteString as B +import qualified Data.ByteString.Lazy as BL +import Data.Monoid ((<>)) +import Data.Word (Word8) +import Data.Serialize (runGet, runPut) +import Test.Framework (Test, defaultMain) +import Test.Framework.Providers.QuickCheck2 (testProperty) + +import TensorFlow.Records (getTFRecord, putTFRecord) + +main :: IO () +main = defaultMain tests + +tests :: [Test] +tests = + [ testProperty "Inverse" propEncodeDecodeInverse + , testProperty "FixedRecord" propFixedRecord + ] + +-- There's no (Arbitrary BL.ByteString), so pack it from a list of chunks. +propEncodeDecodeInverse :: [[Word8]] -> Bool +propEncodeDecodeInverse s = + let bs = BL.fromChunks . fmap B.pack $ s + in runGet getTFRecord (runPut (putTFRecord bs)) == Right bs + +propFixedRecord :: Bool +propFixedRecord = + ("\x42" == case runGet getTFRecord record of + Left err -> error err -- Make the error appear in the test failure. + Right x -> x) && + (runPut (putTFRecord "\x42") == record) + where + record = "\x01\x00\x00\x00\x00\x00\x00\x00" <> "\x01\x75\xde\x41" <> + "\x42" <> "\x52\xcf\xb8\x1e"