diff --git a/tensorflow-records/src/TensorFlow/CRC32C.hs b/tensorflow-records/src/TensorFlow/CRC32C.hs index b6c3e7f..7e62023 100644 --- a/tensorflow-records/src/TensorFlow/CRC32C.hs +++ b/tensorflow-records/src/TensorFlow/CRC32C.hs @@ -12,44 +12,50 @@ -- See the License for the specific language governing permissions and -- limitations under the License. --- | FFI wrappers of CRC32C digest. Import qualified. - module TensorFlow.CRC32C - ( value, extend - , mask, unmask - , valueMasked + ( crc32c + , crc32cLBS + , crc32cUpdate + , crc32cMasked + , crc32cLBSMasked + , crc32cMask + , crc32cUnmask ) where import Data.Bits (rotateL, rotateR) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL -import Data.Digest.CRC32C (crc32c_update) +import Data.Digest.CRC32C (crc32c, crc32c_update) import Data.List (foldl') import Data.Word (Word32) -- | Compute the CRC32C checksum of the concatenation of the bytes checksummed -- by the given CRC32C value and the bytes in the given ByteString. -extend :: Word32 -> B.ByteString -> Word32 -extend = crc32c_update +crc32cUpdate :: Word32 -> B.ByteString -> Word32 +crc32cUpdate = crc32c_update -- | Compute the CRC32C checksum of the given bytes. -value :: BL.ByteString -> Word32 -value = foldl' extend 0 . BL.toChunks +crc32cLBS :: BL.ByteString -> Word32 +crc32cLBS = foldl' crc32cUpdate 0 . BL.toChunks -- | Scramble a CRC32C value so that the result can be safely stored in a -- bytestream that may itself be CRC'd. -- -- This masking is the algorithm specified by TensorFlow's TFRecords format. -mask :: Word32 -> Word32 -mask x = rotateR x 15 + maskDelta +crc32cMask :: Word32 -> Word32 +crc32cMask x = rotateR x 15 + maskDelta --- | Inverse of 'mask'. -unmask :: Word32 -> Word32 -unmask x = rotateL (x - maskDelta) 15 +-- | Inverse of 'crc32cMask'. +crc32cUnmask :: Word32 -> Word32 +crc32cUnmask x = rotateL (x - maskDelta) 15 --- | Convenience function combining 'value' and 'mask'. -valueMasked :: BL.ByteString -> Word32 -valueMasked = mask . value +-- | Convenience function combining 'crc32c' and 'crc32cMask'. +crc32cMasked :: B.ByteString -> Word32 +crc32cMasked = crc32cMask . crc32c + +-- | Convenience function combining 'crc32cLBS' and 'crc32cMask'. +crc32cLBSMasked :: BL.ByteString -> Word32 +crc32cLBSMasked = crc32cMask . crc32cLBS maskDelta :: Word32 maskDelta = 0xa282ead8 diff --git a/tensorflow-records/src/TensorFlow/Records.hs b/tensorflow-records/src/TensorFlow/Records.hs index 337199b..63c30fd 100644 --- a/tensorflow-records/src/TensorFlow/Records.hs +++ b/tensorflow-records/src/TensorFlow/Records.hs @@ -32,6 +32,7 @@ module TensorFlow.Records , putTFRecordData ) where +import Control.Exception (evaluate) import Control.Monad (when) import Data.ByteString.Unsafe (unsafePackCStringLen) import qualified Data.ByteString.Builder as B (Builder) @@ -58,7 +59,7 @@ import Foreign.Marshal.Alloc (allocaBytes) import Foreign.Ptr (Ptr, castPtr) import System.IO.Unsafe (unsafePerformIO) -import qualified TensorFlow.CRC32C as CRC +import TensorFlow.CRC32C (crc32cLBSMasked, crc32cUpdate, crc32cMask) -- | Parse one TFRecord. getTFRecord :: Get BL.ByteString @@ -74,7 +75,7 @@ getTFRecords = do getCheckMaskedCRC32C :: BL.ByteString -> Get () getCheckMaskedCRC32C bs = do wireCRC <- getWord32le - let maskedCRC = CRC.valueMasked bs + let maskedCRC = crc32cLBSMasked bs when (maskedCRC /= wireCRC) $ fail $ "getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++ ", expected: " ++ show wireCRC @@ -95,12 +96,12 @@ getTFRecordData len = if len > 0x7fffffffffffffff return bs putMaskedCRC32C :: BL.ByteString -> Put -putMaskedCRC32C = putWord32le . CRC.valueMasked +putMaskedCRC32C = putWord32le . crc32cLBSMasked --- Runs a Builder that's known to write a fixed number of bytes on a --- stack-allocated buffer, and runs the given IO action on the result. Raises --- exceptions if the Builder yields ByteString chunks or attempts to write more --- bytes than expected. +-- Runs a Builder that's known to write a fixed number of bytes on an 'alloca' +-- buffer, and runs the given IO action on the result. Raises exceptions if +-- the Builder yields ByteString chunks or attempts to write more bytes than +-- expected. unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do (_, signal) <- runBuilder b ptr n @@ -114,11 +115,13 @@ putTFRecordLength :: Word64 -> Put putTFRecordLength x = let put = putWord64le x len = 8 - crc = CRC.mask $ unsafePerformIO $ + crc = crc32cMask $ unsafePerformIO $ -- Serialized Word64 is always 8 bytes, so we can go fast by using -- alloca. - unsafeWithFixedWidthBuilder len (execPut put) $ - \ptr -> CRC.extend 0 <$> unsafePackCStringLen (castPtr ptr, len) + unsafeWithFixedWidthBuilder len (execPut put) $ \ptr -> do + str <- unsafePackCStringLen (castPtr ptr, len) + -- Force the result to ensure it's evaluated before freeing ptr. + evaluate $ crc32cUpdate 0 str in put *> putWord32le crc -- | Put a record payload and its checksum. diff --git a/tensorflow-records/tensorflow-records.cabal b/tensorflow-records/tensorflow-records.cabal index 368791f..8fca811 100644 --- a/tensorflow-records/tensorflow-records.cabal +++ b/tensorflow-records/tensorflow-records.cabal @@ -19,7 +19,7 @@ library , cereal -- TODO: Split Data.Digest.CRC32C out of snappy-framing for a -- lighter dependency? - , snappy-framing + , snappy-framing >= 0.1.1 default-language: Haskell2010 Test-Suite RecordsTest