MNIST Main compiles but is broken.

This commit is contained in:
Greg Steuck 2016-11-08 16:55:51 -08:00
parent 2b5e41ffeb
commit 1677c346eb
11 changed files with 69 additions and 48 deletions

View File

@ -15,7 +15,7 @@
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedLists #-}
import Control.Monad (zipWithM, when, forM_) import Control.Monad (forM_, when)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.List (genericLength) import Data.List (genericLength)
@ -30,6 +30,7 @@ import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF import qualified TensorFlow.Types as TF
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Examples.MNIST.InputData import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse import TensorFlow.Examples.MNIST.Parse
@ -64,6 +65,7 @@ data Model = Model {
createModel :: TF.Build Model createModel :: TF.Build Model
createModel = do createModel = do
let rd = CoreOps.readVariableOp
-- Use -1 batch size to support variable sized batches. -- Use -1 batch size to support variable sized batches.
let batchSize = -1 let batchSize = -1
-- Inputs. -- Inputs.
@ -73,13 +75,14 @@ createModel = do
hiddenWeights <- hiddenWeights <-
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits] TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
hiddenBiases <- TF.zeroInitializedVariable [numUnits] hiddenBiases <- TF.zeroInitializedVariable [numUnits]
let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases let hiddenZ = (images `TF.matMul` rd hiddenWeights)
`TF.add` rd hiddenBiases
let hidden = TF.relu hiddenZ let hidden = TF.relu hiddenZ
-- Logits. -- Logits.
logitWeights <- logitWeights <-
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels] TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
logitBiases <- TF.zeroInitializedVariable [numLabels] logitBiases <- TF.zeroInitializedVariable [numLabels]
let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases let logits = (hidden `TF.matMul` rd logitWeights) `TF.add` rd logitBiases
predict <- TF.render $ TF.cast $ predict <- TF.render $ TF.cast $
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType)) TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
@ -89,11 +92,11 @@ createModel = do
loss = loss =
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases] params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
grads <- TF.gradients loss params grads <- TF.gradients loss (map rd params)
let lr = TF.scalar 0.00001 let lr = TF.scalar (-0.00001 :: Float) -- Negative to make it descend.
applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad) applyGrad var grad = CoreOps.assignVariableOp var (lr * grad)
trainStep <- TF.group =<< zipWithM applyGrad params grads trainStep <- TF.group (zipWith applyGrad params grads)
let correctPredictions = TF.equal predict labels let correctPredictions = TF.equal predict labels
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions) errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)

View File

@ -45,6 +45,7 @@ executable Main
, lens-family , lens-family
, proto-lens , proto-lens
, tensorflow , tensorflow
, tensorflow-core-ops
, tensorflow-mnist , tensorflow-mnist
, tensorflow-mnist-input-data , tensorflow-mnist-input-data
, tensorflow-ops , tensorflow-ops

View File

@ -654,6 +654,7 @@ opGrad "ZerosLike" _ _ _ = [Nothing]
opGrad "Const" _ _ _ = [Nothing, Nothing] opGrad "Const" _ _ _ = [Nothing, Nothing]
opGrad "Placeholder" _ _ _ = [] opGrad "Placeholder" _ _ _ = []
opGrad "Variable" _ _ _ = [] opGrad "Variable" _ _ _ = []
opGrad "ReadVariableOp" _ _ _ = [Nothing, Nothing]
opGrad n nodeDef ins grads = opGrad n nodeDef ins grads =
error $ "no gradient implemented for " ++ error $ "no gradient implemented for " ++
@ -699,6 +700,7 @@ numOutputs o =
"Transpose" -> 1 "Transpose" -> 1
"TruncatedNormal" -> 1 "TruncatedNormal" -> 1
"Variable" -> 1 "Variable" -> 1
"ReadVariableOp" -> 2
"ZerosLike" -> 1 "ZerosLike" -> 1
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op) _ -> error $ "numOuputs not implemented for " ++ show (o ^. op)

View File

