1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Add pure-Haskell implementation of TFRecords.

The tensorflow-records package implements encoding/decoding of the
format, and the tensorflow-records-conduit package provides wrappers and
utilities for use with Conduit.
This commit is contained in:
Andrew Pritchard 2017-02-09 07:06:55 -08:00 committed by fkm3
parent 72631cb9f3
commit bf0abd6d82
4 changed files with 289 additions and 0 deletions

View file

@ -0,0 +1,53 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | Conduit wrappers for TensorFlow.Records.
{-# LANGUAGE Rank2Types #-}
module TensorFlow.Records.Conduit
(
-- * Encode/Decode
encodeTFRecords
, decodeTFRecords
-- * Source/Sink
, sinkTFRecords
, sourceTFRecords
) where
import Control.Monad.Catch (MonadThrow)
import Control.Monad.Trans.Resource (MonadResource)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Conduit ((=$=), Conduit, Consumer, Producer)
import Data.Conduit.Binary (sinkFile, sourceFile)
import Data.Conduit.Cereal (conduitGet2, conduitPut)
import TensorFlow.Records (getTFRecord, putTFRecord)
-- | Decode TFRecords from a stream of bytes.
decodeTFRecords :: MonadThrow m => Conduit B.ByteString m BL.ByteString
decodeTFRecords = conduitGet2 getTFRecord
-- | Read TFRecords from a file.
sourceTFRecords :: (MonadResource m, MonadThrow m) => FilePath -> Producer m BL.ByteString
sourceTFRecords path = sourceFile path =$= decodeTFRecords
-- | Encode TFRecords to a stream of bytes.
encodeTFRecords :: Monad m => Conduit BL.ByteString m B.ByteString
encodeTFRecords = conduitPut putTFRecord
-- | Write TFRecords to a file.
sinkTFRecords :: (MonadResource m) => FilePath -> Consumer BL.ByteString m ()
sinkTFRecords path = encodeTFRecords =$= sinkFile path

View file

@ -0,0 +1,54 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | FFI wrappers of CRC32C digest. Import qualified.
module TensorFlow.CRC32C
( value, extend
, mask, unmask
, valueMasked
) where
import Data.Bits (rotateL, rotateR)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.List (foldl')
import Data.Word (Word32)
-- | Compute the CRC32C checksum of the concatenation of the bytes checksummed
-- by the given CRC32C value and the bytes in the given ByteString.
extend :: Word32 -> B.ByteString -> Word32
extend = _crcExtend
-- | Compute the CRC32C checksum of the given bytes.
value :: BL.ByteString -> Word32
value = foldl' extend 0 . BL.toChunks
-- | Scramble a CRC32C value so that the result can be safely stored in a
-- bytestream that may itself be CRC'd.
--
-- This masking is the algorithm specified by TensorFlow's TFRecords format.
mask :: Word32 -> Word32
mask x = rotateR x 15 + maskDelta
-- | Inverse of 'mask'.
unmask :: Word32 -> Word32
unmask x = rotateL (x - maskDelta) 15
-- | Convenience function combining 'value' and 'mask'.
valueMasked :: BL.ByteString -> Word32
valueMasked = mask . value
maskDelta :: Word32
maskDelta = 0xa282ead8

View file

@ -0,0 +1,131 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | Encoder and decoder for the TensorFlow \"TFRecords\" format.
{-# LANGUAGE Rank2Types #-}
module TensorFlow.Records
(
-- * Records
putTFRecord
, getTFRecord
, getTFRecords
-- * Implementation
-- | These may be useful for encoding or decoding to types other than
-- 'ByteString' that have their own Cereal codecs.
, getTFRecordLength
, getTFRecordData
, putTFRecordLength
, putTFRecordData
) where
import Control.Monad (when)
import Data.ByteString.Unsafe (unsafePackCStringLen)
import qualified Data.ByteString.Builder as B (Builder)
import Data.ByteString.Builder.Extra (runBuilder, Next(..))
import qualified Data.ByteString.Lazy as BL
import Data.Serialize.Get
( Get
, getBytes
, getWord32le
, getWord64le
, getLazyByteString
, isEmpty
, lookAhead
)
import Data.Serialize
( Put
, execPut
, putLazyByteString
, putWord32le
, putWord64le
)
import Data.Word (Word8, Word64)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (Ptr, castPtr)
import System.IO.Unsafe (unsafePerformIO)
import qualified TensorFlow.CRC32C as CRC
-- | Parse one TFRecord.
getTFRecord :: Get BL.ByteString
getTFRecord = getTFRecordLength >>= getTFRecordData
-- | Parse many TFRecords as a list. Note you probably want streaming instead
-- as provided by the tensorflow-records-conduit package.
getTFRecords :: Get [BL.ByteString]
getTFRecords = do
e <- isEmpty
if e then return [] else (:) <$> getTFRecord <*> getTFRecords
getCheckMaskedCRC32C :: BL.ByteString -> Get ()
getCheckMaskedCRC32C bs = do
wireCRC <- getWord32le
let maskedCRC = CRC.valueMasked bs
when (maskedCRC /= wireCRC) $ fail $
"getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++
", expected: " ++ show wireCRC
-- | Get a length and verify its checksum.
getTFRecordLength :: Get Word64
getTFRecordLength = do
buf <- lookAhead (getBytes 8)
getWord64le <* getCheckMaskedCRC32C (BL.fromStrict buf)
-- | Get a record payload and verify its checksum.
getTFRecordData :: Word64 -> Get BL.ByteString
getTFRecordData len = if len > 0x7fffffffffffffff
then fail "getTFRecordData: Record size overflows Int64"
else do
bs <- getLazyByteString (fromIntegral len)
getCheckMaskedCRC32C bs
return bs
putMaskedCRC32C :: BL.ByteString -> Put
putMaskedCRC32C = putWord32le . CRC.valueMasked
-- Runs a Builder that's known to write a fixed number of bytes on a
-- stack-allocated buffer, and runs the given IO action on the result. Raises
-- exceptions if the Builder yields ByteString chunks or attempts to write more
-- bytes than expected.
unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r
unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do
(_, signal) <- runBuilder b ptr n
case signal of
Done -> act ptr
More _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned More."
Chunk _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned Chunk."
-- | Put a record length and its checksum.
putTFRecordLength :: Word64 -> Put
putTFRecordLength x =
let put = putWord64le x
len = 8
crc = CRC.mask $ unsafePerformIO $
-- Serialized Word64 is always 8 bytes, so we can go fast by using
-- alloca.
unsafeWithFixedWidthBuilder len (execPut put) $
\ptr -> CRC.extend 0 <$> unsafePackCStringLen (castPtr ptr, len)
in put *> putWord32le crc
-- | Put a record payload and its checksum.
putTFRecordData :: BL.ByteString -> Put
putTFRecordData bs = putLazyByteString bs *> putMaskedCRC32C bs
-- | Put one TFRecord with the given contents.
putTFRecord :: BL.ByteString -> Put
putTFRecord bs =
putTFRecordLength (fromIntegral $ BL.length bs) *> putTFRecordData bs

View file

@ -0,0 +1,51 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
module Main where
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Monoid ((<>))
import Data.Word (Word8)
import Data.Serialize (runGet, runPut)
import Test.Framework (Test, defaultMain)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import TensorFlow.Records (getTFRecord, putTFRecord)
main :: IO ()
main = defaultMain tests
tests :: [Test]
tests =
[ testProperty "Inverse" propEncodeDecodeInverse
, testProperty "FixedRecord" propFixedRecord
]
-- There's no (Arbitrary BL.ByteString), so pack it from a list of chunks.
propEncodeDecodeInverse :: [[Word8]] -> Bool
propEncodeDecodeInverse s =
let bs = BL.fromChunks . fmap B.pack $ s
in runGet getTFRecord (runPut (putTFRecord bs)) == Right bs
propFixedRecord :: Bool
propFixedRecord =
("\x42" == case runGet getTFRecord record of
Left err -> error err -- Make the error appear in the test failure.
Right x -> x) &&
(runPut (putTFRecord "\x42") == record)
where
record = "\x01\x00\x00\x00\x00\x00\x00\x00" <> "\x01\x75\xde\x41" <>
"\x42" <> "\x52\xcf\xb8\x1e"