{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ViewPatterns #-}
module TensorFlow.Examples.MNIST.Parse where
import Control.Monad (when, liftM)
import Data.Binary.Get (Get, runGet, getWord32be, getLazyByteString)
import Data.ByteString.Lazy (toStrict, readFile)
import Data.List.Split (chunksOf)
import Data.ProtoLens (Message, decodeMessageOrDie)
import Data.Text (Text)
import Data.Word (Word8, Word32)
import Prelude hiding (readFile)
import qualified Codec.Compression.GZip as GZip
import qualified Data.ByteString.Lazy as L
import qualified Data.Text as Text
import qualified Data.Vector as V
type MNIST = V.Vector Word8
drawMNIST :: MNIST -> Text
drawMNIST :: MNIST -> Text
drawMNIST = Text -> Text
chunk (Text -> Text) -> (MNIST -> Text) -> MNIST -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MNIST -> Text
block
where
block :: V.Vector Word8 -> Text
block :: MNIST -> Text
block (Int -> MNIST -> (MNIST, MNIST)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt 1 -> ([0], xs :: MNIST
xs)) = " " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> MNIST -> Text
block MNIST
xs
block (Int -> MNIST -> (MNIST, MNIST)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt 1 -> ([n :: Item MNIST
n], xs :: MNIST
xs)) = Char
c Char -> Text -> Text
`Text.cons` MNIST -> Text
block MNIST
xs
where c :: Char
c = "\9617\9618\9619\9608" [Char] -> Int -> Char
forall a. [a] -> Int -> a
!! Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
Item MNIST
n Word8 -> Word8 -> Word8
forall a. Integral a => a -> a -> a
`div` 64)
block (Int -> MNIST -> (MNIST, MNIST)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt 1 -> (MNIST, MNIST)
_) = ""
chunk :: Text -> Text
chunk :: Text -> Text
chunk "" = "\n"
chunk xs :: Text
xs = Int -> Text -> Text
Text.take 28 Text
xs Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> "\n" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
chunk (Int -> Text -> Text
Text.drop 28 Text
xs)
checkEndian :: Get ()
checkEndian :: Get ()
checkEndian = do
Word32
magic <- Get Word32
getWord32be
Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word32
magic Word32 -> [Word32] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` ([2049, 2051] :: [Word32])) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$
[Char] -> Get ()
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail "Expected big endian, but image file is little endian."
readMNISTSamples :: FilePath -> IO [MNIST]
readMNISTSamples :: [Char] -> IO [MNIST]
readMNISTSamples path :: [Char]
path = do
ByteString
raw <- ByteString -> ByteString
GZip.decompress (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> IO ByteString
readFile [Char]
path
[MNIST] -> IO [MNIST]
forall (m :: * -> *) a. Monad m => a -> m a
return ([MNIST] -> IO [MNIST]) -> [MNIST] -> IO [MNIST]
forall a b. (a -> b) -> a -> b
$ Get [MNIST] -> ByteString -> [MNIST]
forall a. Get a -> ByteString -> a
runGet Get [MNIST]
getMNIST ByteString
raw
where
getMNIST :: Get [MNIST]
getMNIST :: Get [MNIST]
getMNIST = do
Get ()
checkEndian
Int
cnt <- (Word32 -> Int) -> Get Word32 -> Get Int
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Get Word32
getWord32be
Int
rows <- (Word32 -> Int) -> Get Word32 -> Get Int
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Get Word32
getWord32be
Int
cols <- (Word32 -> Int) -> Get Word32 -> Get Int
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Get Word32
getWord32be
ByteString
pixels <- Int64 -> Get ByteString
getLazyByteString (Int64 -> Get ByteString) -> Int64 -> Get ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> Int -> Int64
forall a b. (a -> b) -> a -> b
$ Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cols
[MNIST] -> Get [MNIST]
forall (m :: * -> *) a. Monad m => a -> m a
return ([MNIST] -> Get [MNIST]) -> [MNIST] -> Get [MNIST]
forall a b. (a -> b) -> a -> b
$ [Word8] -> MNIST
forall a. [a] -> Vector a
V.fromList ([Word8] -> MNIST) -> [[Word8]] -> [MNIST]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [Word8] -> [[Word8]]
forall e. Int -> [e] -> [[e]]
chunksOf (Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cols) (ByteString -> [Word8]
L.unpack ByteString
pixels)
readMNISTLabels :: FilePath -> IO [Word8]
readMNISTLabels :: [Char] -> IO [Word8]
readMNISTLabels path :: [Char]
path = do
ByteString
raw <- ByteString -> ByteString
GZip.decompress (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> IO ByteString
readFile [Char]
path
[Word8] -> IO [Word8]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Word8] -> IO [Word8]) -> [Word8] -> IO [Word8]
forall a b. (a -> b) -> a -> b
$ Get [Word8] -> ByteString -> [Word8]
forall a. Get a -> ByteString -> a
runGet Get [Word8]
getLabels ByteString
raw
where getLabels :: Get [Word8]
getLabels :: Get [Word8]
getLabels = do
Get ()
checkEndian
Int64
cnt <- (Word32 -> Int64) -> Get Word32 -> Get Int64
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Word32 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Get Word32
getWord32be
ByteString -> [Word8]
L.unpack (ByteString -> [Word8]) -> Get ByteString -> Get [Word8]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int64 -> Get ByteString
getLazyByteString Int64
cnt
readMessageFromFileOrDie :: Message m => FilePath -> IO m
readMessageFromFileOrDie :: [Char] -> IO m
readMessageFromFileOrDie path :: [Char]
path = do
ByteString
pb <- [Char] -> IO ByteString
readFile [Char]
path
m -> IO m
forall (m :: * -> *) a. Monad m => a -> m a
return (m -> IO m) -> m -> IO m
forall a b. (a -> b) -> a -> b
$ ByteString -> m
forall msg. Message msg => ByteString -> msg
decodeMessageOrDie (ByteString -> m) -> ByteString -> m
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
toStrict ByteString
pb