mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
PR feedback changes
This commit is contained in:
parent
68b5a3d2f9
commit
92d6d6c6a3
|
@ -18,6 +18,7 @@
|
|||
import Control.Monad (zipWithM, when, forM, forM_)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
import qualified Data.Text.IO as T
|
||||
import qualified Data.Vector as V
|
||||
|
||||
|
@ -61,8 +62,10 @@ data Model = Model {
|
|||
|
||||
createModel :: TF.Build Model
|
||||
createModel = do
|
||||
-- Use -1 batch size to support variable sized batches.
|
||||
let batchSize = -1
|
||||
-- Inputs.
|
||||
images <- TF.placeholder [-1, numPixels]
|
||||
images <- TF.placeholder [batchSize, numPixels]
|
||||
-- Hidden layer.
|
||||
let numUnits = 500
|
||||
hiddenWeights <-
|
||||
|
@ -79,7 +82,7 @@ createModel = do
|
|||
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
|
||||
|
||||
-- Create training action.
|
||||
labels <- TF.placeholder [-1]
|
||||
labels <- TF.placeholder [batchSize]
|
||||
let labelVecs = TF.oneHot labels (fromIntegral numLabels) 1 0
|
||||
loss =
|
||||
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||
|
@ -117,10 +120,10 @@ main = TF.runSession $ do
|
|||
|
||||
-- Functions for generating batches.
|
||||
let encodeImageBatch xs =
|
||||
TF.encodeTensorData [fromIntegral (length xs), numPixels]
|
||||
TF.encodeTensorData [genericLength xs, numPixels]
|
||||
(fromIntegral <$> mconcat xs)
|
||||
let encodeLabelBatch xs =
|
||||
TF.encodeTensorData [fromIntegral (length xs)]
|
||||
TF.encodeTensorData [genericLength xs]
|
||||
(fromIntegral <$> V.fromList xs)
|
||||
let batchSize = 100
|
||||
let selectBatch i xs = take batchSize $ drop (i * batchSize) (cycle xs)
|
||||
|
|
Loading…
Reference in New Issue
Block a user