From b86945f0088fb45ca8fd9317fbeaf43911b6b78d Mon Sep 17 00:00:00 2001 From: fkm3 Date: Wed, 17 May 2017 13:20:51 -0700 Subject: [PATCH] Support Variable in TensorFlow.Gradient and use in mnist example (#116) --- tensorflow-mnist/app/Main.hs | 11 ++++++---- tensorflow-ops/src/TensorFlow/Gradient.hs | 26 ++++++++++++++++------- tensorflow-ops/tests/GradientTest.hs | 15 +++++++------ 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/tensorflow-mnist/app/Main.hs b/tensorflow-mnist/app/Main.hs index ccb032b..a2dc0d7 100644 --- a/tensorflow-mnist/app/Main.hs +++ b/tensorflow-mnist/app/Main.hs @@ -24,7 +24,8 @@ import qualified Data.Vector as V import qualified TensorFlow.Core as TF import qualified TensorFlow.Gradient as TF -import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Ops as TF hiding (initializedVariable, zeroInitializedVariable) +import qualified TensorFlow.Variable as TF import TensorFlow.Examples.MNIST.InputData import TensorFlow.Examples.MNIST.Parse @@ -68,13 +69,15 @@ createModel = do hiddenWeights <- TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits] hiddenBiases <- TF.zeroInitializedVariable [numUnits] - let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases + let hiddenZ = (images `TF.matMul` TF.readValue hiddenWeights) + `TF.add` TF.readValue hiddenBiases let hidden = TF.relu hiddenZ -- Logits. logitWeights <- TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels] logitBiases <- TF.zeroInitializedVariable [numLabels] - let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases + let logits = (hidden `TF.matMul` TF.readValue logitWeights) + `TF.add` TF.readValue logitBiases predict <- TF.render $ TF.cast $ TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType)) @@ -87,7 +90,7 @@ createModel = do grads <- TF.gradients loss params let lr = TF.scalar 0.00001 - applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad) + applyGrad param grad = TF.assignAdd param (negate $ lr `TF.mul` grad) trainStep <- TF.group =<< zipWithM applyGrad params grads let correctPredictions = TF.equal predict labels diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 491800f..4056101 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -99,6 +99,7 @@ import TensorFlow.Tensor , tensorNodeName , renderedOutput , renderValue + , ToTensor(..) ) import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens) import Proto.Tensorflow.Core.Framework.NodeDef @@ -116,12 +117,13 @@ type GradientCompatible a = -- | Gradient of @y@ w.r.t. each element of @xs@. -gradients :: forall a v1 v2 m . ( MonadBuild m - , Rendered (Tensor v2) - , GradientCompatible a - ) +gradients :: forall a v1 t m . ( MonadBuild m + , Rendered t + , ToTensor t + , GradientCompatible a + ) => Tensor v1 a -- ^ The output of the graph. - -> [Tensor v2 a] -- ^ Tensors for which gradients are computed. + -> [t a] -- ^ Tensors for which gradients are computed. -> m [Tensor Value a] gradients y xs = build $ do -- The gradients are computed using "reverse accumulation", similarly to @@ -171,10 +173,9 @@ gradients y xs = build $ do gradientMap <- graphGrads gr initPending -- Lookup the gradients for each x. forM xs $ \x -> - let xName = tensorNodeName x - in maybe (render $ zerosLike x) return $ do + let Output i xName = renderedOutput x + in maybe (render $ zerosLike $ toTensor x) return $ do n <- nodeMap ^. at xName - let i = outputIndex $ renderedOutput x gradientMap ^. at n . nonEmpty . outputIxAt i outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v) @@ -687,9 +688,16 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx] where rx = rangeOfRank dz +-- Treat read ops as an identity function on the variable. This allows us to +-- take gradients w.r.t. to the variable handle instead of the result of a read +-- op. If a variable is read multiple times, the gradients will propagate back +-- through each read. +opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz] + -- TODO(fmayle): These can go away if we properly prune the graph. opGrad "Const" _ _ _ = [Nothing, Nothing] opGrad "Placeholder" _ _ _ = [] +opGrad "VarHandleOp" _ _ _ = [] opGrad "Variable" _ _ _ = [] opGrad n nodeDef ins grads = @@ -723,6 +731,7 @@ numOutputs o = "Neg" -> 1 "Placeholder" -> 1 "OneHot" -> 1 + "ReadVariableOp" -> 1 "RefIdentity" -> 1 "Relu" -> 1 "ReluGrad" -> 1 @@ -737,6 +746,7 @@ numOutputs o = "Tile" -> 1 "Transpose" -> 1 "TruncatedNormal" -> 1 + "VarHandleOp" -> 1 "Variable" -> 1 "ZerosLike" -> 1 "Fill" -> 1 diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 5e57c15..ce3d284 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -31,9 +31,10 @@ import Control.Monad.IO.Class (liftIO) import qualified TensorFlow.Core as TF import qualified TensorFlow.GenOps.Core as TF (max, tile) import qualified TensorFlow.Gradient as TF -import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable) import qualified TensorFlow.Output as TF import qualified TensorFlow.Types as TF +import qualified TensorFlow.Variable as TF import Proto.Tensorflow.Core.Framework.Graph (node) import Proto.Tensorflow.Core.Framework.NodeDef (op) @@ -222,7 +223,7 @@ matMulGradient = testCase "matMulGradients" $ do let dfBuild = do x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64] w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64] - let f = x `TF.matMul` w :: TF.Tensor TF.Build Float + let f = x `TF.matMul` TF.readValue w :: TF.Tensor TF.Build Float dfs <- TF.gradients f [x] return (x, dfs) @@ -242,11 +243,11 @@ matMulGradGrad = testCase "matMulGradGrad" $ do let tower = do x <- TF.render $ TF.zeros $ TF.Shape [batch, 1] w <- TF.zeroInitializedVariable $ TF.Shape [1, width] - let f = x `TF.matMul` w + let f = x `TF.matMul` TF.readValue w [dfdx] <- TF.gradients f [x] let f'x = TF.reduceSum dfdx [dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w) - return [TF.value w, dfdw] + return [TF.readValue w, TF.expr dfdw] TF.runSession $ do [w, dfdw] <- TF.build tower @@ -255,12 +256,12 @@ matMulGradGrad = testCase "matMulGradGrad" $ do let step = w `TF.add` dfdw w0 <- TF.run step - liftIO $ ((V.fromList [4, 4 :: Float]) @=? w0) + liftIO $ V.fromList [4, 4 :: Float] @=? w0 -- test that gradient of matMul deals correctly with transpose_a and transpose_b matMulTransposeGradient :: (Bool, Bool) -> Test -matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw)) $ do +matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw) $ do let (transposeX, transposeW) = txw let dfBuild = do @@ -268,7 +269,7 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw let xZeros = TF.zeros xShape x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64] - let wv = if transposeW then TF.matTranspose variable else TF.expr variable + let wv = if transposeW then TF.matTranspose (TF.readValue variable) else TF.readValue variable let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float w <- TF.render wv ds <- TF.gradients f [x, w]