1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-01-11 19:39:49 +01:00

Improve comments and make naming consistent.

This commit is contained in:
Andrew Pritchard 2017-02-09 08:20:26 -08:00 committed by fkm3
parent ce6717a9f8
commit 65a1220b90
3 changed files with 38 additions and 29 deletions

View file

@ -12,44 +12,50 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | FFI wrappers of CRC32C digest. Import qualified.
module TensorFlow.CRC32C
( value, extend
, mask, unmask
, valueMasked
( crc32c
, crc32cLBS
, crc32cUpdate
, crc32cMasked
, crc32cLBSMasked
, crc32cMask
, crc32cUnmask
) where
import Data.Bits (rotateL, rotateR)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Digest.CRC32C (crc32c_update)
import Data.Digest.CRC32C (crc32c, crc32c_update)
import Data.List (foldl')
import Data.Word (Word32)
-- | Compute the CRC32C checksum of the concatenation of the bytes checksummed
-- by the given CRC32C value and the bytes in the given ByteString.
extend :: Word32 -> B.ByteString -> Word32
extend = crc32c_update
crc32cUpdate :: Word32 -> B.ByteString -> Word32
crc32cUpdate = crc32c_update
-- | Compute the CRC32C checksum of the given bytes.
value :: BL.ByteString -> Word32
value = foldl' extend 0 . BL.toChunks
crc32cLBS :: BL.ByteString -> Word32
crc32cLBS = foldl' crc32cUpdate 0 . BL.toChunks
-- | Scramble a CRC32C value so that the result can be safely stored in a
-- bytestream that may itself be CRC'd.
--
-- This masking is the algorithm specified by TensorFlow's TFRecords format.
mask :: Word32 -> Word32
mask x = rotateR x 15 + maskDelta
crc32cMask :: Word32 -> Word32
crc32cMask x = rotateR x 15 + maskDelta
-- | Inverse of 'mask'.
unmask :: Word32 -> Word32
unmask x = rotateL (x - maskDelta) 15
-- | Inverse of 'crc32cMask'.
crc32cUnmask :: Word32 -> Word32
crc32cUnmask x = rotateL (x - maskDelta) 15
-- | Convenience function combining 'value' and 'mask'.
valueMasked :: BL.ByteString -> Word32
valueMasked = mask . value
-- | Convenience function combining 'crc32c' and 'crc32cMask'.
crc32cMasked :: B.ByteString -> Word32
crc32cMasked = crc32cMask . crc32c
-- | Convenience function combining 'crc32cLBS' and 'crc32cMask'.
crc32cLBSMasked :: BL.ByteString -> Word32
crc32cLBSMasked = crc32cMask . crc32cLBS
maskDelta :: Word32
maskDelta = 0xa282ead8

View file

@ -32,6 +32,7 @@ module TensorFlow.Records
, putTFRecordData
) where
import Control.Exception (evaluate)
import Control.Monad (when)
import Data.ByteString.Unsafe (unsafePackCStringLen)
import qualified Data.ByteString.Builder as B (Builder)
@ -58,7 +59,7 @@ import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Ptr (Ptr, castPtr)
import System.IO.Unsafe (unsafePerformIO)
import qualified TensorFlow.CRC32C as CRC
import TensorFlow.CRC32C (crc32cLBSMasked, crc32cUpdate, crc32cMask)
-- | Parse one TFRecord.
getTFRecord :: Get BL.ByteString
@ -74,7 +75,7 @@ getTFRecords = do
getCheckMaskedCRC32C :: BL.ByteString -> Get ()
getCheckMaskedCRC32C bs = do
wireCRC <- getWord32le
let maskedCRC = CRC.valueMasked bs
let maskedCRC = crc32cLBSMasked bs
when (maskedCRC /= wireCRC) $ fail $
"getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++
", expected: " ++ show wireCRC
@ -95,12 +96,12 @@ getTFRecordData len = if len > 0x7fffffffffffffff
return bs
putMaskedCRC32C :: BL.ByteString -> Put
putMaskedCRC32C = putWord32le . CRC.valueMasked
putMaskedCRC32C = putWord32le . crc32cLBSMasked
-- Runs a Builder that's known to write a fixed number of bytes on a
-- stack-allocated buffer, and runs the given IO action on the result. Raises
-- exceptions if the Builder yields ByteString chunks or attempts to write more
-- bytes than expected.
-- Runs a Builder that's known to write a fixed number of bytes on an 'alloca'
-- buffer, and runs the given IO action on the result. Raises exceptions if
-- the Builder yields ByteString chunks or attempts to write more bytes than
-- expected.
unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r
unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do
(_, signal) <- runBuilder b ptr n
@ -114,11 +115,13 @@ putTFRecordLength :: Word64 -> Put
putTFRecordLength x =
let put = putWord64le x
len = 8
crc = CRC.mask $ unsafePerformIO $
crc = crc32cMask $ unsafePerformIO $
-- Serialized Word64 is always 8 bytes, so we can go fast by using
-- alloca.
unsafeWithFixedWidthBuilder len (execPut put) $
\ptr -> CRC.extend 0 <$> unsafePackCStringLen (castPtr ptr, len)
unsafeWithFixedWidthBuilder len (execPut put) $ \ptr -> do
str <- unsafePackCStringLen (castPtr ptr, len)
-- Force the result to ensure it's evaluated before freeing ptr.
evaluate $ crc32cUpdate 0 str
in put *> putWord32le crc
-- | Put a record payload and its checksum.

View file

@ -19,7 +19,7 @@ library
, cereal
-- TODO: Split Data.Digest.CRC32C out of snappy-framing for a
-- lighter dependency?
, snappy-framing
, snappy-framing >= 0.1.1
default-language: Haskell2010
Test-Suite RecordsTest