-- Copyright 2016 TensorFlow authors. -- -- Licensed under the Apache License, Version 2.0 (the "License"); -- you may not use this file except in compliance with the License. -- You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- See the License for the specific language governing permissions and -- limitations under the License. {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ViewPatterns #-} module TensorFlow.Examples.MNIST.Parse where import Control.Monad (when, liftM) import Data.Binary.Get (Get, runGet, getWord32be, getLazyByteString) import Data.ByteString.Lazy (toStrict, readFile) import Data.List.Split (chunksOf) import Data.ProtoLens (Message, decodeMessageOrDie) import Data.Text (Text) import Data.Word (Word8, Word32) import Prelude hiding (readFile) import qualified Codec.Compression.GZip as GZip import qualified Data.ByteString.Lazy as L import qualified Data.Text as Text import qualified Data.Vector as V -- | Utilities specific to MNIST. type MNIST = V.Vector Word8 -- | Produces a unicode rendering of the MNIST digit sample. drawMNIST :: MNIST -> Text drawMNIST = chunk . block where block :: V.Vector Word8 -> Text block (V.splitAt 1 -> ([0], xs)) = " " <> block xs block (V.splitAt 1 -> ([n], xs)) = c `Text.cons` block xs where c = "\9617\9618\9619\9608" !! fromIntegral (n `div` 64) block (V.splitAt 1 -> _) = "" chunk :: Text -> Text chunk "" = "\n" chunk xs = Text.take 28 xs <> "\n" <> chunk (Text.drop 28 xs) -- | Check's the file's endianess, throwing an error if it's not as expected. checkEndian :: Get () checkEndian = do magic <- getWord32be when (magic `notElem` ([2049, 2051] :: [Word32])) $ fail "Expected big endian, but image file is little endian." -- | Reads an MNIST file and returns a list of samples. readMNISTSamples :: FilePath -> IO [MNIST] readMNISTSamples path = do raw <- GZip.decompress <$> readFile path return $ runGet getMNIST raw where getMNIST :: Get [MNIST] getMNIST = do checkEndian -- Parse header data. cnt <- liftM fromIntegral getWord32be rows <- liftM fromIntegral getWord32be cols <- liftM fromIntegral getWord32be -- Read all of the data, then split into samples. pixels <- getLazyByteString $ fromIntegral $ cnt * rows * cols return $ V.fromList <$> chunksOf (rows * cols) (L.unpack pixels) -- | Reads a list of MNIST labels from a file and returns them. readMNISTLabels :: FilePath -> IO [Word8] readMNISTLabels path = do raw <- GZip.decompress <$> readFile path return $ runGet getLabels raw where getLabels :: Get [Word8] getLabels = do checkEndian -- Parse header data. cnt <- liftM fromIntegral getWord32be -- Read all of the labels. L.unpack <$> getLazyByteString cnt readMessageFromFileOrDie :: Message m => FilePath -> IO m readMessageFromFileOrDie path = do pb <- readFile path return $ decodeMessageOrDie $ toStrict pb -- TODO: Write a writeMessageFromFileOrDie and read/write non-lethal -- versions.