mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
170 lines
6.3 KiB
Haskell
170 lines
6.3 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE OverloadedLists #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
module Main where
|
|
|
|
import Control.Monad.IO.Class (liftIO)
|
|
import Data.Int (Int64)
|
|
import Data.Text (Text)
|
|
import qualified Data.Text.IO as Text
|
|
import Lens.Family2 ((&), (.~), (^.))
|
|
import Prelude hiding (abs)
|
|
import Proto.Tensorflow.Core.Framework.Graph
|
|
( GraphDef(..)
|
|
, version
|
|
, node )
|
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
|
( NodeDef(..)
|
|
, op )
|
|
import System.IO as IO
|
|
import TensorFlow.Examples.MNIST.InputData
|
|
import TensorFlow.Examples.MNIST.Parse
|
|
import TensorFlow.Examples.MNIST.TrainedGraph
|
|
import TensorFlow.Build
|
|
( asGraphDef
|
|
, addGraphDef
|
|
, render
|
|
)
|
|
import TensorFlow.Tensor
|
|
( Tensor(..)
|
|
, Ref
|
|
, Value
|
|
, feed
|
|
, TensorKind(..)
|
|
, tensorFromName
|
|
)
|
|
import TensorFlow.Ops
|
|
import TensorFlow.Nodes (unScalar)
|
|
import TensorFlow.Session
|
|
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
|
import TensorFlow.Types (TensorType(..), Shape(..))
|
|
import Test.Framework.Providers.HUnit (testCase)
|
|
import Test.HUnit ((@=?), Assertion)
|
|
import Google.Test
|
|
import qualified Data.Vector as V
|
|
|
|
-- | Test that a file can be read and the GraphDef proto correctly parsed.
|
|
testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
|
|
-- Check the function on a known well-formatted file.
|
|
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
|
|
-- Simple field read.
|
|
1 @=? mnist^.version
|
|
-- Count the number of nodes.
|
|
let nodes :: [NodeDef]
|
|
nodes = mnist^.node
|
|
100 @=? length nodes
|
|
-- Check that the expected op is found at an arbitrary index.
|
|
"Variable" @=? nodes!!6^.op
|
|
|
|
-- | Parse the test set for label and image data. Will only fail if the file is
|
|
-- missing or incredibly corrupt.
|
|
testReadMNIST = testCase "testReadMNIST" $ do
|
|
imageData <- readMNISTSamples =<< testImageData
|
|
10000 @=? length imageData
|
|
labelData <- readMNISTLabels =<< testLabelData
|
|
10000 @=? length labelData
|
|
|
|
testNodeName :: Text -> Tensor v a -> Assertion
|
|
testNodeName n g = n @=? opName
|
|
where
|
|
opName = head (gDef^.node)^.op
|
|
gDef = asGraphDef $ render g
|
|
|
|
testGraphDefGen = testCase "testGraphDefGen" $ do
|
|
-- Test the inferred operation type.
|
|
let f0 :: Tensor Value Float
|
|
f0 = 0
|
|
testNodeName "Const" f0
|
|
testNodeName "Add" $ 1 + f0
|
|
testNodeName "Mul" $ 1 * f0
|
|
testNodeName "Sub" $ 1 - f0
|
|
testNodeName "Abs" $ abs f0
|
|
testNodeName "Sign" $ signum f0
|
|
testNodeName "Neg" $ -f0
|
|
-- Test the grouping.
|
|
testNodeName "Add" $ 1 + f0 * 2
|
|
testNodeName "Add" $ 1 + (f0 * 2)
|
|
testNodeName "Mul" $ (1 + f0) * 2
|
|
|
|
-- | Convert a simple graph to GraphDef, load it, run it, and check the output.
|
|
testGraphDefExec = testCase "testGraphDefExec" $ do
|
|
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
|
runSession $ do
|
|
build $ addGraphDef graphDef
|
|
x <- run $ tensorFromName ValueKind "Mul_2"
|
|
liftIO $ (50 :: Float) @=? unScalar x
|
|
|
|
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
|
|
-- sample data.
|
|
testMNISTExec = testCase "testMNISTExec" $ do
|
|
-- Switch to unicode to enable pretty printing of MNIST digits.
|
|
IO.hSetEncoding IO.stdout IO.utf8
|
|
|
|
-- Parse the Graph definition, samples, & labels from files.
|
|
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
|
|
mnistSamples <- readMNISTSamples =<< testImageData
|
|
mnistLabels <- readMNISTLabels =<< testLabelData
|
|
|
|
-- Select a sample to run on and convert it into a TensorData of Floats.
|
|
let idx = 12
|
|
sample :: MNIST
|
|
sample = mnistSamples !! idx
|
|
label = mnistLabels !! idx
|
|
tensorSample = encodeTensorData (Shape [1,784]) floatSample
|
|
where
|
|
floatSample :: V.Vector Float
|
|
floatSample = V.map fromIntegral sample
|
|
Text.putStrLn $ drawMNIST sample
|
|
|
|
-- Execute the graph on the sample data.
|
|
runSession $ do
|
|
-- The version of this session is 0, but the version of the graph is 1.
|
|
-- Change the graph version to 0 so they're compatible.
|
|
build $ addGraphDef $ mnist & version .~ 0
|
|
-- Define nodes that restore saved weights and biases.
|
|
let bias, wts :: Tensor Ref Float
|
|
bias = tensorFromName RefKind "Variable"
|
|
wts = tensorFromName RefKind "weights"
|
|
wtsCkptPath <- liftIO wtsCkpt
|
|
biasCkptPath <- liftIO biasCkpt
|
|
-- Run those restoring nodes on the graph in the current session.
|
|
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
|
|
[ restore wtsCkptPath wts
|
|
, restoreFromName biasCkptPath "bias" bias
|
|
]
|
|
-- Encode the expected sample data as one-hot data.
|
|
let ty = encodeTensorData [10] oneHotLabels
|
|
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
|
|
updates = [(fromIntegral label, 1)]
|
|
let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample
|
|
, feed (tensorFromName ValueKind "y-input") ty
|
|
]
|
|
-- Run the graph with the input feeds and read the ArgMax'd result from
|
|
-- the test (not training) side of the evaluation.
|
|
x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax"
|
|
-- Print the trained model's predicted outcome.
|
|
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
|
|
++ "Prediction: " ++ show (unScalar x :: Int64)
|
|
-- Check whether the prediction matches the expectation.
|
|
liftIO $ (fromInteger . toInteger $ label :: Int64) @=? unScalar x
|
|
|
|
main :: IO ()
|
|
main = googleTest [ testReadMessageFromFileOrDie
|
|
, testReadMNIST
|
|
, testGraphDefGen
|
|
, testGraphDefExec
|
|
, testMNISTExec]
|