mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Misc MNIST example cleanup (#9)
* Use native oneHot op in the example code. It didn't exist when this was originally written. * Misc cleanup in MNIST example - Use unspecified dimension for batch size in model. This simplifies the code for the test set. - Move error rate calculation into model.
This commit is contained in:
parent
54eddcc6bd
commit
03a3a6d086
2 changed files with 53 additions and 59 deletions
|
@ -12,21 +12,24 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE OverloadedLists #-}
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
|
||||||
import Control.Monad (zipWithM, when, forM, forM_)
|
import Control.Monad (zipWithM, when, forM, forM_)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
|
import Data.List (genericLength)
|
||||||
import qualified Data.Text.IO as T
|
import qualified Data.Text.IO as T
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.ControlFlow as TF
|
|
||||||
import qualified TensorFlow.Build as TF
|
import qualified TensorFlow.Build as TF
|
||||||
|
import qualified TensorFlow.ControlFlow as TF
|
||||||
|
import qualified TensorFlow.Gradient as TF
|
||||||
|
import qualified TensorFlow.Nodes as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
import qualified TensorFlow.Tensor as TF
|
||||||
import qualified TensorFlow.Types as TF
|
import qualified TensorFlow.Types as TF
|
||||||
import qualified TensorFlow.Gradient as TF
|
|
||||||
|
|
||||||
import TensorFlow.Examples.MNIST.InputData
|
import TensorFlow.Examples.MNIST.InputData
|
||||||
import TensorFlow.Examples.MNIST.Parse
|
import TensorFlow.Examples.MNIST.Parse
|
||||||
|
@ -41,30 +44,10 @@ randomParam width (TF.Shape shape) =
|
||||||
where
|
where
|
||||||
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||||
|
|
||||||
-- Types must match due to model structure (sparseToDense requires
|
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
|
||||||
-- index types to match)
|
|
||||||
|
-- Types must match due to model structure.
|
||||||
type LabelType = Int32
|
type LabelType = Int32
|
||||||
type BatchSize = Int32
|
|
||||||
|
|
||||||
-- | Convert scalar labels to one-hot vectors.
|
|
||||||
labelClasses :: TF.Tensor TF.Value LabelType
|
|
||||||
-> LabelType
|
|
||||||
-> BatchSize
|
|
||||||
-> TF.Tensor TF.Value Float
|
|
||||||
labelClasses labels numClasses batchSize =
|
|
||||||
let indices = TF.range 0 (TF.scalar batchSize) 1
|
|
||||||
concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1]
|
|
||||||
in TF.sparseToDense concated
|
|
||||||
(TF.constant [2] [batchSize, numClasses])
|
|
||||||
1 {- ON value -}
|
|
||||||
0 {- default (OFF) value -}
|
|
||||||
|
|
||||||
-- | Fraction of elements that differ between two vectors.
|
|
||||||
errorRate :: Eq a => V.Vector a -> V.Vector a -> Double
|
|
||||||
errorRate xs ys = fromIntegral (len - numCorrect) / fromIntegral len
|
|
||||||
where
|
|
||||||
numCorrect = V.length $ V.filter id $ V.zipWith (==) xs ys
|
|
||||||
len = V.length xs
|
|
||||||
|
|
||||||
data Model = Model {
|
data Model = Model {
|
||||||
train :: TF.TensorData Float -- ^ images
|
train :: TF.TensorData Float -- ^ images
|
||||||
|
@ -72,10 +55,15 @@ data Model = Model {
|
||||||
-> TF.Session ()
|
-> TF.Session ()
|
||||||
, infer :: TF.TensorData Float -- ^ images
|
, infer :: TF.TensorData Float -- ^ images
|
||||||
-> TF.Session (V.Vector LabelType) -- ^ predictions
|
-> TF.Session (V.Vector LabelType) -- ^ predictions
|
||||||
|
, errorRate :: TF.TensorData Float -- ^ images
|
||||||
|
-> TF.TensorData LabelType
|
||||||
|
-> TF.Session Float
|
||||||
}
|
}
|
||||||
|
|
||||||
createModel :: Int64 -> TF.Build Model
|
createModel :: TF.Build Model
|
||||||
createModel batchSize = do
|
createModel = do
|
||||||
|
-- Use -1 batch size to support variable sized batches.
|
||||||
|
let batchSize = -1
|
||||||
-- Inputs.
|
-- Inputs.
|
||||||
images <- TF.placeholder [batchSize, numPixels]
|
images <- TF.placeholder [batchSize, numPixels]
|
||||||
-- Hidden layer.
|
-- Hidden layer.
|
||||||
|
@ -95,22 +83,29 @@ createModel batchSize = do
|
||||||
|
|
||||||
-- Create training action.
|
-- Create training action.
|
||||||
labels <- TF.placeholder [batchSize]
|
labels <- TF.placeholder [batchSize]
|
||||||
let labelVecs = labelClasses labels 10 (fromIntegral batchSize)
|
let labelVecs = TF.oneHot labels (fromIntegral numLabels) 1 0
|
||||||
loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
loss =
|
||||||
|
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||||
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
||||||
grads <- TF.gradients loss params
|
grads <- TF.gradients loss params
|
||||||
|
|
||||||
let lr = TF.scalar $ 0.001 / fromIntegral batchSize
|
let lr = TF.scalar 0.00001
|
||||||
applyGrad param grad
|
applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad)
|
||||||
= TF.assign param $ param `TF.sub` (lr * grad)
|
|
||||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
||||||
|
|
||||||
|
let correctPredictions = TF.equal predict labels
|
||||||
|
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)
|
||||||
|
|
||||||
return Model {
|
return Model {
|
||||||
train = \imFeed lFeed -> TF.runWithFeeds_ [
|
train = \imFeed lFeed -> TF.runWithFeeds_ [
|
||||||
TF.feed images imFeed
|
TF.feed images imFeed
|
||||||
, TF.feed labels lFeed
|
, TF.feed labels lFeed
|
||||||
] trainStep
|
] trainStep
|
||||||
, infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict
|
, infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict
|
||||||
|
, errorRate = \imFeed lFeed -> TF.unScalar <$> TF.runWithFeeds [
|
||||||
|
TF.feed images imFeed
|
||||||
|
, TF.feed labels lFeed
|
||||||
|
] errorRateTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
main = TF.runSession $ do
|
main = TF.runSession $ do
|
||||||
|
@ -120,40 +115,36 @@ main = TF.runSession $ do
|
||||||
testImages <- liftIO (readMNISTSamples =<< testImageData)
|
testImages <- liftIO (readMNISTSamples =<< testImageData)
|
||||||
testLabels <- liftIO (readMNISTLabels =<< testLabelData)
|
testLabels <- liftIO (readMNISTLabels =<< testLabelData)
|
||||||
|
|
||||||
let batchSize = 100 :: Int64
|
|
||||||
|
|
||||||
-- Create the model.
|
-- Create the model.
|
||||||
model <- TF.build $ createModel batchSize
|
model <- TF.build createModel
|
||||||
|
|
||||||
-- Helpers for generate batches.
|
-- Functions for generating batches.
|
||||||
let selectBatch i xs = take size $ drop (i * size) $ cycle xs
|
let encodeImageBatch xs =
|
||||||
where size = fromIntegral batchSize
|
TF.encodeTensorData [genericLength xs, numPixels]
|
||||||
let getImageBatch i xs = TF.encodeTensorData
|
(fromIntegral <$> mconcat xs)
|
||||||
[batchSize, numPixels]
|
let encodeLabelBatch xs =
|
||||||
$ fromIntegral <$> mconcat (selectBatch i xs)
|
TF.encodeTensorData [genericLength xs]
|
||||||
let getExpectedLabelBatch i xs =
|
(fromIntegral <$> V.fromList xs)
|
||||||
fromIntegral <$> V.fromList (selectBatch i xs)
|
let batchSize = 100
|
||||||
|
let selectBatch i xs = take batchSize $ drop (i * batchSize) (cycle xs)
|
||||||
|
|
||||||
-- Train.
|
-- Train.
|
||||||
forM_ ([0..1000] :: [Int]) $ \i -> do
|
forM_ ([0..1000] :: [Int]) $ \i -> do
|
||||||
let images = getImageBatch i trainingImages
|
let images = encodeImageBatch (selectBatch i trainingImages)
|
||||||
labels = getExpectedLabelBatch i trainingLabels
|
labels = encodeLabelBatch (selectBatch i trainingLabels)
|
||||||
train model images (TF.encodeTensorData [batchSize] labels)
|
train model images labels
|
||||||
when (i `mod` 100 == 0) $ do
|
when (i `mod` 100 == 0) $ do
|
||||||
preds <- infer model images
|
err <- errorRate model images labels
|
||||||
liftIO $ putStrLn $
|
liftIO $ putStrLn $ "training error " ++ show (err * 100)
|
||||||
"training error " ++ show (errorRate preds labels * 100)
|
|
||||||
liftIO $ putStrLn ""
|
liftIO $ putStrLn ""
|
||||||
|
|
||||||
-- Test.
|
-- Test.
|
||||||
let numTestBatches = length testImages `div` fromIntegral batchSize
|
testErr <- errorRate model (encodeImageBatch testImages)
|
||||||
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do
|
(encodeLabelBatch testLabels)
|
||||||
infer model (getImageBatch i testImages)
|
liftIO $ putStrLn $ "test error " ++ show (testErr * 100)
|
||||||
let testExpected = fromIntegral <$> V.fromList testLabels
|
|
||||||
liftIO $ putStrLn $
|
|
||||||
"test error " ++ show (errorRate testPreds testExpected * 100)
|
|
||||||
|
|
||||||
-- Show some predictions.
|
-- Show some predictions.
|
||||||
|
testPreds <- infer model (encodeImageBatch testImages)
|
||||||
liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do
|
liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do
|
||||||
putStrLn ""
|
putStrLn ""
|
||||||
T.putStrLn $ drawMNIST $ testImages !! i
|
T.putStrLn $ drawMNIST $ testImages !! i
|
||||||
|
|
|
@ -65,12 +65,15 @@ module TensorFlow.Ops
|
||||||
, CoreOps.cast
|
, CoreOps.cast
|
||||||
, CoreOps.concat
|
, CoreOps.concat
|
||||||
, constant
|
, constant
|
||||||
|
, CoreOps.equal
|
||||||
, expandDims
|
, expandDims
|
||||||
, initializedVariable
|
, initializedVariable
|
||||||
, zeroInitializedVariable
|
, zeroInitializedVariable
|
||||||
, CoreOps.fill
|
, CoreOps.fill
|
||||||
|
, CoreOps.oneHot
|
||||||
, CoreOps.matMul
|
, CoreOps.matMul
|
||||||
, matTranspose
|
, matTranspose
|
||||||
|
, CoreOps.mean
|
||||||
, CoreOps.mul
|
, CoreOps.mul
|
||||||
, CoreOps.neg
|
, CoreOps.neg
|
||||||
, CoreOps.pack
|
, CoreOps.pack
|
||||||
|
|
Loading…
Reference in a new issue