mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-27 05:19:45 +01:00
Add tensorflow-records and tensorflow-records-conduit package
Original author: awpr@
This commit is contained in:
parent
4b5a57152f
commit
7484214854
11 changed files with 372 additions and 0 deletions
|
@ -18,6 +18,8 @@ RUN \
|
||||||
echo 'deb http://download.fpcomplete.com/ubuntu trusty main'| tee /etc/apt/sources.list.d/fpco.list && \
|
echo 'deb http://download.fpcomplete.com/ubuntu trusty main'| tee /etc/apt/sources.list.d/fpco.list && \
|
||||||
apt-get update && \
|
apt-get update && \
|
||||||
apt-get install -y \
|
apt-get install -y \
|
||||||
|
# Required by snappy-frames dependency.
|
||||||
|
libsnappy-dev \
|
||||||
# Avoids /usr/bin/ld: cannot find -ltinfo
|
# Avoids /usr/bin/ld: cannot find -ltinfo
|
||||||
libncurses5-dev \
|
libncurses5-dev \
|
||||||
# Makes stack viable in the container
|
# Makes stack viable in the container
|
||||||
|
|
|
@ -15,6 +15,8 @@ RUN \
|
||||||
RUN apt-get update
|
RUN apt-get update
|
||||||
|
|
||||||
RUN apt-get install -y \
|
RUN apt-get install -y \
|
||||||
|
# Required by snappy-frames dependency.
|
||||||
|
libsnappy-dev \
|
||||||
# Avoids /usr/bin/ld: cannot find -ltinfo
|
# Avoids /usr/bin/ld: cannot find -ltinfo
|
||||||
libncurses5-dev \
|
libncurses5-dev \
|
||||||
# Makes stack viable in the container
|
# Makes stack viable in the container
|
||||||
|
|
|
@ -11,12 +11,16 @@ packages:
|
||||||
- tensorflow-mnist-input-data
|
- tensorflow-mnist-input-data
|
||||||
- tensorflow-queue
|
- tensorflow-queue
|
||||||
- tensorflow-nn
|
- tensorflow-nn
|
||||||
|
- tensorflow-records
|
||||||
|
- tensorflow-records-conduit
|
||||||
- tensorflow-test
|
- tensorflow-test
|
||||||
|
|
||||||
extra-deps:
|
extra-deps:
|
||||||
# proto-lens is not yet in Stackage.
|
# proto-lens is not yet in Stackage.
|
||||||
- proto-lens-0.1.0.4
|
- proto-lens-0.1.0.4
|
||||||
- proto-lens-protoc-0.1.0.4
|
- proto-lens-protoc-0.1.0.4
|
||||||
|
- snappy-framing-0.1.1
|
||||||
|
- snappy-0.2.0.2
|
||||||
|
|
||||||
# Allow our custom Setup.hs scripts to import Data.ProtoLens.Setup from the version of
|
# Allow our custom Setup.hs scripts to import Data.ProtoLens.Setup from the version of
|
||||||
# `proto-lens-protoc` in stack's local DB. See:
|
# `proto-lens-protoc` in stack's local DB. See:
|
||||||
|
|
3
tensorflow-records-conduit/Setup.hs
Normal file
3
tensorflow-records-conduit/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
53
tensorflow-records-conduit/src/TensorFlow/Records/Conduit.hs
Normal file
53
tensorflow-records-conduit/src/TensorFlow/Records/Conduit.hs
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
-- | Conduit wrappers for TensorFlow.Records.
|
||||||
|
|
||||||
|
{-# LANGUAGE Rank2Types #-}
|
||||||
|
module TensorFlow.Records.Conduit
|
||||||
|
(
|
||||||
|
-- * Encode/Decode
|
||||||
|
encodeTFRecords
|
||||||
|
, decodeTFRecords
|
||||||
|
|
||||||
|
-- * Source/Sink
|
||||||
|
, sinkTFRecords
|
||||||
|
, sourceTFRecords
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.Catch (MonadThrow)
|
||||||
|
import Control.Monad.Trans.Resource (MonadResource)
|
||||||
|
import qualified Data.ByteString as B
|
||||||
|
import qualified Data.ByteString.Lazy as BL
|
||||||
|
import Data.Conduit ((=$=), Conduit, Consumer, Producer)
|
||||||
|
import Data.Conduit.Binary (sinkFile, sourceFile)
|
||||||
|
import Data.Conduit.Cereal (conduitGet2, conduitPut)
|
||||||
|
|
||||||
|
import TensorFlow.Records (getTFRecord, putTFRecord)
|
||||||
|
|
||||||
|
-- | Decode TFRecords from a stream of bytes.
|
||||||
|
decodeTFRecords :: MonadThrow m => Conduit B.ByteString m BL.ByteString
|
||||||
|
decodeTFRecords = conduitGet2 getTFRecord
|
||||||
|
|
||||||
|
-- | Read TFRecords from a file.
|
||||||
|
sourceTFRecords :: (MonadResource m, MonadThrow m) => FilePath -> Producer m BL.ByteString
|
||||||
|
sourceTFRecords path = sourceFile path =$= decodeTFRecords
|
||||||
|
|
||||||
|
-- | Encode TFRecords to a stream of bytes.
|
||||||
|
encodeTFRecords :: Monad m => Conduit BL.ByteString m B.ByteString
|
||||||
|
encodeTFRecords = conduitPut putTFRecord
|
||||||
|
|
||||||
|
-- | Write TFRecords to a file.
|
||||||
|
sinkTFRecords :: (MonadResource m) => FilePath -> Consumer BL.ByteString m ()
|
||||||
|
sinkTFRecords path = encodeTFRecords =$= sinkFile path
|
28
tensorflow-records-conduit/tensorflow-records-conduit.cabal
Normal file
28
tensorflow-records-conduit/tensorflow-records-conduit.cabal
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
name: tensorflow-records-conduit
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Conduit wrappers for TensorFlow.Records.
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.Records.Conduit
|
||||||
|
build-depends: base >= 4.7 && < 5
|
||||||
|
, bytestring
|
||||||
|
, cereal-conduit
|
||||||
|
, conduit
|
||||||
|
, conduit-extra
|
||||||
|
, exceptions
|
||||||
|
, resourcet
|
||||||
|
, tensorflow-records
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
3
tensorflow-records/Setup.hs
Normal file
3
tensorflow-records/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
55
tensorflow-records/src/TensorFlow/CRC32C.hs
Normal file
55
tensorflow-records/src/TensorFlow/CRC32C.hs
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
-- | FFI wrappers of CRC32C digest. Import qualified.
|
||||||
|
|
||||||
|
module TensorFlow.CRC32C
|
||||||
|
( value, extend
|
||||||
|
, mask, unmask
|
||||||
|
, valueMasked
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Bits (rotateL, rotateR)
|
||||||
|
import qualified Data.ByteString as B
|
||||||
|
import qualified Data.ByteString.Lazy as BL
|
||||||
|
import Data.Digest.CRC32C (crc32c_update)
|
||||||
|
import Data.List (foldl')
|
||||||
|
import Data.Word (Word32)
|
||||||
|
|
||||||
|
-- | Compute the CRC32C checksum of the concatenation of the bytes checksummed
|
||||||
|
-- by the given CRC32C value and the bytes in the given ByteString.
|
||||||
|
extend :: Word32 -> B.ByteString -> Word32
|
||||||
|
extend = crc32c_update
|
||||||
|
|
||||||
|
-- | Compute the CRC32C checksum of the given bytes.
|
||||||
|
value :: BL.ByteString -> Word32
|
||||||
|
value = foldl' extend 0 . BL.toChunks
|
||||||
|
|
||||||
|
-- | Scramble a CRC32C value so that the result can be safely stored in a
|
||||||
|
-- bytestream that may itself be CRC'd.
|
||||||
|
--
|
||||||
|
-- This masking is the algorithm specified by TensorFlow's TFRecords format.
|
||||||
|
mask :: Word32 -> Word32
|
||||||
|
mask x = rotateR x 15 + maskDelta
|
||||||
|
|
||||||
|
-- | Inverse of 'mask'.
|
||||||
|
unmask :: Word32 -> Word32
|
||||||
|
unmask x = rotateL (x - maskDelta) 15
|
||||||
|
|
||||||
|
-- | Convenience function combining 'value' and 'mask'.
|
||||||
|
valueMasked :: BL.ByteString -> Word32
|
||||||
|
valueMasked = mask . value
|
||||||
|
|
||||||
|
maskDelta :: Word32
|
||||||
|
maskDelta = 0xa282ead8
|
131
tensorflow-records/src/TensorFlow/Records.hs
Normal file
131
tensorflow-records/src/TensorFlow/Records.hs
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
-- | Encoder and decoder for the TensorFlow \"TFRecords\" format.
|
||||||
|
|
||||||
|
{-# LANGUAGE Rank2Types #-}
|
||||||
|
module TensorFlow.Records
|
||||||
|
(
|
||||||
|
-- * Records
|
||||||
|
putTFRecord
|
||||||
|
, getTFRecord
|
||||||
|
, getTFRecords
|
||||||
|
|
||||||
|
-- * Implementation
|
||||||
|
|
||||||
|
-- | These may be useful for encoding or decoding to types other than
|
||||||
|
-- 'ByteString' that have their own Cereal codecs.
|
||||||
|
, getTFRecordLength
|
||||||
|
, getTFRecordData
|
||||||
|
, putTFRecordLength
|
||||||
|
, putTFRecordData
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad (when)
|
||||||
|
import Data.ByteString.Unsafe (unsafePackCStringLen)
|
||||||
|
import qualified Data.ByteString.Builder as B (Builder)
|
||||||
|
import Data.ByteString.Builder.Extra (runBuilder, Next(..))
|
||||||
|
import qualified Data.ByteString.Lazy as BL
|
||||||
|
import Data.Serialize.Get
|
||||||
|
( Get
|
||||||
|
, getBytes
|
||||||
|
, getWord32le
|
||||||
|
, getWord64le
|
||||||
|
, getLazyByteString
|
||||||
|
, isEmpty
|
||||||
|
, lookAhead
|
||||||
|
)
|
||||||
|
import Data.Serialize
|
||||||
|
( Put
|
||||||
|
, execPut
|
||||||
|
, putLazyByteString
|
||||||
|
, putWord32le
|
||||||
|
, putWord64le
|
||||||
|
)
|
||||||
|
import Data.Word (Word8, Word64)
|
||||||
|
import Foreign.Marshal.Alloc (allocaBytes)
|
||||||
|
import Foreign.Ptr (Ptr, castPtr)
|
||||||
|
import System.IO.Unsafe (unsafePerformIO)
|
||||||
|
|
||||||
|
import qualified TensorFlow.CRC32C as CRC
|
||||||
|
|
||||||
|
-- | Parse one TFRecord.
|
||||||
|
getTFRecord :: Get BL.ByteString
|
||||||
|
getTFRecord = getTFRecordLength >>= getTFRecordData
|
||||||
|
|
||||||
|
-- | Parse many TFRecords as a list. Note you probably want streaming instead
|
||||||
|
-- as provided by the tensorflow-records-conduit package.
|
||||||
|
getTFRecords :: Get [BL.ByteString]
|
||||||
|
getTFRecords = do
|
||||||
|
e <- isEmpty
|
||||||
|
if e then return [] else (:) <$> getTFRecord <*> getTFRecords
|
||||||
|
|
||||||
|
getCheckMaskedCRC32C :: BL.ByteString -> Get ()
|
||||||
|
getCheckMaskedCRC32C bs = do
|
||||||
|
wireCRC <- getWord32le
|
||||||
|
let maskedCRC = CRC.valueMasked bs
|
||||||
|
when (maskedCRC /= wireCRC) $ fail $
|
||||||
|
"getCheckMaskedCRC32C: CRC mismatch, computed: " ++ show maskedCRC ++
|
||||||
|
", expected: " ++ show wireCRC
|
||||||
|
|
||||||
|
-- | Get a length and verify its checksum.
|
||||||
|
getTFRecordLength :: Get Word64
|
||||||
|
getTFRecordLength = do
|
||||||
|
buf <- lookAhead (getBytes 8)
|
||||||
|
getWord64le <* getCheckMaskedCRC32C (BL.fromStrict buf)
|
||||||
|
|
||||||
|
-- | Get a record payload and verify its checksum.
|
||||||
|
getTFRecordData :: Word64 -> Get BL.ByteString
|
||||||
|
getTFRecordData len = if len > 0x7fffffffffffffff
|
||||||
|
then fail "getTFRecordData: Record size overflows Int64"
|
||||||
|
else do
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
getCheckMaskedCRC32C bs
|
||||||
|
return bs
|
||||||
|
|
||||||
|
putMaskedCRC32C :: BL.ByteString -> Put
|
||||||
|
putMaskedCRC32C = putWord32le . CRC.valueMasked
|
||||||
|
|
||||||
|
-- Runs a Builder that's known to write a fixed number of bytes on a
|
||||||
|
-- stack-allocated buffer, and runs the given IO action on the result. Raises
|
||||||
|
-- exceptions if the Builder yields ByteString chunks or attempts to write more
|
||||||
|
-- bytes than expected.
|
||||||
|
unsafeWithFixedWidthBuilder :: Int -> B.Builder -> (Ptr Word8 -> IO r) -> IO r
|
||||||
|
unsafeWithFixedWidthBuilder n b act = allocaBytes n $ \ptr -> do
|
||||||
|
(_, signal) <- runBuilder b ptr n
|
||||||
|
case signal of
|
||||||
|
Done -> act ptr
|
||||||
|
More _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned More."
|
||||||
|
Chunk _ _ -> error "unsafeWithFixedWidthBuilder: Builder returned Chunk."
|
||||||
|
|
||||||
|
-- | Put a record length and its checksum.
|
||||||
|
putTFRecordLength :: Word64 -> Put
|
||||||
|
putTFRecordLength x =
|
||||||
|
let put = putWord64le x
|
||||||
|
len = 8
|
||||||
|
crc = CRC.mask $ unsafePerformIO $
|
||||||
|
-- Serialized Word64 is always 8 bytes, so we can go fast by using
|
||||||
|
-- alloca.
|
||||||
|
unsafeWithFixedWidthBuilder len (execPut put) $
|
||||||
|
\ptr -> CRC.extend 0 <$> unsafePackCStringLen (castPtr ptr, len)
|
||||||
|
in put *> putWord32le crc
|
||||||
|
|
||||||
|
-- | Put a record payload and its checksum.
|
||||||
|
putTFRecordData :: BL.ByteString -> Put
|
||||||
|
putTFRecordData bs = putLazyByteString bs *> putMaskedCRC32C bs
|
||||||
|
|
||||||
|
-- | Put one TFRecord with the given contents.
|
||||||
|
putTFRecord :: BL.ByteString -> Put
|
||||||
|
putTFRecord bs =
|
||||||
|
putTFRecordLength (fromIntegral $ BL.length bs) *> putTFRecordData bs
|
40
tensorflow-records/tensorflow-records.cabal
Normal file
40
tensorflow-records/tensorflow-records.cabal
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
name: tensorflow-records
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Encoder and decoder for the TensorFlow \"TFRecords\" format.
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.Records
|
||||||
|
other-modules: TensorFlow.CRC32C
|
||||||
|
build-depends: base >= 4.7 && < 5
|
||||||
|
, bytestring
|
||||||
|
, cereal
|
||||||
|
-- TODO: Split Data.Digest.CRC32C out of snappy-framing for a
|
||||||
|
-- lighter dependency?
|
||||||
|
, snappy-framing
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
Test-Suite RecordsTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: Main.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: base
|
||||||
|
, bytestring
|
||||||
|
, cereal
|
||||||
|
, tensorflow-records
|
||||||
|
, test-framework
|
||||||
|
, test-framework-quickcheck2
|
||||||
|
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
51
tensorflow-records/tests/Main.hs
Normal file
51
tensorflow-records/tests/Main.hs
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import qualified Data.ByteString as B
|
||||||
|
import qualified Data.ByteString.Lazy as BL
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Data.Word (Word8)
|
||||||
|
import Data.Serialize (runGet, runPut)
|
||||||
|
import Test.Framework (Test, defaultMain)
|
||||||
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
|
|
||||||
|
import TensorFlow.Records (getTFRecord, putTFRecord)
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = defaultMain tests
|
||||||
|
|
||||||
|
tests :: [Test]
|
||||||
|
tests =
|
||||||
|
[ testProperty "Inverse" propEncodeDecodeInverse
|
||||||
|
, testProperty "FixedRecord" propFixedRecord
|
||||||
|
]
|
||||||
|
|
||||||
|
-- There's no (Arbitrary BL.ByteString), so pack it from a list of chunks.
|
||||||
|
propEncodeDecodeInverse :: [[Word8]] -> Bool
|
||||||
|
propEncodeDecodeInverse s =
|
||||||
|
let bs = BL.fromChunks . fmap B.pack $ s
|
||||||
|
in runGet getTFRecord (runPut (putTFRecord bs)) == Right bs
|
||||||
|
|
||||||
|
propFixedRecord :: Bool
|
||||||
|
propFixedRecord =
|
||||||
|
("\x42" == case runGet getTFRecord record of
|
||||||
|
Left err -> error err -- Make the error appear in the test failure.
|
||||||
|
Right x -> x) &&
|
||||||
|
(runPut (putTFRecord "\x42") == record)
|
||||||
|
where
|
||||||
|
record = "\x01\x00\x00\x00\x00\x00\x00\x00" <> "\x01\x75\xde\x41" <>
|
||||||
|
"\x42" <> "\x52\xcf\xb8\x1e"
|
Loading…
Reference in a new issue