@ -121,8 +121,7 @@ import qualified Proto.Tensorflow.Core.Framework.TensorShape
as TensorShape as TensorShape
import TensorFlow.Build import TensorFlow.Build
import TensorFlow.BuildOp import TensorFlow.BuildOp
import TensorFlow.ControlFlow (group) import TensorFlow.Output (ResourceHandle, unNodeName)
import TensorFlow.Output (unNodeName)
import TensorFlow.Tensor import TensorFlow.Tensor
import TensorFlow.Types import TensorFlow.Types
@ -176,37 +175,30 @@ assign = buildOp $ opDef "Assign"
-- | Creates a variable initialized to the given value. -- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs. -- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a initializedVariable :: forall a . TensorType a
=> Tensor Value a -> Build (Tensor Ref a) => Tensor Value a -> Build (ResourceHandle a)
initializedVariable initializer = do initializedVariable initializer = do
v <- variable [] -- The shape is not known initially. let v = CoreOps.varHandleOp
(i :: Tensor Ref a) <- & resourceHandleAttr "shape" .~ (Shape [])
buildOp (opDef "Assign" addInitializer (CoreOps.createVariableOp v initializer)
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
& opAttr "validate_shape" .~ False
)
v initializer
addInitializer =<< group i
return v return v
-- | Creates a zero-initialized variable with the given shape. -- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable zeroInitializedVariable
:: (TensorType a, Num a) => :: (TensorType a, Num a) => TensorFlow.Types.Shape -> Build (ResourceHandle a)
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros zeroInitializedVariable = initializedVariable . zeros
-- TODO: Support heterogeneous list of tensors. -- TODO: Support heterogeneous list of tensors.
save :: forall a v . TensorType a save :: forall a . TensorType a
=> ByteString -- ^ File path. => ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save. -> [ResourceHandle a] -- ^ Tensors to save.
-> Build ControlNode -> Build ControlNode
save path xs = do save path xs = do
let toByteStringTensor = scalar . encodeUtf8 . unNodeName let toByteStringTensor = scalar . encodeUtf8 . unNodeName
names <- mapM (fmap toByteStringTensor . renderNodeName) xs names <- mapM (fmap toByteStringTensor . renderResourceHandle) xs
let types = replicate (length xs) (tensorType (undefined :: a)) let types = replicate (length xs) (tensorType (undefined :: a))
let saveOp = buildOp $ opDef "Save" let saveOp = buildOp $ opDef "Save"
& opAttr "T" .~ types & opAttr "T" .~ types
saveOp (scalar path) (CoreOps.pack names) xs saveOp (scalar path) (CoreOps.pack names) (map CoreOps.readVariableOp xs)
-- | Restore a tensor's value from a checkpoint file. -- | Restore a tensor's value from a checkpoint file.
-- --
@ -215,21 +207,21 @@ save path xs = do
restoreFromName :: forall a . TensorType a restoreFromName :: forall a . TensorType a
=> ByteString -- ^ File path. => ByteString -- ^ File path.
-> ByteString -- ^ Tensor name override. -> ByteString -- ^ Tensor name override.
-> Tensor Ref a -- ^ Tensor to restore. -> Build (Tensor Value a)
-> Build ControlNode restoreFromName path name = do
restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore" let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a) & opAttr "dt" .~ tensorType (undefined :: a)
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a) restoreOp (scalar path) (scalar name)
-- | Restore a tensor's value from a checkpoint file. -- | Restore a tensor's value from a checkpoint file.
restore :: forall a . TensorType a restore :: forall a . TensorType a
=> ByteString -- ^ File path. => ByteString -- ^ File path.
-> Tensor Ref a -- ^ Tensor to restore. -> ResourceHandle a
-> Build ControlNode -> Build ControlNode
restore path x = do restore path x = do
name <- encodeUtf8 . unNodeName <$> renderNodeName x name <- encodeUtf8 . unNodeName <$> renderResourceHandle x
restoreFromName path name x CoreOps.assignVariableOp x <$> restoreFromName path name
-- | Create a constant tensor. -- | Create a constant tensor.
-- --
@ -253,12 +245,13 @@ constant (Shape shape') values
typedNode :: TensorProto typedNode :: TensorProto
typedNode = def typedNode = def
& dtype .~ nodeType & dtype .~ nodeType
-- Use shapeToProto from Types.hs
& tensorShape.TensorShape.dim .~ & tensorShape.TensorShape.dim .~
[def & TensorShape.size .~ x | x <- shape'] [def & TensorShape.size .~ x | x <- shape']
& tensorVal .~ values & tensorVal .~ values
-- | Reshape a N-D tensor down to a scalar. -- | Reshape a N-D tensor down to a scalar.
-- --
-- See `TensorFlow.GenOps.Core.reshape`. -- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
scalarize t = CoreOps.reshape t (vector scalarShape) scalarize t = CoreOps.reshape t (vector scalarShape)

View File

@ -41,6 +41,7 @@ Test-Suite BuildTest
, lens-family , lens-family
, google-shim , google-shim
, tensorflow , tensorflow
, tensorflow-core-ops
, tensorflow-ops , tensorflow-ops
, tensorflow-proto , tensorflow-proto
, test-framework , test-framework

View File

@ -64,6 +64,7 @@ import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
import Google.Test (googleTest) import Google.Test (googleTest)
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.GenOps.Core as CoreOps
-- | Test named behavior. -- | Test named behavior.
testNamed :: Test testNamed :: Test
@ -96,13 +97,12 @@ testPureRender = testCase "testPureRender" $ runSession $ do
testInitializedVariable :: Test testInitializedVariable :: Test
testInitializedVariable = testInitializedVariable =
testCase "testInitializedVariable" $ runSession $ do testCase "testInitializedVariable" $ runSession $ do
(formula, reset) <- build $ do (v, formula) <- build $ do
v <- initializedVariable 42 v <- initializedVariable 42
r <- assign v 24 return (v, 1 `add` CoreOps.readVariableOp v)
return (1 `add` v, r)
result <- run formula result <- run formula
liftIO $ 43 @=? (unScalar result :: Float) liftIO $ 43 @=? (unScalar result :: Float)
run_ reset -- Updates v to a different value run_ (CoreOps.assignVariableOp v 24)
rerunResult <- run formula rerunResult <- run formula
liftIO $ 25 @=? (unScalar rerunResult :: Float) liftIO $ 25 @=? (unScalar rerunResult :: Float)
@ -110,7 +110,7 @@ testInitializedVariableShape :: Test
testInitializedVariableShape = testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do testCase "testInitializedVariableShape" $ runSession $ do
vector <- build $ initializedVariable (constant [1] [42 :: Float]) vector <- build $ initializedVariable (constant [1] [42 :: Float])
result <- run vector result <- run (CoreOps.readVariableOp vector)
liftIO $ [42] @=? (result :: V.Vector Float) liftIO $ [42] @=? (result :: V.Vector Float)
-- | Test nameScoped behavior. -- | Test nameScoped behavior.

View File

@ -112,8 +112,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
let ids = TF.constant (TF.Shape [1, 2]) idValues let ids = TF.constant (TF.Shape [1, 2]) idValues
x <- TF.placeholder (TF.Shape [2]) x <- TF.placeholder (TF.Shape [2])
embedding <- TF.initializedVariable embedding <- CoreOps.readVariableOp <$>
=<< TF.render (TF.constant embShape embeddingInit) (TF.initializedVariable
=<< TF.render (TF.constant embShape embeddingInit))
op <- embeddingLookup [embedding] ids op <- embeddingLookup [embedding] ids
let twoNorm = CoreOps.square $ TF.abs (op - x) let twoNorm = CoreOps.square $ TF.abs (op - x)

View File

@ -28,10 +28,11 @@ import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.Build as TF import qualified TensorFlow.Build as TF
import qualified TensorFlow.ControlFlow as TF import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.GenOps.Core as CoreOps
import qualified TensorFlow.Nodes as TF import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Output as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF import qualified TensorFlow.Types as TF
-- | Test that one can easily determine number of elements in the tensor. -- | Test that one can easily determine number of elements in the tensor.
@ -54,17 +55,16 @@ testSaveRestore :: Test
testSaveRestore = testCase "testSaveRestore" $ testSaveRestore = testCase "testSaveRestore" $
withSystemTempDirectory "" $ \dirPath -> do withSystemTempDirectory "" $ \dirPath -> do
let path = B8.pack $ dirPath ++ "/checkpoint" let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.Build (TF.Tensor TF.Ref Float) var :: TF.Build (TF.ResourceHandle Float)
var = TF.render =<< var = TF.zeroInitializedVariable (TF.Shape [])
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
TF.runSession $ do TF.runSession $ do
v <- TF.build var v <- TF.build var
TF.buildAnd TF.run_ $ TF.assign v 134 TF.buildAnd TF.run_ $ TF.group $ CoreOps.assignVariableOp v 134
TF.buildAnd TF.run_ $ TF.save path [v] TF.buildAnd TF.run_ $ TF.save path [v]
result <- TF.runSession $ do result <- TF.runSession $ do
v <- TF.build var v <- TF.build var
TF.buildAnd TF.run_ $ TF.restore path v TF.buildAnd TF.run_ $ TF.restore path v
TF.run v TF.run (CoreOps.readVariableOp v)
liftIO $ TF.Scalar 134 @=? result liftIO $ TF.Scalar 134 @=? result

View File

@ -34,6 +34,7 @@ module TensorFlow.Build
, GraphState , GraphState
, render , render
, renderNodeName , renderNodeName
, renderResourceHandle
, renderedNodeDefs , renderedNodeDefs
, BuildT , BuildT
, Build , Build
@ -361,6 +362,9 @@ render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
renderNodeName :: Tensor v a -> Build NodeName renderNodeName :: Tensor v a -> Build NodeName
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp) renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
renderResourceHandle :: ResourceHandle a -> Build NodeName
renderResourceHandle (ResourceHandle r) = getOrAddOp (r ^. outputOp)
-- | Records the given summary action in Build for retrieval with -- | Records the given summary action in Build for retrieval with
-- 'collectAllSummaries'. The summary op is required to produce a -- 'collectAllSummaries'. The summary op is required to produce a
-- Summary protocol buffer in string form. For safety, use the -- Summary protocol buffer in string form. For safety, use the

