tensorflow-haskell/tensorflow-mnist-input-data/Setup.hs

119 lines
3.8 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE LambdaCase #-}
-- | Downloads the MNIST data set and packages them as data files.
module Main where
import Control.Monad (when)
import Data.Maybe (fromMaybe)
import Distribution.PackageDescription
( GenericPackageDescription(packageDescription)
, dataDir
)
import Distribution.Simple
( UserHooks(..)
, defaultMainWithHooks
, simpleUserHooks
)
import System.IO (hPutStrLn, stderr)
import System.FilePath ((</>))
import System.Directory (doesFileExist)
import qualified Crypto.Hash as Hash
import qualified Data.ByteString.Lazy as B
import qualified Network.HTTP as HTTP
import qualified Network.Browser as Browser
import qualified Network.URI as URI
main :: IO ()
main = defaultMainWithHooks downloadingDataFiles
downloadingDataFiles :: UserHooks
downloadingDataFiles = hooks
{ confHook = \gh@(g, _) c -> downloadFiles g >> confHook hooks gh c
}
where
hooks = simpleUserHooks
downloadFiles :: GenericPackageDescription -> IO ()
downloadFiles g = do
let dir = dataDir (packageDescription g)
mapM_ (maybeDownload dir) fileInfos
maybeDownload :: FilePath -> (String, String) -> IO ()
maybeDownload dataDir (basename, sha256) = do
let filePath = dataDir </> basename
exists <- doesFileExist filePath
when (not exists) $ do
let url = urlPrefix ++ basename
hPutStrLn stderr ("Downloading " ++ url)
httpDownload url filePath
verify filePath sha256
httpDownload :: String -> FilePath -> IO ()
httpDownload url outFile = do
let uri = fromMaybe
(error ("Can't be: invalid URI " ++ url))
(URI.parseURI url)
(_, rsp)
<- Browser.browse $ do
Browser.setAllowRedirects True
Browser.setCheckForProxy True
Browser.request $ HTTP.defaultGETRequest_ uri
case HTTP.rspCode rsp of
(2, 0, 0) -> B.writeFile outFile $ HTTP.rspBody rsp
s -> error ( "Failed to download " ++ url ++ " error code " ++ show s
++ helpfulMessage
)
verify :: FilePath -> String -> IO ()
verify filePath hash = do
let sha256 = Hash.hashlazy :: B.ByteString -> Hash.Digest Hash.SHA256
computed <- show . sha256 <$> B.readFile filePath
when (hash /= computed) $
error ( "Incorrect checksum for " ++ filePath
++ "\nexpected " ++ hash
++ "\ncomputed " ++ computed
++ helpfulMessage
)
urlPrefix = "http://yann.lecun.com/exdb/mnist/"
-- | File names relative to 'urlPrefix' and their sha256.
fileInfos = [
( "train-images-idx3-ubyte.gz"
, "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
)
,
( "train-labels-idx1-ubyte.gz"
, "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
)
,
( "t10k-images-idx3-ubyte.gz"
, "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
)
,
( "t10k-labels-idx1-ubyte.gz"
, "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
)
]
helpfulMessage =
unlines
( ""
: ""
: "Please download the following URLs manually and put them in data/"
: [ urlPrefix ++ h | (h, _) <- fileInfos ]
)