mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Improve comments and make naming consistent.
This commit is contained in:
parent
ce6717a9f8
commit
65a1220b90
3 changed files with 38 additions and 29 deletions
|
@ -12,44 +12,50 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
-- | FFI wrappers of CRC32C digest. Import qualified.
|
|
||||||
|
|
||||||
module TensorFlow.CRC32C
|
module TensorFlow.CRC32C
|
||||||
( value, extend
|
( crc32c
|
||||||
, mask, unmask
|
, crc32cLBS
|
||||||
, valueMasked
|
, crc32cUpdate
|
||||||
|
, crc32cMasked
|
||||||
|
, crc32cLBSMasked
|
||||||
|
, crc32cMask
|
||||||
|
, crc32cUnmask
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Data.Bits (rotateL, rotateR)
|
import Data.Bits (rotateL, rotateR)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Lazy as BL
|
import qualified Data.ByteString.Lazy as BL
|
||||||
import Data.Digest.CRC32C (crc32c_update)
|
import Data.Digest.CRC32C (crc32c, crc32c_update)
|
||||||
import Data.List (foldl')
|
import Data.List (foldl')
|
||||||
import Data.Word (Word32)
|
import Data.Word (Word32)
|
||||||
|
|
||||||
-- | Compute the CRC32C checksum of the concatenation of the bytes checksummed
|
-- | Compute the CRC32C checksum of the concatenation of the bytes checksummed
|
||||||
-- by the given CRC32C value and the bytes in the given ByteString.
|
-- by the given CRC32C value and the bytes in the given ByteString.
|
||||||
extend :: Word32 -> B.ByteString -> Word32
|
crc32cUpdate :: Word32 -> B.ByteString -> Word32
|
||||||
extend = crc32c_update
|
crc32cUpdate = crc32c_update
|
||||||
|
|
||||||
-- | Compute the CRC32C checksum of the given bytes.
|
-- | Compute the CRC32C checksum of the given bytes.
|
||||||
value :: BL.ByteString -> Word32
|
crc32cLBS :: BL.ByteString -> Word32
|
||||||
value = foldl' extend 0 . BL.toChunks
|
crc32cLBS = foldl' crc32cUpdate 0 . BL.toChunks
|
||||||
|
|
||||||
-- | Scramble a CRC32C value so that the result can be safely stored in a
|
-- | Scramble a CRC32C value so that the result can be safely stored in a
|
||||||
-- bytestream that may itself be CRC'd.
|
-- bytestream that may itself be CRC'd.
|
||||||
--
|
--
|
||||||
-- This masking is the algorithm specified by TensorFlow's TFRecords format.
|
-- This masking is the algorithm specified by TensorFlow's TFRecords format.
|
||||||
mask :: Word32 -> Word32
|
crc32cMask :: Word32 -> Word32
|
||||||
mask x = rotateR x 15 + maskDelta
|
crc32cMask x = rotateR x 15 + maskDelta
|
||||||
|
|
||||||
-- | Inverse of 'mask'.
|
-- | Inverse of 'crc32cMask'.
|
||||||
unmask :: Word32 -> Word32
|
crc32cUnmask :: Word32 -> Word32
|
||||||
unmask x = rotateL (x - maskDelta) 15
|
crc32cUnmask x = rotateL (x - maskDelta) 15
|
||||||
|
|
||||||
-- | Convenience function combining 'value' and 'mask'.
|
-- | Convenience function combining 'crc32c' and 'crc32cMask'.
|
||||||
valueMasked :: BL.ByteString -> Word32
|
crc32cMasked :: B.ByteString -> Word32
|
||||||
valueMasked = mask . value
|
crc32cMasked = crc32cMask . crc32c
|
||||||
|
|
||||||
|
-- | Convenience function combining 'crc32cLBS' and 'crc32cMask'.
|
||||||
|
crc32cLBSMasked :: BL.ByteString -> Word32
|
||||||
|
crc32cLBSMasked = crc32cMask . crc32cLBS
|
||||||
|
|
||||||
maskDelta :: Word32
|
maskDelta :: Word32
|
||||||
maskDelta = 0xa282ead8
|
maskDelta = 0xa282ead8
|
||||||
|
|
|
@ -32,6 +32,7 @@ module TensorFlow.Records
|
||||||
, putTFRecordData
|
, putTFRecordData
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Control.Exception (evaluate)
|
||||||
import Control.Monad (when)
|
import Control.Monad (when)
|
||||||
import Data.ByteString.Unsafe (unsafePackCStringLen)
|
import Data.ByteString.Unsafe (unsafePackCStringLen)
|
||||||
import qualified Data.ByteString.Builder as B (Builder)
|
import qualified Data.ByteString.Builder as B (Builder)
|
||||||
|
@ -58,7 +59,7 @@ import Foreign.Marshal.Alloc (allocaBytes)
|
||||||
import Foreign.Ptr (Ptr, castPtr)
|
import Foreign.Ptr (Ptr, castPtr)
|
||||||
import System.IO.Unsafe (unsafePerformIO)
|
import System.IO.Unsafe (unsafePerformIO)
|
||||||
|
|
||||||
import qualified TensorFlow.CRC32C as CRC
|
import TensorFlow.CRC32C (crc32cLBSMasked, crc32cUpdate, crc32cMask)
|
||||||
|
|
||||||
-- | Parse one TFRecord.
|
-- | Parse one TFRecord.
|
||||||
getTFRecord :: Get BL.ByteString
|
getTFRecord :: Get BL.ByteString
|
||||||
|
@ -74,7 +75,7 @@ getTFRecords = do
|
||||||
getCheckMaskedCRC32C :: BL.ByteString -> Get ()
|
getCheckMaskedCRC32C :: BL.ByteString -> Get ()
|
||||||
getCheckMaskedCRC32C bs = do
|
getCheckMaskedCRC32C bs = do
|
||||||
wireCRC <- getWord32le
|
wireCRC <- getWord32le
|
||||||
let maskedCRC = CRC.valueMasked bs
|
let maskedCRC = crc32cLBSMasked bs
|
||||||
when (maskedCRC /= wireCRC) $ fail $
|
when (maskedCRC /= wireCRC) $ fail $
|
||||||
"getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++
|
"getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++
|
||||||
", expected: " ++ show wireCRC
|
", expected: " ++ show wireCRC
|
||||||
|
@ -95,12 +96,12 @@ getTFRecordData len = if len > 0x7fffffffffffffff
|
||||||
return bs
|
return bs
|
||||||
|
|
||||||
putMaskedCRC32C :: BL.ByteString -> Put
|
putMaskedCRC32C :: BL.ByteString -> Put
|
||||||
putMaskedCRC32C = putWord32le . CRC.valueMasked
|
putMaskedCRC32C = putWord32le . crc32cLBSMasked
|
||||||
|
|
||||||
-- Runs a Builder that's known to write a fixed number of bytes on a
|
-- Runs a Builder that's known to write a fixed number of bytes on an 'alloca'
|
||||||
-- stack-allocated buffer, and runs the given IO action on the result. Raises
|
-- buffer, and runs the given IO action on the result. Raises exceptions if
|
||||||
-- exceptions if the Builder yields ByteString chunks or attempts to write more
|
-- the Builder yields ByteString chunks or attempts to write more bytes than
|
||||||
-- bytes than expected.
|
-- expected.
|
||||||
unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r
|
unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r
|
||||||
unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do
|
unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do
|
||||||
(_, signal) <- runBuilder b ptr n
|
(_, signal) <- runBuilder b ptr n
|
||||||
|
@ -114,11 +115,13 @@ putTFRecordLength :: Word64 -> Put
|
||||||
putTFRecordLength x =
|
putTFRecordLength x =
|
||||||
let put = putWord64le x
|
let put = putWord64le x
|
||||||
len = 8
|
len = 8
|
||||||
crc = CRC.mask $ unsafePerformIO $
|
crc = crc32cMask $ unsafePerformIO $
|
||||||
-- Serialized Word64 is always 8 bytes, so we can go fast by using
|
-- Serialized Word64 is always 8 bytes, so we can go fast by using
|
||||||
-- alloca.
|
-- alloca.
|
||||||
unsafeWithFixedWidthBuilder len (execPut put) $
|
unsafeWithFixedWidthBuilder len (execPut put) $ \ptr -> do
|
||||||
\ptr -> CRC.extend 0 <$> unsafePackCStringLen (castPtr ptr, len)
|
str <- unsafePackCStringLen (castPtr ptr, len)
|
||||||
|
-- Force the result to ensure it's evaluated before freeing ptr.
|
||||||
|
evaluate $ crc32cUpdate 0 str
|
||||||
in put *> putWord32le crc
|
in put *> putWord32le crc
|
||||||
|
|
||||||
-- | Put a record payload and its checksum.
|
-- | Put a record payload and its checksum.
|
||||||
|
|
|
@ -19,7 +19,7 @@ library
|
||||||
, cereal
|
, cereal
|
||||||
-- TODO: Split Data.Digest.CRC32C out of snappy-framing for a
|
-- TODO: Split Data.Digest.CRC32C out of snappy-framing for a
|
||||||
-- lighter dependency?
|
-- lighter dependency?
|
||||||
, snappy-framing
|
, snappy-framing >= 0.1.1
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
||||||
Test-Suite RecordsTest
|
Test-Suite RecordsTest
|
||||||
|
|
Loading…
Reference in a new issue