-- Copyright 2016 TensorFlow authors. -- -- Licensed under the Apache License, Version 2.0 (the "License"); -- you may not use this file except in compliance with the License. -- You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- See the License for the specific language governing permissions and -- limitations under the License. {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} module Main where import Control.Monad.IO.Class (liftIO) import Data.Int (Int64) import Data.Text (Text) import qualified Data.Text.IO as Text import Lens.Family2 ((&), (.~), (^.)) import Prelude hiding (abs) import Proto.Tensorflow.Core.Framework.Graph ( GraphDef(..) , version , node ) import Proto.Tensorflow.Core.Framework.NodeDef ( NodeDef(..) , op ) import System.IO as IO import TensorFlow.Examples.MNIST.InputData import TensorFlow.Examples.MNIST.Parse import TensorFlow.Examples.MNIST.TrainedGraph import TensorFlow.Build ( asGraphDef , addGraphDef , render ) import TensorFlow.Tensor ( Tensor(..) , Ref , Value , feed , TensorKind(..) , tensorFromName ) import TensorFlow.Ops import TensorFlow.Session (runSession, run, run_, runWithFeeds, build, buildAnd) import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar) import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) import Test.HUnit ((@=?), Assertion) import Google.Test import qualified Data.Vector as V -- | Test that a file can be read and the GraphDef proto correctly parsed. testReadMessageFromFileOrDie :: Test testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do -- Check the function on a known well-formatted file. mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef -- Simple field read. 1 @=? mnist^.version -- Count the number of nodes. let nodes :: [NodeDef] nodes = mnist^.node 100 @=? length nodes -- Check that the expected op is found at an arbitrary index. "Variable" @=? nodes!!6^.op -- | Parse the test set for label and image data. Will only fail if the file is -- missing or incredibly corrupt. testReadMNIST :: Test testReadMNIST = testCase "testReadMNIST" $ do imageData <- readMNISTSamples =<< testImageData 10000 @=? length imageData labelData <- readMNISTLabels =<< testLabelData 10000 @=? length labelData testNodeName :: Text -> Tensor v a -> Assertion testNodeName n g = n @=? opName where opName = head (gDef^.node)^.op gDef = asGraphDef $ render g testGraphDefGen :: Test testGraphDefGen = testCase "testGraphDefGen" $ do -- Test the inferred operation type. let f0 :: Tensor Value Float f0 = 0 testNodeName "Const" f0 testNodeName "Add" $ 1 + f0 testNodeName "Mul" $ 1 * f0 testNodeName "Sub" $ 1 - f0 testNodeName "Abs" $ abs f0 testNodeName "Sign" $ signum f0 testNodeName "Neg" $ -f0 -- Test the grouping. testNodeName "Add" $ 1 + f0 * 2 testNodeName "Add" $ 1 + (f0 * 2) testNodeName "Mul" $ (1 + f0) * 2 -- | Convert a simple graph to GraphDef, load it, run it, and check the output. testGraphDefExec :: Test testGraphDefExec = testCase "testGraphDefExec" $ do let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10 runSession $ do build $ addGraphDef graphDef x <- run $ tensorFromName ValueKind "Mul_2" liftIO $ (50 :: Float) @=? unScalar x -- | Load MNIST from a GraphDef and the weights from a checkpoint and run on -- sample data. testMNISTExec :: Test testMNISTExec = testCase "testMNISTExec" $ do -- Switch to unicode to enable pretty printing of MNIST digits. IO.hSetEncoding IO.stdout IO.utf8 -- Parse the Graph definition, samples, & labels from files. mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef mnistSamples <- readMNISTSamples =<< testImageData mnistLabels <- readMNISTLabels =<< testLabelData -- Select a sample to run on and convert it into a TensorData of Floats. let idx = 12 sample :: MNIST sample = mnistSamples !! idx label = mnistLabels !! idx tensorSample = encodeTensorData (Shape [1,784]) floatSample where floatSample :: V.Vector Float floatSample = V.map fromIntegral sample Text.putStrLn $ drawMNIST sample -- Execute the graph on the sample data. runSession $ do -- The version of this session is 0, but the version of the graph is 1. -- Change the graph version to 0 so they're compatible. build $ addGraphDef $ mnist & version .~ 0 -- Define nodes that restore saved weights and biases. let bias, wts :: Tensor Ref Float bias = tensorFromName RefKind "Variable" wts = tensorFromName RefKind "weights" wtsCkptPath <- liftIO wtsCkpt biasCkptPath <- liftIO biasCkpt -- Run those restoring nodes on the graph in the current session. buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a]) [ restore wtsCkptPath wts , restoreFromName biasCkptPath "bias" bias ] -- Encode the expected sample data as one-hot data. let ty = encodeTensorData [10] oneHotLabels where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates updates = [(fromIntegral label, 1)] let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample , feed (tensorFromName ValueKind "y-input") ty ] -- Run the graph with the input feeds and read the ArgMax'd result from -- the test (not training) side of the evaluation. x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax" -- Print the trained model's predicted outcome. liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n" ++ "Prediction: " ++ show (unScalar x :: Int64) -- Check whether the prediction matches the expectation. liftIO $ (fromInteger . toInteger $ label :: Int64) @=? unScalar x main :: IO () main = googleTest [ testReadMessageFromFileOrDie , testReadMNIST , testGraphDefGen , testGraphDefExec , testMNISTExec]