{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ViewPatterns #-}
module TensorFlow.Examples.MNIST.Parse where
import Control.Monad (when, liftM)
import Data.Binary.Get (Get, runGet, getWord32be, getLazyByteString)
import Data.ByteString.Lazy (toStrict, readFile)
import Data.List.Split (chunksOf)
import Data.Monoid ((<>))
import Data.ProtoLens (Message, decodeMessageOrDie)
import Data.Text (Text)
import Data.Word (Word8, Word32)
import Prelude hiding (readFile)
import qualified Codec.Compression.GZip as GZip
import qualified Data.ByteString.Lazy as L
import qualified Data.Text as Text
import qualified Data.Vector as V
type MNIST = V.Vector Word8
drawMNIST :: MNIST -> Text
drawMNIST = chunk . block
where
block :: V.Vector Word8 -> Text
block (V.splitAt 1 -> ([0], xs)) = " " <> block xs
block (V.splitAt 1 -> ([n], xs)) = c `Text.cons` block xs
where c = "\9617\9618\9619\9608" !! fromIntegral (n `div` 64)
block (V.splitAt 1 -> _) = ""
chunk :: Text -> Text
chunk "" = "\n"
chunk xs = Text.take 28 xs <> "\n" <> chunk (Text.drop 28 xs)
checkEndian :: Get ()
checkEndian = do
magic <- getWord32be
when (magic `notElem` ([2049, 2051] :: [Word32])) $
fail "Expected big endian, but image file is little endian."
readMNISTSamples :: FilePath -> IO [MNIST]
readMNISTSamples path = do
raw <- GZip.decompress <$> readFile path
return $ runGet getMNIST raw
where
getMNIST :: Get [MNIST]
getMNIST = do
checkEndian
cnt <- liftM fromIntegral getWord32be
rows <- liftM fromIntegral getWord32be
cols <- liftM fromIntegral getWord32be
pixels <- getLazyByteString $ fromIntegral $ cnt * rows * cols
return $ V.fromList <$> chunksOf (rows * cols) (L.unpack pixels)
readMNISTLabels :: FilePath -> IO [Word8]
readMNISTLabels path = do
raw <- GZip.decompress <$> readFile path
return $ runGet getLabels raw
where getLabels :: Get [Word8]
getLabels = do
checkEndian
cnt <- liftM fromIntegral getWord32be
L.unpack <$> getLazyByteString cnt
readMessageFromFileOrDie :: Message m => FilePath -> IO m
readMessageFromFileOrDie path = do
pb <- readFile path
return $ decodeMessageOrDie $ toStrict pb