mirror of
synced 2025-03-21 05:05:10 +01:00
96 lines
3.4 KiB
96 lines
3.4 KiB
-- Copyright 2016 TensorFlow authors.
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
-- http://www.apache.org/licenses/LICENSE-2.0
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ViewPatterns #-}
module TensorFlow.Examples.MNIST.Parse where
import Control.Monad (when, liftM)
import Data.Binary.Get (Get, runGet, getWord32be, getLazyByteString)
import Data.ByteString.Lazy (toStrict, readFile)
import Data.List.Split (chunksOf)
import Data.ProtoLens (Message, decodeMessageOrDie)
import Data.Text (Text)
import Data.Word (Word8, Word32)
import Prelude hiding (readFile)
import qualified Codec.Compression.GZip as GZip
import qualified Data.ByteString.Lazy as L
import qualified Data.Text as Text
import qualified Data.Vector as V
-- | Utilities specific to MNIST.
type MNIST = V.Vector Word8
-- | Produces a unicode rendering of the MNIST digit sample.
drawMNIST :: MNIST -> Text
drawMNIST = chunk . block
block :: V.Vector Word8 -> Text
block (V.splitAt 1 -> ([0], xs)) = " " <> block xs
block (V.splitAt 1 -> ([n], xs)) = c `Text.cons` block xs
where c = "\9617\9618\9619\9608" !! fromIntegral (n `div` 64)
block (V.splitAt 1 -> _) = ""
chunk :: Text -> Text
chunk "" = "\n"
chunk xs = Text.take 28 xs <> "\n" <> chunk (Text.drop 28 xs)
-- | Check's the file's endianess, throwing an error if it's not as expected.
checkEndian :: Get ()
checkEndian = do
magic <- getWord32be
when (magic `notElem` ([2049, 2051] :: [Word32])) $
fail "Expected big endian, but image file is little endian."
-- | Reads an MNIST file and returns a list of samples.
readMNISTSamples :: FilePath -> IO [MNIST]
readMNISTSamples path = do
raw <- GZip.decompress <$> readFile path
return $ runGet getMNIST raw
getMNIST :: Get [MNIST]
getMNIST = do
-- Parse header data.
cnt <- liftM fromIntegral getWord32be
rows <- liftM fromIntegral getWord32be
cols <- liftM fromIntegral getWord32be
-- Read all of the data, then split into samples.
pixels <- getLazyByteString $ fromIntegral $ cnt * rows * cols
return $ V.fromList <$> chunksOf (rows * cols) (L.unpack pixels)
-- | Reads a list of MNIST labels from a file and returns them.
readMNISTLabels :: FilePath -> IO [Word8]
readMNISTLabels path = do
raw <- GZip.decompress <$> readFile path
return $ runGet getLabels raw
where getLabels :: Get [Word8]
getLabels = do
-- Parse header data.
cnt <- liftM fromIntegral getWord32be
-- Read all of the labels.
L.unpack <$> getLazyByteString cnt
readMessageFromFileOrDie :: Message m => FilePath -> IO m
readMessageFromFileOrDie path = do
pb <- readFile path
return $ decodeMessageOrDie $ toStrict pb
-- TODO: Write a writeMessageFromFileOrDie and read/write non-lethal
-- versions.