1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-27 03:59:46 +01:00
tensorflow-haskell/tensorflow-mnist/tests/ParseTest.hs

177 lines
6.5 KiB
Haskell
Raw Normal View History

2016-10-24 21:26:42 +02:00
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Data.Text (Text)
import qualified Data.Text.IO as Text
import Lens.Family2 ((&), (.~), (^.))
import Prelude hiding (abs)
import Proto.Tensorflow.Core.Framework.Graph
2019-04-12 04:27:15 +02:00
( GraphDef )
import Proto.Tensorflow.Core.Framework.Graph_Fields
( version
2016-10-24 21:26:42 +02:00
, node )
import Proto.Tensorflow.Core.Framework.NodeDef
2019-04-12 04:27:15 +02:00
( NodeDef )
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
2016-10-24 21:26:42 +02:00
import System.IO as IO
import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse
import TensorFlow.Examples.MNIST.TrainedGraph
import TensorFlow.Build
( asGraphDef
, addGraphDef
, Build
2016-10-24 21:26:42 +02:00
)
import TensorFlow.Tensor
( Tensor(..)
, Ref
, feed
, render
2016-10-24 21:26:42 +02:00
, tensorFromName
, tensorValueFromName
2016-10-24 21:26:42 +02:00
)
import TensorFlow.Ops
import TensorFlow.Session
(runSession, run, run_, runWithFeeds, build)
Support fetching storable vectors + use them in benchmark (#50) In addition, you can now fetch TensorData directly. This might be useful in scenarios where you feed the result of a computation back in, like RNN. Before: benchmarking feedFetch/4 byte time 83.31 μs (81.88 μs .. 84.75 μs) 0.997 R² (0.994 R² .. 0.998 R²) mean 87.32 μs (86.06 μs .. 88.83 μs) std dev 4.580 μs (3.698 μs .. 5.567 μs) variance introduced by outliers: 55% (severely inflated) benchmarking feedFetch/4 KiB time 114.9 μs (111.5 μs .. 118.2 μs) 0.996 R² (0.994 R² .. 0.998 R²) mean 117.3 μs (116.2 μs .. 118.6 μs) std dev 3.877 μs (3.058 μs .. 5.565 μs) variance introduced by outliers: 31% (moderately inflated) benchmarking feedFetch/4 MiB time 109.0 ms (107.9 ms .. 110.7 ms) 1.000 R² (0.999 R² .. 1.000 R²) mean 108.6 ms (108.2 ms .. 109.2 ms) std dev 740.2 μs (353.2 μs .. 1.186 ms) After: benchmarking feedFetch/4 byte time 82.92 μs (80.55 μs .. 85.24 μs) 0.996 R² (0.993 R² .. 0.998 R²) mean 83.58 μs (82.34 μs .. 84.89 μs) std dev 4.327 μs (3.664 μs .. 5.375 μs) variance introduced by outliers: 54% (severely inflated) benchmarking feedFetch/4 KiB time 85.69 μs (83.81 μs .. 87.30 μs) 0.997 R² (0.996 R² .. 0.999 R²) mean 86.99 μs (86.11 μs .. 88.15 μs) std dev 3.608 μs (2.854 μs .. 5.273 μs) variance introduced by outliers: 43% (moderately inflated) benchmarking feedFetch/4 MiB time 1.582 ms (1.509 ms .. 1.677 ms) 0.970 R² (0.936 R² .. 0.993 R²) mean 1.645 ms (1.554 ms .. 1.981 ms) std dev 490.6 μs (138.9 μs .. 1.067 ms) variance introduced by outliers: 97% (severely inflated)
2016-12-15 03:53:06 +01:00
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
import Test.Framework (defaultMain, Test)
2016-10-24 21:26:42 +02:00
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), Assertion)
import qualified Data.Vector as V
-- | Test that a file can be read and the GraphDef proto correctly parsed.
testReadMessageFromFileOrDie :: Test
2016-10-24 21:26:42 +02:00
testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
-- Check the function on a known well-formatted file.
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
-- Simple field read.
1 @=? mnist^.version
-- Count the number of nodes.
let nodes :: [NodeDef]
nodes = mnist^.node
100 @=? length nodes
-- Check that the expected op is found at an arbitrary index.
"Variable" @=? nodes!!6^.op
-- | Parse the test set for label and image data. Will only fail if the file is
-- missing or incredibly corrupt.
testReadMNIST :: Test
2016-10-24 21:26:42 +02:00
testReadMNIST = testCase "testReadMNIST" $ do
imageData <- readMNISTSamples =<< testImageData
10000 @=? length imageData
labelData <- readMNISTLabels =<< testLabelData
10000 @=? length labelData
testNodeName :: Text -> Tensor Build a -> Assertion
2016-10-24 21:26:42 +02:00
testNodeName n g = n @=? opName
where
opName = head (gDef^.node)^.op
gDef = asGraphDef $ render g
testGraphDefGen :: Test
2016-10-24 21:26:42 +02:00
testGraphDefGen = testCase "testGraphDefGen" $ do
-- Test the inferred operation type.
let f0 :: Tensor Build Float
2016-10-24 21:26:42 +02:00
f0 = 0
testNodeName "Const" f0
testNodeName "Add" $ 1 + f0
testNodeName "Mul" $ 1 * f0
testNodeName "Sub" $ 1 - f0
testNodeName "Abs" $ abs f0
testNodeName "Sign" $ signum f0
testNodeName "Neg" $ -f0
-- Test the grouping.
testNodeName "Add" $ 1 + f0 * 2
testNodeName "Add" $ 1 + (f0 * 2)
testNodeName "Mul" $ (1 + f0) * 2
-- | Convert a simple graph to GraphDef, load it, run it, and check the output.
testGraphDefExec :: Test
2016-10-24 21:26:42 +02:00
testGraphDefExec = testCase "testGraphDefExec" $ do
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
runSession $ do
addGraphDef graphDef
x <- run $ tensorValueFromName "Mul_2"
2016-10-24 21:26:42 +02:00
liftIO $ (50 :: Float) @=? unScalar x
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
-- sample data.
testMNISTExec :: Test
2016-10-24 21:26:42 +02:00
testMNISTExec = testCase "testMNISTExec" $ do
-- Switch to unicode to enable pretty printing of MNIST digits.
IO.hSetEncoding IO.stdout IO.utf8
-- Parse the Graph definition, samples, & labels from files.
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
mnistSamples <- readMNISTSamples =<< testImageData
mnistLabels <- readMNISTLabels =<< testLabelData
-- Select a sample to run on and convert it into a TensorData of Floats.
let idx = 12
sample :: MNIST
sample = mnistSamples !! idx
label = mnistLabels !! idx
tensorSample = encodeTensorData (Shape [1,784]) floatSample
where
floatSample :: V.Vector Float
floatSample = V.map fromIntegral sample
Text.putStrLn $ drawMNIST sample
-- Execute the graph on the sample data.
runSession $ do
-- The version of this session is 0, but the version of the graph is 1.
-- Change the graph version to 0 so they're compatible.
build $ addGraphDef $ mnist & version .~ 0
-- Define nodes that restore saved weights and biases.
let bias, wts :: Tensor Ref Float
bias = tensorFromName "Variable"
wts = tensorFromName "weights"
2016-10-24 21:26:42 +02:00
wtsCkptPath <- liftIO wtsCkpt
biasCkptPath <- liftIO biasCkpt
-- Run those restoring nodes on the graph in the current session.
run_ =<< (sequence :: Monad m => [m a] -> m [a])
2016-10-24 21:26:42 +02:00
[ restore wtsCkptPath wts
, restoreFromName biasCkptPath "bias" bias
]
-- Encode the expected sample data as one-hot data.
let ty = encodeTensorData [10] oneHotLabels
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
updates = [(fromIntegral label, 1)]
let feeds = [ feed (tensorValueFromName "x-input") tensorSample
, feed (tensorValueFromName "y-input") ty
2016-10-24 21:26:42 +02:00
]
-- Run the graph with the input feeds and read the ArgMax'd result from
-- the test (not training) side of the evaluation.
x <- runWithFeeds feeds $ tensorValueFromName "test/ArgMax"
2016-10-24 21:26:42 +02:00
-- Print the trained model's predicted outcome.
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
++ "Prediction: " ++ show (unScalar x :: Int64)
-- Check whether the prediction matches the expectation.
liftIO $ (fromInteger . toInteger $ label :: Int64) @=? unScalar x
main :: IO ()
main = defaultMain
[ testReadMessageFromFileOrDie
, testReadMNIST
, testGraphDefGen
, testGraphDefExec
, testMNISTExec ]