View File

@ -38,6 +38,7 @@ module TensorFlow.Output
, outputOp , outputOp
, PendingNodeName(..) , PendingNodeName(..)
, ResourceHandle(..) , ResourceHandle(..)
, resourceHandleOutput
) where ) where
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
@ -160,4 +161,7 @@ instance IsString Output where
-- | Opaque handle to a mutable resource in the graph. Typical such -- | Opaque handle to a mutable resource in the graph. Typical such
-- resources are variables. The type parameter corresponds to the -- resources are variables. The type parameter corresponds to the
-- dtype of the tensor held in the variable. -- dtype of the tensor held in the variable.
newtype ResourceHandle a = ResourceHandle Output newtype ResourceHandle a = ResourceHandle { unResourceHandle :: Output }
resourceHandleOutput :: Lens' (ResourceHandle a) Output
resourceHandleOutput = lens unResourceHandle (\_ x -> ResourceHandle x)

View File

@ -24,7 +24,14 @@ import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal') import Lens.Family2 (Lens', Traversal')
import Lens.Family2.Unchecked (lens) import Lens.Family2.Unchecked (lens)
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr) import TensorFlow.Output
( Output
, ResourceHandle
, outputOp
, opUnrendered
, opAttr
, resourceHandleOutput
)
import TensorFlow.Types (TensorData(..), Attribute) import TensorFlow.Types (TensorData(..), Attribute)
import qualified TensorFlow.Internal.FFI as FFI import qualified TensorFlow.Internal.FFI as FFI
@ -61,6 +68,11 @@ tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
resourceHandleAttr :: Attribute attr
=> Text.Text -> Traversal' (ResourceHandle a) attr
resourceHandleAttr attr =
resourceHandleOutput . outputOp . opUnrendered . opAttr attr
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a -- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
-- Ref into Value. This behaves like a no-op. -- Ref into Value. This behaves like a no-op.
value :: Tensor v a -> Tensor Value a value :: Tensor v a -> Tensor Value a