From 03a3a6d086937cfbc1b11e8741c8425e214499e9 Mon Sep 17 00:00:00 2001 From: fkm3 Date: Wed, 26 Oct 2016 11:14:38 -0700 Subject: [PATCH] Misc MNIST example cleanup (#9) * Use native oneHot op in the example code. It didn't exist when this was originally written. * Misc cleanup in MNIST example - Use unspecified dimension for batch size in model. This simplifies the code for the test set. - Move error rate calculation into model. --- tensorflow-mnist/app/Main.hs | 109 ++++++++++++--------------- tensorflow-ops/src/TensorFlow/Ops.hs | 3 + 2 files changed, 53 insertions(+), 59 deletions(-) diff --git a/tensorflow-mnist/app/Main.hs b/tensorflow-mnist/app/Main.hs index 57d5ce0..d39e430 100644 --- a/tensorflow-mnist/app/Main.hs +++ b/tensorflow-mnist/app/Main.hs @@ -12,21 +12,24 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedLists #-} import Control.Monad (zipWithM, when, forM, forM_) import Control.Monad.IO.Class (liftIO) import Data.Int (Int32, Int64) +import Data.List (genericLength) import qualified Data.Text.IO as T import qualified Data.Vector as V -import qualified TensorFlow.ControlFlow as TF import qualified TensorFlow.Build as TF +import qualified TensorFlow.ControlFlow as TF +import qualified TensorFlow.Gradient as TF +import qualified TensorFlow.Nodes as TF import qualified TensorFlow.Ops as TF import qualified TensorFlow.Session as TF import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Types as TF -import qualified TensorFlow.Gradient as TF import TensorFlow.Examples.MNIST.InputData import TensorFlow.Examples.MNIST.Parse @@ -41,30 +44,10 @@ randomParam width (TF.Shape shape) = where stddev = TF.scalar (1 / sqrt (fromIntegral width)) --- Types must match due to model structure (sparseToDense requires --- index types to match) +reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32)) + +-- Types must match due to model structure. type LabelType = Int32 -type BatchSize = Int32 - --- | Convert scalar labels to one-hot vectors. -labelClasses :: TF.Tensor TF.Value LabelType - -> LabelType - -> BatchSize - -> TF.Tensor TF.Value Float -labelClasses labels numClasses batchSize = - let indices = TF.range 0 (TF.scalar batchSize) 1 - concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1] - in TF.sparseToDense concated - (TF.constant [2] [batchSize, numClasses]) - 1 {- ON value -} - 0 {- default (OFF) value -} - --- | Fraction of elements that differ between two vectors. -errorRate :: Eq a => V.Vector a -> V.Vector a -> Double -errorRate xs ys = fromIntegral (len - numCorrect) / fromIntegral len - where - numCorrect = V.length $ V.filter id $ V.zipWith (==) xs ys - len = V.length xs data Model = Model { train :: TF.TensorData Float -- ^ images @@ -72,10 +55,15 @@ data Model = Model { -> TF.Session () , infer :: TF.TensorData Float -- ^ images -> TF.Session (V.Vector LabelType) -- ^ predictions + , errorRate :: TF.TensorData Float -- ^ images + -> TF.TensorData LabelType + -> TF.Session Float } -createModel :: Int64 -> TF.Build Model -createModel batchSize = do +createModel :: TF.Build Model +createModel = do + -- Use -1 batch size to support variable sized batches. + let batchSize = -1 -- Inputs. images <- TF.placeholder [batchSize, numPixels] -- Hidden layer. @@ -95,22 +83,29 @@ createModel batchSize = do -- Create training action. labels <- TF.placeholder [batchSize] - let labelVecs = labelClasses labels 10 (fromIntegral batchSize) - loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs + let labelVecs = TF.oneHot labels (fromIntegral numLabels) 1 0 + loss = + reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases] grads <- TF.gradients loss params - let lr = TF.scalar $ 0.001 / fromIntegral batchSize - applyGrad param grad - = TF.assign param $ param `TF.sub` (lr * grad) + let lr = TF.scalar 0.00001 + applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad) trainStep <- TF.group =<< zipWithM applyGrad params grads + let correctPredictions = TF.equal predict labels + errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions) + return Model { - train = \imFeed lFeed -> TF.runWithFeeds_ [ - TF.feed images imFeed - , TF.feed labels lFeed - ] trainStep + train = \imFeed lFeed -> TF.runWithFeeds_ [ + TF.feed images imFeed + , TF.feed labels lFeed + ] trainStep , infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict + , errorRate = \imFeed lFeed -> TF.unScalar <$> TF.runWithFeeds [ + TF.feed images imFeed + , TF.feed labels lFeed + ] errorRateTensor } main = TF.runSession $ do @@ -120,40 +115,36 @@ main = TF.runSession $ do testImages <- liftIO (readMNISTSamples =<< testImageData) testLabels <- liftIO (readMNISTLabels =<< testLabelData) - let batchSize = 100 :: Int64 - -- Create the model. - model <- TF.build $ createModel batchSize + model <- TF.build createModel - -- Helpers for generate batches. - let selectBatch i xs = take size $ drop (i * size) $ cycle xs - where size = fromIntegral batchSize - let getImageBatch i xs = TF.encodeTensorData - [batchSize, numPixels] - $ fromIntegral <$> mconcat (selectBatch i xs) - let getExpectedLabelBatch i xs = - fromIntegral <$> V.fromList (selectBatch i xs) + -- Functions for generating batches. + let encodeImageBatch xs = + TF.encodeTensorData [genericLength xs, numPixels] + (fromIntegral <$> mconcat xs) + let encodeLabelBatch xs = + TF.encodeTensorData [genericLength xs] + (fromIntegral <$> V.fromList xs) + let batchSize = 100 + let selectBatch i xs = take batchSize $ drop (i * batchSize) (cycle xs) -- Train. forM_ ([0..1000] :: [Int]) $ \i -> do - let images = getImageBatch i trainingImages - labels = getExpectedLabelBatch i trainingLabels - train model images (TF.encodeTensorData [batchSize] labels) + let images = encodeImageBatch (selectBatch i trainingImages) + labels = encodeLabelBatch (selectBatch i trainingLabels) + train model images labels when (i `mod` 100 == 0) $ do - preds <- infer model images - liftIO $ putStrLn $ - "training error " ++ show (errorRate preds labels * 100) + err <- errorRate model images labels + liftIO $ putStrLn $ "training error " ++ show (err * 100) liftIO $ putStrLn "" -- Test. - let numTestBatches = length testImages `div` fromIntegral batchSize - testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do - infer model (getImageBatch i testImages) - let testExpected = fromIntegral <$> V.fromList testLabels - liftIO $ putStrLn $ - "test error " ++ show (errorRate testPreds testExpected * 100) + testErr <- errorRate model (encodeImageBatch testImages) + (encodeLabelBatch testLabels) + liftIO $ putStrLn $ "test error " ++ show (testErr * 100) -- Show some predictions. + testPreds <- infer model (encodeImageBatch testImages) liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do putStrLn "" T.putStrLn $ drawMNIST $ testImages !! i diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 0730363..3fff01d 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -65,12 +65,15 @@ module TensorFlow.Ops , CoreOps.cast , CoreOps.concat , constant + , CoreOps.equal , expandDims , initializedVariable , zeroInitializedVariable , CoreOps.fill + , CoreOps.oneHot , CoreOps.matMul , matTranspose + , CoreOps.mean , CoreOps.mul , CoreOps.neg , CoreOps.pack