mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-11 19:39:49 +01:00
Improve comments and make naming consistent.
This commit is contained in:
parent
ce6717a9f8
commit
65a1220b90
3 changed files with 38 additions and 29 deletions
|
@ -12,44 +12,50 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- | FFI wrappers of CRC32C digest. Import qualified.
|
||||
|
||||
module TensorFlow.CRC32C
|
||||
( value, extend
|
||||
, mask, unmask
|
||||
, valueMasked
|
||||
( crc32c
|
||||
, crc32cLBS
|
||||
, crc32cUpdate
|
||||
, crc32cMasked
|
||||
, crc32cLBSMasked
|
||||
, crc32cMask
|
||||
, crc32cUnmask
|
||||
) where
|
||||
|
||||
import Data.Bits (rotateL, rotateR)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import Data.Digest.CRC32C (crc32c_update)
|
||||
import Data.Digest.CRC32C (crc32c, crc32c_update)
|
||||
import Data.List (foldl')
|
||||
import Data.Word (Word32)
|
||||
|
||||
-- | Compute the CRC32C checksum of the concatenation of the bytes checksummed
|
||||
-- by the given CRC32C value and the bytes in the given ByteString.
|
||||
extend :: Word32 -> B.ByteString -> Word32
|
||||
extend = crc32c_update
|
||||
crc32cUpdate :: Word32 -> B.ByteString -> Word32
|
||||
crc32cUpdate = crc32c_update
|
||||
|
||||
-- | Compute the CRC32C checksum of the given bytes.
|
||||
value :: BL.ByteString -> Word32
|
||||
value = foldl' extend 0 . BL.toChunks
|
||||
crc32cLBS :: BL.ByteString -> Word32
|
||||
crc32cLBS = foldl' crc32cUpdate 0 . BL.toChunks
|
||||
|
||||
-- | Scramble a CRC32C value so that the result can be safely stored in a
|
||||
-- bytestream that may itself be CRC'd.
|
||||
--
|
||||
-- This masking is the algorithm specified by TensorFlow's TFRecords format.
|
||||
mask :: Word32 -> Word32
|
||||
mask x = rotateR x 15 + maskDelta
|
||||
crc32cMask :: Word32 -> Word32
|
||||
crc32cMask x = rotateR x 15 + maskDelta
|
||||
|
||||
-- | Inverse of 'mask'.
|
||||
unmask :: Word32 -> Word32
|
||||
unmask x = rotateL (x - maskDelta) 15
|
||||
-- | Inverse of 'crc32cMask'.
|
||||
crc32cUnmask :: Word32 -> Word32
|
||||
crc32cUnmask x = rotateL (x - maskDelta) 15
|
||||
|
||||
-- | Convenience function combining 'value' and 'mask'.
|
||||
valueMasked :: BL.ByteString -> Word32
|
||||
valueMasked = mask . value
|
||||
-- | Convenience function combining 'crc32c' and 'crc32cMask'.
|
||||
crc32cMasked :: B.ByteString -> Word32
|
||||
crc32cMasked = crc32cMask . crc32c
|
||||
|
||||
-- | Convenience function combining 'crc32cLBS' and 'crc32cMask'.
|
||||
crc32cLBSMasked :: BL.ByteString -> Word32
|
||||
crc32cLBSMasked = crc32cMask . crc32cLBS
|
||||
|
||||
maskDelta :: Word32
|
||||
maskDelta = 0xa282ead8
|
||||
|
|
|
@ -32,6 +32,7 @@ module TensorFlow.Records
|
|||
, putTFRecordData
|
||||
) where
|
||||
|
||||
import Control.Exception (evaluate)
|
||||
import Control.Monad (when)
|
||||
import Data.ByteString.Unsafe (unsafePackCStringLen)
|
||||
import qualified Data.ByteString.Builder as B (Builder)
|
||||
|
@ -58,7 +59,7 @@ import Foreign.Marshal.Alloc (allocaBytes)
|
|||
import Foreign.Ptr (Ptr, castPtr)
|
||||
import System.IO.Unsafe (unsafePerformIO)
|
||||
|
||||
import qualified TensorFlow.CRC32C as CRC
|
||||
import TensorFlow.CRC32C (crc32cLBSMasked, crc32cUpdate, crc32cMask)
|
||||
|
||||
-- | Parse one TFRecord.
|
||||
getTFRecord :: Get BL.ByteString
|
||||
|
@ -74,7 +75,7 @@ getTFRecords = do
|
|||
getCheckMaskedCRC32C :: BL.ByteString -> Get ()
|
||||
getCheckMaskedCRC32C bs = do
|
||||
wireCRC <- getWord32le
|
||||
let maskedCRC = CRC.valueMasked bs
|
||||
let maskedCRC = crc32cLBSMasked bs
|
||||
when (maskedCRC /= wireCRC) $ fail $
|
||||
"getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++
|
||||
", expected: " ++ show wireCRC
|
||||
|
@ -95,12 +96,12 @@ getTFRecordData len = if len > 0x7fffffffffffffff
|
|||
return bs
|
||||
|
||||
putMaskedCRC32C :: BL.ByteString -> Put
|
||||
putMaskedCRC32C = putWord32le . CRC.valueMasked
|
||||
putMaskedCRC32C = putWord32le . crc32cLBSMasked
|
||||
|
||||
-- Runs a Builder that's known to write a fixed number of bytes on a
|
||||
-- stack-allocated buffer, and runs the given IO action on the result. Raises
|
||||
-- exceptions if the Builder yields ByteString chunks or attempts to write more
|
||||
-- bytes than expected.
|
||||
-- Runs a Builder that's known to write a fixed number of bytes on an 'alloca'
|
||||
-- buffer, and runs the given IO action on the result. Raises exceptions if
|
||||
-- the Builder yields ByteString chunks or attempts to write more bytes than
|
||||
-- expected.
|
||||
unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r
|
||||
unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do
|
||||
(_, signal) <- runBuilder b ptr n
|
||||
|
@ -114,11 +115,13 @@ putTFRecordLength :: Word64 -> Put
|
|||
putTFRecordLength x =
|
||||
let put = putWord64le x
|
||||
len = 8
|
||||
crc = CRC.mask $ unsafePerformIO $
|
||||
crc = crc32cMask $ unsafePerformIO $
|
||||
-- Serialized Word64 is always 8 bytes, so we can go fast by using
|
||||
-- alloca.
|
||||
unsafeWithFixedWidthBuilder len (execPut put) $
|
||||
\ptr -> CRC.extend 0 <$> unsafePackCStringLen (castPtr ptr, len)
|
||||
unsafeWithFixedWidthBuilder len (execPut put) $ \ptr -> do
|
||||
str <- unsafePackCStringLen (castPtr ptr, len)
|
||||
-- Force the result to ensure it's evaluated before freeing ptr.
|
||||
evaluate $ crc32cUpdate 0 str
|
||||
in put *> putWord32le crc
|
||||
|
||||
-- | Put a record payload and its checksum.
|
||||
|
|
|
@ -19,7 +19,7 @@ library
|
|||
, cereal
|
||||
-- TODO: Split Data.Digest.CRC32C out of snappy-framing for a
|
||||
-- lighter dependency?
|
||||
, snappy-framing
|
||||
, snappy-framing >= 0.1.1
|
||||
default-language: Haskell2010
|
||||
|
||||
Test-Suite RecordsTest
|
||||
|
|
Loading…
Reference in a new issue