mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Use native oneHot op in the example code
It didn't exist when this was originally written.
This commit is contained in:
parent
ea8b62e47b
commit
d53332f1ae
|
@ -41,23 +41,8 @@ randomParam width (TF.Shape shape) =
|
|||
where
|
||||
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||
|
||||
-- Types must match due to model structure (sparseToDense requires
|
||||
-- index types to match)
|
||||
-- Types must match due to model structure.
|
||||
type LabelType = Int32
|
||||
type BatchSize = Int32
|
||||
|
||||
-- | Convert scalar labels to one-hot vectors.
|
||||
labelClasses :: TF.Tensor TF.Value LabelType
|
||||
-> LabelType
|
||||
-> BatchSize
|
||||
-> TF.Tensor TF.Value Float
|
||||
labelClasses labels numClasses batchSize =
|
||||
let indices = TF.range 0 (TF.scalar batchSize) 1
|
||||
concated = TF.concat 1 [TF.expandDims indices 1, TF.expandDims labels 1]
|
||||
in TF.sparseToDense concated
|
||||
(TF.constant [2] [batchSize, numClasses])
|
||||
1 {- ON value -}
|
||||
0 {- default (OFF) value -}
|
||||
|
||||
-- | Fraction of elements that differ between two vectors.
|
||||
errorRate :: Eq a => V.Vector a -> V.Vector a -> Double
|
||||
|
@ -95,7 +80,7 @@ createModel batchSize = do
|
|||
|
||||
-- Create training action.
|
||||
labels <- TF.placeholder [batchSize]
|
||||
let labelVecs = labelClasses labels 10 (fromIntegral batchSize)
|
||||
let labelVecs = TF.oneHot labels 10 1 0
|
||||
loss = fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
||||
grads <- TF.gradients loss params
|
||||
|
@ -147,7 +132,7 @@ main = TF.runSession $ do
|
|||
|
||||
-- Test.
|
||||
let numTestBatches = length testImages `div` fromIntegral batchSize
|
||||
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i -> do
|
||||
testPreds <- fmap mconcat $ forM [0..numTestBatches] $ \i ->
|
||||
infer model (getImageBatch i testImages)
|
||||
let testExpected = fromIntegral <$> V.fromList testLabels
|
||||
liftIO $ putStrLn $
|
||||
|
|
|
@ -69,6 +69,7 @@ module TensorFlow.Ops
|
|||
, initializedVariable
|
||||
, zeroInitializedVariable
|
||||
, CoreOps.fill
|
||||
, CoreOps.oneHot
|
||||
, CoreOps.matMul
|
||||
, matTranspose
|
||||
, CoreOps.mul
|
||||
|
|
Loading…
Reference in New Issue
Block a user