Misc MNIST example cleanup (#9)

* Use native oneHot op in the example code. It didn't exist when this was originally written.
* Misc cleanup in MNIST example

- Use unspecified dimension for batch size in model. This simplifies the
  code for the test set.
- Move error rate calculation into model.
This commit is contained in:
fkm3 2016-10-26 11:14:38 -07:00 committed by Greg Steuck
parent 54eddcc6bd
commit 03a3a6d086
2 changed files with 53 additions and 59 deletions

View File

@ -12,21 +12,24 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedLists #-}
import Control.Monad (zipWithM, when, forM, forM_) import Control.Monad (zipWithM, when, forM, forM_)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.List (genericLength)
import qualified Data.Text.IO as T import qualified Data.Text.IO as T
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Build as TF import qualified TensorFlow.Build as TF
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF import qualified TensorFlow.Types as TF
import qualified TensorFlow.Gradient as TF
import TensorFlow.Examples.MNIST.InputData import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse import TensorFlow.Examples.MNIST.Parse
@ -41,30 +44,10 @@ randomParam width (TF.Shape shape) =
where where
stddev = TF.scalar (1 / sqrt (fromIntegral width)) stddev = TF.scalar (1 / sqrt (fromIntegral width))
-- Types must match due to model structure (sparseToDense requires reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
-- index types to match)
-- Types must match due to model structure.
type LabelType = Int32 type LabelType = Int32
type BatchSize = Int32
-- | Convert scalar labels to one-hot vectors.
labelClasses :: TF.Tensor TF.Value LabelType
-> LabelType
-> BatchSize
-> TF.Tensor TF.Value Float
labelClasses labels numClasses batchSize =
let indices = TF.range 0 (TF.scalar batchSize) 1
concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1]
in TF.sparseToDense concated
(TF.constant [2] [batchSize, numClasses])
1 {- ON value -}
0 {- default (OFF) value -}
-- | Fraction of elements that differ between two vectors.
errorRate :: Eq a => V.Vector a -> V.Vector a -> Double
errorRate xs ys = fromIntegral (len - numCorrect) / fromIntegral len
where
numCorrect = V.length $ V.filter id $ V.zipWith (==) xs ys
len = V.length xs
data Model = Model { data Model = Model {
train :: TF.TensorData Float -- ^ images train :: TF.TensorData Float -- ^ images
@ -72,10 +55,15 @@ data Model = Model {
-> TF.Session () -> TF.Session ()
, infer :: TF.TensorData Float -- ^ images , infer :: TF.TensorData Float -- ^ images
-> TF.Session (V.Vector LabelType) -- ^ predictions -> TF.Session (V.Vector LabelType) -- ^ predictions
, errorRate :: TF.TensorData Float -- ^ images
-> TF.TensorData LabelType
-> TF.Session Float
} }
createModel :: Int64 -> TF.Build Model createModel :: TF.Build Model
createModel batchSize = do createModel = do
-- Use -1 batch size to support variable sized batches.
let batchSize = -1
-- Inputs. -- Inputs.
images <- TF.placeholder [batchSize, numPixels] images <- TF.placeholder [batchSize, numPixels]
-- Hidden layer. -- Hidden layer.
@ -95,22 +83,29 @@ createModel batchSize = do
-- Create training action. -- Create training action.
labels <- TF.placeholder [batchSize] labels <- TF.placeholder [batchSize]
let labelVecs = labelClasses labels 10 (fromIntegral batchSize) let labelVecs = TF.oneHot labels (fromIntegral numLabels) 1 0
loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs loss =
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases] params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
grads <- TF.gradients loss params grads <- TF.gradients loss params
let lr = TF.scalar $ 0.001 / fromIntegral batchSize let lr = TF.scalar 0.00001
applyGrad param grad applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad)
= TF.assign param $ param `TF.sub` (lr * grad)
trainStep <- TF.group =<< zipWithM applyGrad params grads trainStep <- TF.group =<< zipWithM applyGrad params grads
let correctPredictions = TF.equal predict labels
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)
return Model { return Model {
train = \imFeed lFeed -> TF.runWithFeeds_ [ train = \imFeed lFeed -> TF.runWithFeeds_ [
TF.feed images imFeed TF.feed images imFeed
, TF.feed labels lFeed , TF.feed labels lFeed
] trainStep ] trainStep
, infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict , infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict
, errorRate = \imFeed lFeed -> TF.unScalar <$> TF.runWithFeeds [
TF.feed images imFeed
, TF.feed labels lFeed
] errorRateTensor
} }
main = TF.runSession $ do main = TF.runSession $ do
@ -120,40 +115,36 @@ main = TF.runSession $ do
testImages <- liftIO (readMNISTSamples =<< testImageData) testImages <- liftIO (readMNISTSamples =<< testImageData)
testLabels <- liftIO (readMNISTLabels =<< testLabelData) testLabels <- liftIO (readMNISTLabels =<< testLabelData)
let batchSize = 100 :: Int64
-- Create the model. -- Create the model.
model <- TF.build $ createModel batchSize model <- TF.build createModel
-- Helpers for generate batches. -- Functions for generating batches.
let selectBatch i xs = take size $ drop (i * size) $ cycle xs let encodeImageBatch xs =
where size = fromIntegral batchSize TF.encodeTensorData [genericLength xs, numPixels]
let getImageBatch i xs = TF.encodeTensorData (fromIntegral <$> mconcat xs)
[batchSize, numPixels] let encodeLabelBatch xs =
$ fromIntegral <$> mconcat (selectBatch i xs) TF.encodeTensorData [genericLength xs]
let getExpectedLabelBatch i xs = (fromIntegral <$> V.fromList xs)
fromIntegral <$> V.fromList (selectBatch i xs) let batchSize = 100
let selectBatch i xs = take batchSize $ drop (i * batchSize) (cycle xs)
-- Train. -- Train.
forM_ ([0..1000] :: [Int]) $ \i -> do forM_ ([0..1000] :: [Int]) $ \i -> do
let images = getImageBatch i trainingImages let images = encodeImageBatch (selectBatch i trainingImages)
labels = getExpectedLabelBatch i trainingLabels labels = encodeLabelBatch (selectBatch i trainingLabels)
train model images (TF.encodeTensorData [batchSize] labels) train model images labels
when (i `mod` 100 == 0) $ do when (i `mod` 100 == 0) $ do
preds <- infer model images err <- errorRate model images labels
liftIO $ putStrLn $ liftIO $ putStrLn $ "training error " ++ show (err * 100)
"training error " ++ show (errorRate preds labels * 100)
liftIO $ putStrLn "" liftIO $ putStrLn ""
-- Test. -- Test.
let numTestBatches = length testImages `div` fromIntegral batchSize testErr <- errorRate model (encodeImageBatch testImages)
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do (encodeLabelBatch testLabels)
infer model (getImageBatch i testImages) liftIO $ putStrLn $ "test error " ++ show (testErr * 100)
let testExpected = fromIntegral <$> V.fromList testLabels
liftIO $ putStrLn $
"test error " ++ show (errorRate testPreds testExpected * 100)
-- Show some predictions. -- Show some predictions.
testPreds <- infer model (encodeImageBatch testImages)
liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do
putStrLn "" putStrLn ""
T.putStrLn $ drawMNIST $ testImages !! i T.putStrLn $ drawMNIST $ testImages !! i

View File

@ -65,12 +65,15 @@ module TensorFlow.Ops
, CoreOps.cast , CoreOps.cast
, CoreOps.concat , CoreOps.concat
, constant , constant
, CoreOps.equal
, expandDims , expandDims
, initializedVariable , initializedVariable
, zeroInitializedVariable , zeroInitializedVariable
, CoreOps.fill , CoreOps.fill
, CoreOps.oneHot
, CoreOps.matMul , CoreOps.matMul
, matTranspose , matTranspose
, CoreOps.mean
, CoreOps.mul , CoreOps.mul
, CoreOps.neg , CoreOps.neg
, CoreOps.pack , CoreOps.pack