1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

PR feedback changes

This commit is contained in:
Frederick Mayle 2016-10-26 10:46:27 -07:00
parent 68b5a3d2f9
commit 92d6d6c6a3

View File

@ -18,6 +18,7 @@
import Control.Monad (zipWithM, when, forM, forM_)
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64)
import Data.List (genericLength)
import qualified Data.Text.IO as T
import qualified Data.Vector as V
@ -61,8 +62,10 @@ data Model = Model {
createModel :: TF.Build Model
createModel = do
-- Use -1 batch size to support variable sized batches.
let batchSize = -1
-- Inputs.
images <- TF.placeholder [-1, numPixels]
images <- TF.placeholder [batchSize, numPixels]
-- Hidden layer.
let numUnits = 500
hiddenWeights <-
@ -79,7 +82,7 @@ createModel = do
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
-- Create training action.
labels <- TF.placeholder [-1]
labels <- TF.placeholder [batchSize]
let labelVecs = TF.oneHot labels (fromIntegral numLabels) 1 0
loss =
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
@ -117,10 +120,10 @@ main = TF.runSession $ do
-- Functions for generating batches.
let encodeImageBatch xs =
TF.encodeTensorData [fromIntegral (length xs), numPixels]
TF.encodeTensorData [genericLength xs, numPixels]
(fromIntegral <$> mconcat xs)
let encodeLabelBatch xs =
TF.encodeTensorData [fromIntegral (length xs)]
TF.encodeTensorData [genericLength xs]
(fromIntegral <$> V.fromList xs)
let batchSize = 100
let selectBatch i xs = take batchSize $ drop (i * batchSize) (cycle xs)