mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Misc MNIST example cleanup (#9)
* Use native oneHot op in the example code. It didn't exist when this was originally written. * Misc cleanup in MNIST example - Use unspecified dimension for batch size in model. This simplifies the code for the test set. - Move error rate calculation into model.
This commit is contained in:
parent
54eddcc6bd
commit
03a3a6d086
2 changed files with 53 additions and 59 deletions
|
@ -12,21 +12,24 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
|
||||
import Control.Monad (zipWithM, when, forM, forM_)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
import qualified Data.Text.IO as T
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
|
||||
import TensorFlow.Examples.MNIST.InputData
|
||||
import TensorFlow.Examples.MNIST.Parse
|
||||
|
@ -41,30 +44,10 @@ randomParam width (TF.Shape shape) =
|
|||
where
|
||||
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||
|
||||
-- Types must match due to model structure (sparseToDense requires
|
||||
-- index types to match)
|
||||
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
|
||||
|
||||
-- Types must match due to model structure.
|
||||
type LabelType = Int32
|
||||
type BatchSize = Int32
|
||||
|
||||
-- | Convert scalar labels to one-hot vectors.
|
||||
labelClasses :: TF.Tensor TF.Value LabelType
|
||||
-> LabelType
|
||||
-> BatchSize
|
||||
-> TF.Tensor TF.Value Float
|
||||
labelClasses labels numClasses batchSize =
|
||||
let indices = TF.range 0 (TF.scalar batchSize) 1
|
||||
concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1]
|
||||
in TF.sparseToDense concated
|
||||
(TF.constant [2] [batchSize, numClasses])
|
||||
1 {- ON value -}
|
||||
0 {- default (OFF) value -}
|
||||
|
||||
-- | Fraction of elements that differ between two vectors.
|
||||
errorRate :: Eq a => V.Vector a -> V.Vector a -> Double
|
||||
errorRate xs ys = fromIntegral (len - numCorrect) / fromIntegral len
|
||||
where
|
||||
numCorrect = V.length $ V.filter id $ V.zipWith (==) xs ys
|
||||
len = V.length xs
|
||||
|
||||
data Model = Model {
|
||||
train :: TF.TensorData Float -- ^ images
|
||||
|
@ -72,10 +55,15 @@ data Model = Model {
|
|||
-> TF.Session ()
|
||||
, infer :: TF.TensorData Float -- ^ images
|
||||
-> TF.Session (V.Vector LabelType) -- ^ predictions
|
||||
, errorRate :: TF.TensorData Float -- ^ images
|
||||
-> TF.TensorData LabelType
|
||||
-> TF.Session Float
|
||||
}
|
||||
|
||||
createModel :: Int64 -> TF.Build Model
|
||||
createModel batchSize = do
|
||||
createModel :: TF.Build Model
|
||||
createModel = do
|
||||
-- Use -1 batch size to support variable sized batches.
|
||||
let batchSize = -1
|
||||
-- Inputs.
|
||||
images <- TF.placeholder [batchSize, numPixels]
|
||||
-- Hidden layer.
|
||||
|
@ -95,22 +83,29 @@ createModel batchSize = do
|
|||
|
||||
-- Create training action.
|
||||
labels <- TF.placeholder [batchSize]
|
||||
let labelVecs = labelClasses labels 10 (fromIntegral batchSize)
|
||||
loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||
let labelVecs = TF.oneHot labels (fromIntegral numLabels) 1 0
|
||||
loss =
|
||||
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
||||
grads <- TF.gradients loss params
|
||||
|
||||
let lr = TF.scalar $ 0.001 / fromIntegral batchSize
|
||||
applyGrad param grad
|
||||
= TF.assign param $ param `TF.sub` (lr * grad)
|
||||
let lr = TF.scalar 0.00001
|
||||
applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad)
|
||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
||||
|
||||
let correctPredictions = TF.equal predict labels
|
||||
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)
|
||||
|
||||
return Model {
|
||||
train = \imFeed lFeed -> TF.runWithFeeds_ [
|
||||
TF.feed images imFeed
|
||||
, TF.feed labels lFeed
|
||||
] trainStep
|
||||
train = \imFeed lFeed -> TF.runWithFeeds_ [
|
||||
TF.feed images imFeed
|
||||
, TF.feed labels lFeed
|
||||
] trainStep
|
||||
, infer = \imFeed -> TF.runWithFeeds [TF.feed images imFeed] predict
|
||||
, errorRate = \imFeed lFeed -> TF.unScalar <$> TF.runWithFeeds [
|
||||
TF.feed images imFeed
|
||||
, TF.feed labels lFeed
|
||||
] errorRateTensor
|
||||
}
|
||||
|
||||
main = TF.runSession $ do
|
||||
|
@ -120,40 +115,36 @@ main = TF.runSession $ do
|
|||
testImages <- liftIO (readMNISTSamples =<< testImageData)
|
||||
testLabels <- liftIO (readMNISTLabels =<< testLabelData)
|
||||
|
||||
let batchSize = 100 :: Int64
|
||||
|
||||
-- Create the model.
|
||||
model <- TF.build $ createModel batchSize
|
||||
model <- TF.build createModel
|
||||
|
||||
-- Helpers for generate batches.
|
||||
let selectBatch i xs = take size $ drop (i * size) $ cycle xs
|
||||
where size = fromIntegral batchSize
|
||||
let getImageBatch i xs = TF.encodeTensorData
|
||||
[batchSize, numPixels]
|
||||
$ fromIntegral <$> mconcat (selectBatch i xs)
|
||||
let getExpectedLabelBatch i xs =
|
||||
fromIntegral <$> V.fromList (selectBatch i xs)
|
||||
-- Functions for generating batches.
|
||||
let encodeImageBatch xs =
|
||||
TF.encodeTensorData [genericLength xs, numPixels]
|
||||
(fromIntegral <$> mconcat xs)
|
||||
let encodeLabelBatch xs =
|
||||
TF.encodeTensorData [genericLength xs]
|
||||
(fromIntegral <$> V.fromList xs)
|
||||
let batchSize = 100
|
||||
let selectBatch i xs = take batchSize $ drop (i * batchSize) (cycle xs)
|
||||
|
||||
-- Train.
|
||||
forM_ ([0..1000] :: [Int]) $ \i -> do
|
||||
let images = getImageBatch i trainingImages
|
||||
labels = getExpectedLabelBatch i trainingLabels
|
||||
train model images (TF.encodeTensorData [batchSize] labels)
|
||||
let images = encodeImageBatch (selectBatch i trainingImages)
|
||||
labels = encodeLabelBatch (selectBatch i trainingLabels)
|
||||
train model images labels
|
||||
when (i `mod` 100 == 0) $ do
|
||||
preds <- infer model images
|
||||
liftIO $ putStrLn $
|
||||
"training error " ++ show (errorRate preds labels * 100)
|
||||
err <- errorRate model images labels
|
||||
liftIO $ putStrLn $ "training error " ++ show (err * 100)
|
||||
liftIO $ putStrLn ""
|
||||
|
||||
-- Test.
|
||||
let numTestBatches = length testImages `div` fromIntegral batchSize
|
||||
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do
|
||||
infer model (getImageBatch i testImages)
|
||||
let testExpected = fromIntegral <$> V.fromList testLabels
|
||||
liftIO $ putStrLn $
|
||||
"test error " ++ show (errorRate testPreds testExpected * 100)
|
||||
testErr <- errorRate model (encodeImageBatch testImages)
|
||||
(encodeLabelBatch testLabels)
|
||||
liftIO $ putStrLn $ "test error " ++ show (testErr * 100)
|
||||
|
||||
-- Show some predictions.
|
||||
testPreds <- infer model (encodeImageBatch testImages)
|
||||
liftIO $ forM_ ([0..3] :: [Int]) $ \i -> do
|
||||
putStrLn ""
|
||||
T.putStrLn $ drawMNIST $ testImages !! i
|
||||
|
|
|
@ -65,12 +65,15 @@ module TensorFlow.Ops
|
|||
, CoreOps.cast
|
||||
, CoreOps.concat
|
||||
, constant
|
||||
, CoreOps.equal
|
||||
, expandDims
|
||||
, initializedVariable
|
||||
, zeroInitializedVariable
|
||||
, CoreOps.fill
|
||||
, CoreOps.oneHot
|
||||
, CoreOps.matMul
|
||||
, matTranspose
|
||||
, CoreOps.mean
|
||||
, CoreOps.mul
|
||||
, CoreOps.neg
|
||||
, CoreOps.pack
|
||||
|
|
Loading…
Reference in a new issue