mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-24 18:49:46 +01:00
MNIST Main compiles but is broken.
This commit is contained in:
parent
2b5e41ffeb
commit
1677c346eb
11 changed files with 69 additions and 48 deletions
|
@ -15,7 +15,7 @@
|
|||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
|
||||
import Control.Monad (zipWithM, when, forM_)
|
||||
import Control.Monad (forM_, when)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
|
@ -30,6 +30,7 @@ import qualified TensorFlow.Ops as TF
|
|||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
import TensorFlow.Examples.MNIST.InputData
|
||||
import TensorFlow.Examples.MNIST.Parse
|
||||
|
@ -64,6 +65,7 @@ data Model = Model {
|
|||
|
||||
createModel :: TF.Build Model
|
||||
createModel = do
|
||||
let rd = CoreOps.readVariableOp
|
||||
-- Use -1 batch size to support variable sized batches.
|
||||
let batchSize = -1
|
||||
-- Inputs.
|
||||
|
@ -73,13 +75,14 @@ createModel = do
|
|||
hiddenWeights <-
|
||||
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
|
||||
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
|
||||
let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases
|
||||
let hiddenZ = (images `TF.matMul` rd hiddenWeights)
|
||||
`TF.add` rd hiddenBiases
|
||||
let hidden = TF.relu hiddenZ
|
||||
-- Logits.
|
||||
logitWeights <-
|
||||
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
|
||||
logitBiases <- TF.zeroInitializedVariable [numLabels]
|
||||
let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases
|
||||
let logits = (hidden `TF.matMul` rd logitWeights) `TF.add` rd logitBiases
|
||||
predict <- TF.render $ TF.cast $
|
||||
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
|
||||
|
||||
|
@ -89,11 +92,11 @@ createModel = do
|
|||
loss =
|
||||
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
||||
grads <- TF.gradients loss params
|
||||
grads <- TF.gradients loss (map rd params)
|
||||
|
||||
let lr = TF.scalar 0.00001
|
||||
applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad)
|
||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
||||
let lr = TF.scalar (-0.00001 :: Float) -- Negative to make it descend.
|
||||
applyGrad var grad = CoreOps.assignVariableOp var (lr * grad)
|
||||
trainStep <- TF.group (zipWith applyGrad params grads)
|
||||
|
||||
let correctPredictions = TF.equal predict labels
|
||||
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)
|
||||
|
|
|
@ -45,6 +45,7 @@ executable Main
|
|||
, lens-family
|
||||
, proto-lens
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-mnist
|
||||
, tensorflow-mnist-input-data
|
||||
, tensorflow-ops
|
||||
|
|
|
@ -654,6 +654,7 @@ opGrad "ZerosLike" _ _ _ = [Nothing]
|
|||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "Placeholder" _ _ _ = []
|
||||
opGrad "Variable" _ _ _ = []
|
||||
opGrad "ReadVariableOp" _ _ _ = [Nothing, Nothing]
|
||||
|
||||
opGrad n nodeDef ins grads =
|
||||
error $ "no gradient implemented for " ++
|
||||
|
@ -699,6 +700,7 @@ numOutputs o =
|
|||
"Transpose" -> 1
|
||||
"TruncatedNormal" -> 1
|
||||
"Variable" -> 1
|
||||
"ReadVariableOp" -> 2
|
||||
"ZerosLike" -> 1
|
||||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||
|
||||
|
|
|
@ -121,8 +121,7 @@ import qualified Proto.Tensorflow.Core.Framework.TensorShape
|
|||
as TensorShape
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.BuildOp
|
||||
import TensorFlow.ControlFlow (group)
|
||||
import TensorFlow.Output (unNodeName)
|
||||
import TensorFlow.Output (ResourceHandle, unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
|
||||
|
@ -176,37 +175,30 @@ assign = buildOp $ opDef "Assign"
|
|||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
initializedVariable :: forall a . TensorType a
|
||||
=> Tensor Value a -> Build (Tensor Ref a)
|
||||
=> Tensor Value a -> Build (ResourceHandle a)
|
||||
initializedVariable initializer = do
|
||||
v <- variable [] -- The shape is not known initially.
|
||||
(i :: Tensor Ref a) <-
|
||||
buildOp (opDef "Assign"
|
||||
& opAttr "T" .~ tensorType (undefined :: a)
|
||||
& opAttr "use_locking" .~ True
|
||||
& opAttr "validate_shape" .~ False
|
||||
)
|
||||
v initializer
|
||||
addInitializer =<< group i
|
||||
let v = CoreOps.varHandleOp
|
||||
& resourceHandleAttr "shape" .~ (Shape [])
|
||||
addInitializer (CoreOps.createVariableOp v initializer)
|
||||
return v
|
||||
|
||||
-- | Creates a zero-initialized variable with the given shape.
|
||||
zeroInitializedVariable
|
||||
:: (TensorType a, Num a) =>
|
||||
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
||||
:: (TensorType a, Num a) => TensorFlow.Types.Shape -> Build (ResourceHandle a)
|
||||
zeroInitializedVariable = initializedVariable . zeros
|
||||
|
||||
-- TODO: Support heterogeneous list of tensors.
|
||||
save :: forall a v . TensorType a
|
||||
save :: forall a . TensorType a
|
||||
=> ByteString -- ^ File path.
|
||||
-> [Tensor v a] -- ^ Tensors to save.
|
||||
-> [ResourceHandle a] -- ^ Tensors to save.
|
||||
-> Build ControlNode
|
||||
save path xs = do
|
||||
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
||||
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
|
||||
names <- mapM (fmap toByteStringTensor . renderResourceHandle) xs
|
||||
let types = replicate (length xs) (tensorType (undefined :: a))
|
||||
let saveOp = buildOp $ opDef "Save"
|
||||
& opAttr "T" .~ types
|
||||
saveOp (scalar path) (CoreOps.pack names) xs
|
||||
saveOp (scalar path) (CoreOps.pack names) (map CoreOps.readVariableOp xs)
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
--
|
||||
|
@ -215,21 +207,21 @@ save path xs = do
|
|||
restoreFromName :: forall a . TensorType a
|
||||
=> ByteString -- ^ File path.
|
||||
-> ByteString -- ^ Tensor name override.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> Build ControlNode
|
||||
restoreFromName path name x = do
|
||||
-> Build (Tensor Value a)
|
||||
restoreFromName path name = do
|
||||
let restoreOp = buildOp $ opDef "Restore"
|
||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||
restoreOp (scalar path) (scalar name)
|
||||
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
restore :: forall a . TensorType a
|
||||
=> ByteString -- ^ File path.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> ResourceHandle a
|
||||
-> Build ControlNode
|
||||
restore path x = do
|
||||
name <- encodeUtf8 . unNodeName <$> renderNodeName x
|
||||
restoreFromName path name x
|
||||
name <- encodeUtf8 . unNodeName <$> renderResourceHandle x
|
||||
CoreOps.assignVariableOp x <$> restoreFromName path name
|
||||
|
||||
-- | Create a constant tensor.
|
||||
--
|
||||
|
@ -253,12 +245,13 @@ constant (Shape shape') values
|
|||
typedNode :: TensorProto
|
||||
typedNode = def
|
||||
& dtype .~ nodeType
|
||||
-- Use shapeToProto from Types.hs
|
||||
& tensorShape.TensorShape.dim .~
|
||||
[def & TensorShape.size .~ x | x <- shape']
|
||||
& tensorVal .~ values
|
||||
|
||||
-- | Reshape a N-D tensor down to a scalar.
|
||||
--
|
||||
--
|
||||
-- See `TensorFlow.GenOps.Core.reshape`.
|
||||
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
|
||||
scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||
|
|
|
@ -41,6 +41,7 @@ Test-Suite BuildTest
|
|||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, test-framework
|
||||
|
|
|
@ -64,6 +64,7 @@ import Test.Framework.Providers.HUnit (testCase)
|
|||
import Test.HUnit ((@=?))
|
||||
import Google.Test (googleTest)
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
-- | Test named behavior.
|
||||
testNamed :: Test
|
||||
|
@ -96,13 +97,12 @@ testPureRender = testCase "testPureRender" $ runSession $ do
|
|||
testInitializedVariable :: Test
|
||||
testInitializedVariable =
|
||||
testCase "testInitializedVariable" $ runSession $ do
|
||||
(formula, reset) <- build $ do
|
||||
(v, formula) <- build $ do
|
||||
v <- initializedVariable 42
|
||||
r <- assign v 24
|
||||
return (1 `add` v, r)
|
||||
return (v, 1 `add` CoreOps.readVariableOp v)
|
||||
result <- run formula
|
||||
liftIO $ 43 @=? (unScalar result :: Float)
|
||||
run_ reset -- Updates v to a different value
|
||||
run_ (CoreOps.assignVariableOp v 24)
|
||||
rerunResult <- run formula
|
||||
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
||||
|
||||
|
@ -110,7 +110,7 @@ testInitializedVariableShape :: Test
|
|||
testInitializedVariableShape =
|
||||
testCase "testInitializedVariableShape" $ runSession $ do
|
||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
||||
result <- run vector
|
||||
result <- run (CoreOps.readVariableOp vector)
|
||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
|
||||
-- | Test nameScoped behavior.
|
||||
|
|
|
@ -112,8 +112,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
|
||||
x <- TF.placeholder (TF.Shape [2])
|
||||
embedding <- TF.initializedVariable
|
||||
=<< TF.render (TF.constant embShape embeddingInit)
|
||||
embedding <- CoreOps.readVariableOp <$>
|
||||
(TF.initializedVariable
|
||||
=<< TF.render (TF.constant embShape embeddingInit))
|
||||
|
||||
op <- embeddingLookup [embedding] ids
|
||||
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
||||
|
|
|
@ -28,10 +28,11 @@ import qualified Data.ByteString.Char8 as B8
|
|||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Output as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
-- | Test that one can easily determine number of elements in the tensor.
|
||||
|
@ -54,17 +55,16 @@ testSaveRestore :: Test
|
|||
testSaveRestore = testCase "testSaveRestore" $
|
||||
withSystemTempDirectory "" $ \dirPath -> do
|
||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||
var :: TF.Build (TF.Tensor TF.Ref Float)
|
||||
var = TF.render =<<
|
||||
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
||||
var :: TF.Build (TF.ResourceHandle Float)
|
||||
var = TF.zeroInitializedVariable (TF.Shape [])
|
||||
TF.runSession $ do
|
||||
v <- TF.build var
|
||||
TF.buildAnd TF.run_ $ TF.assign v 134
|
||||
TF.buildAnd TF.run_ $ TF.group $ CoreOps.assignVariableOp v 134
|
||||
TF.buildAnd TF.run_ $ TF.save path [v]
|
||||
result <- TF.runSession $ do
|
||||
v <- TF.build var
|
||||
TF.buildAnd TF.run_ $ TF.restore path v
|
||||
TF.run v
|
||||
TF.run (CoreOps.readVariableOp v)
|
||||
liftIO $ TF.Scalar 134 @=? result
|
||||
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ module TensorFlow.Build
|
|||
, GraphState
|
||||
, render
|
||||
, renderNodeName
|
||||
, renderResourceHandle
|
||||
, renderedNodeDefs
|
||||
, BuildT
|
||||
, Build
|
||||
|
@ -361,6 +362,9 @@ render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
|
|||
renderNodeName :: Tensor v a -> Build NodeName
|
||||
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
|
||||
|
||||
renderResourceHandle :: ResourceHandle a -> Build NodeName
|
||||
renderResourceHandle (ResourceHandle r) = getOrAddOp (r ^. outputOp)
|
||||
|
||||
-- | Records the given summary action in Build for retrieval with
|
||||
-- 'collectAllSummaries'. The summary op is required to produce a
|
||||
-- Summary protocol buffer in string form. For safety, use the
|
||||
|
|
|
@ -38,6 +38,7 @@ module TensorFlow.Output
|
|||
, outputOp
|
||||
, PendingNodeName(..)
|
||||
, ResourceHandle(..)
|
||||
, resourceHandleOutput
|
||||
) where
|
||||
|
||||
import qualified Data.Map.Strict as Map
|
||||
|
@ -160,4 +161,7 @@ instance IsString Output where
|
|||
-- | Opaque handle to a mutable resource in the graph. Typical such
|
||||
-- resources are variables. The type parameter corresponds to the
|
||||
-- dtype of the tensor held in the variable.
|
||||
newtype ResourceHandle a = ResourceHandle Output
|
||||
newtype ResourceHandle a = ResourceHandle { unResourceHandle :: Output }
|
||||
|
||||
resourceHandleOutput :: Lens' (ResourceHandle a) Output
|
||||
resourceHandleOutput = lens unResourceHandle (\_ x -> ResourceHandle x)
|
||||
|
|
|
@ -24,7 +24,14 @@ import qualified Data.Text as Text
|
|||
import Lens.Family2 (Lens', Traversal')
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
|
||||
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
||||
import TensorFlow.Output
|
||||
( Output
|
||||
, ResourceHandle
|
||||
, outputOp
|
||||
, opUnrendered
|
||||
, opAttr
|
||||
, resourceHandleOutput
|
||||
)
|
||||
import TensorFlow.Types (TensorData(..), Attribute)
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
||||
|
@ -61,6 +68,11 @@ tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
|||
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
|
||||
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
|
||||
|
||||
resourceHandleAttr :: Attribute attr
|
||||
=> Text.Text -> Traversal' (ResourceHandle a) attr
|
||||
resourceHandleAttr attr =
|
||||
resourceHandleOutput . outputOp . opUnrendered . opAttr attr
|
||||
|
||||
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
||||
-- Ref into Value. This behaves like a no-op.
|
||||
value :: Tensor v a -> Tensor Value a
|
||||
|
|
Loading…
Reference in a new issue