1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00

Use native oneHot op in the example code

It didn't exist when this was originally written.
This commit is contained in:
Frederick Mayle 2016-10-25 22:24:01 -07:00
parent ea8b62e47b
commit d53332f1ae
2 changed files with 4 additions and 18 deletions

View File

@ -41,23 +41,8 @@ randomParam width (TF.Shape shape) =
where
stddev = TF.scalar (1 / sqrt (fromIntegral width))
-- Types must match due to model structure (sparseToDense requires
-- index types to match)
-- Types must match due to model structure.
type LabelType = Int32
type BatchSize = Int32
-- | Convert scalar labels to one-hot vectors.
labelClasses :: TF.Tensor TF.Value LabelType
-> LabelType
-> BatchSize
-> TF.Tensor TF.Value Float
labelClasses labels numClasses batchSize =
let indices = TF.range 0 (TF.scalar batchSize) 1
concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1]
in TF.sparseToDense concated
(TF.constant [2] [batchSize, numClasses])
1 {- ON value -}
0 {- default (OFF) value -}
-- | Fraction of elements that differ between two vectors.
errorRate :: Eq a => V.Vector a -> V.Vector a -> Double
@ -95,7 +80,7 @@ createModel batchSize = do
-- Create training action.
labels <- TF.placeholder [batchSize]
let labelVecs = labelClasses labels 10 (fromIntegral batchSize)
let labelVecs = TF.oneHot labels 10 1 0
loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
grads <- TF.gradients loss params
@ -147,7 +132,7 @@ main = TF.runSession $ do
-- Test.
let numTestBatches = length testImages `div` fromIntegral batchSize
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i ->
infer model (getImageBatch i testImages)
let testExpected = fromIntegral <$> V.fromList testLabels
liftIO $ putStrLn $

View File

@ -69,6 +69,7 @@ module TensorFlow.Ops
, initializedVariable
, zeroInitializedVariable
, CoreOps.fill
, CoreOps.oneHot
, CoreOps.matMul
, matTranspose
, CoreOps.mul