1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-24 18:49:46 +01:00

Support Variable in TensorFlow.Gradient and use in mnist example (#116)

This commit is contained in:
fkm3 2017-05-17 13:20:51 -07:00 committed by Judah Jacobson
parent ddb4fe4f90
commit b86945f008
3 changed files with 33 additions and 19 deletions

View file

@ -24,7 +24,8 @@ import qualified Data.Vector as V
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Ops as TF hiding (initializedVariable, zeroInitializedVariable)
import qualified TensorFlow.Variable as TF
import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse
@ -68,13 +69,15 @@ createModel = do
hiddenWeights <-
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases
let hiddenZ = (images `TF.matMul` TF.readValue hiddenWeights)
`TF.add` TF.readValue hiddenBiases
let hidden = TF.relu hiddenZ
-- Logits.
logitWeights <-
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
logitBiases <- TF.zeroInitializedVariable [numLabels]
let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases
let logits = (hidden `TF.matMul` TF.readValue logitWeights)
`TF.add` TF.readValue logitBiases
predict <- TF.render $ TF.cast $
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
@ -87,7 +90,7 @@ createModel = do
grads <- TF.gradients loss params
let lr = TF.scalar 0.00001
applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad)
applyGrad param grad = TF.assignAdd param (negate $ lr `TF.mul` grad)
trainStep <- TF.group =<< zipWithM applyGrad params grads
let correctPredictions = TF.equal predict labels

View file

@ -99,6 +99,7 @@ import TensorFlow.Tensor
, tensorNodeName
, renderedOutput
, renderValue
, ToTensor(..)
)
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef
@ -116,12 +117,13 @@ type GradientCompatible a =
-- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 m . ( MonadBuild m
, Rendered (Tensor v2)
, GradientCompatible a
)
gradients :: forall a v1 t m . ( MonadBuild m
, Rendered t
, ToTensor t
, GradientCompatible a
)
=> Tensor v1 a -- ^ The output of the graph.
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
-> [t a] -- ^ Tensors for which gradients are computed.
-> m [Tensor Value a]
gradients y xs = build $ do
-- The gradients are computed using "reverse accumulation", similarly to
@ -171,10 +173,9 @@ gradients y xs = build $ do
gradientMap <- graphGrads gr initPending
-- Lookup the gradients for each x.
forM xs $ \x ->
let xName = tensorNodeName x
in maybe (render $ zerosLike x) return $ do
let Output i xName = renderedOutput x
in maybe (render $ zerosLike $ toTensor x) return $ do
n <- nodeMap ^. at xName
let i = outputIndex $ renderedOutput x
gradientMap ^. at n . nonEmpty . outputIxAt i
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
@ -687,9 +688,16 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
where
rx = rangeOfRank dz
-- Treat read ops as an identity function on the variable. This allows us to
-- take gradients w.r.t. to the variable handle instead of the result of a read
-- op. If a variable is read multiple times, the gradients will propagate back
-- through each read.
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
-- TODO(fmayle): These can go away if we properly prune the graph.
opGrad "Const" _ _ _ = [Nothing, Nothing]
opGrad "Placeholder" _ _ _ = []
opGrad "VarHandleOp" _ _ _ = []
opGrad "Variable" _ _ _ = []
opGrad n nodeDef ins grads =
@ -723,6 +731,7 @@ numOutputs o =
"Neg" -> 1
"Placeholder" -> 1
"OneHot" -> 1
"ReadVariableOp" -> 1
"RefIdentity" -> 1
"Relu" -> 1
"ReluGrad" -> 1
@ -737,6 +746,7 @@ numOutputs o =
"Tile" -> 1
"Transpose" -> 1
"TruncatedNormal" -> 1
"VarHandleOp" -> 1
"Variable" -> 1
"ZerosLike" -> 1
"Fill" -> 1

View file

@ -31,9 +31,10 @@ import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (max, tile)
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
import qualified TensorFlow.Output as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Variable as TF
import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Framework.NodeDef (op)
@ -222,7 +223,7 @@ matMulGradient = testCase "matMulGradients" $ do
let dfBuild = do
x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64]
w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
let f = x `TF.matMul` w :: TF.Tensor TF.Build Float
let f = x `TF.matMul` TF.readValue w :: TF.Tensor TF.Build Float
dfs <- TF.gradients f [x]
return (x, dfs)
@ -242,11 +243,11 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
let tower = do
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
let f = x `TF.matMul` w
let f = x `TF.matMul` TF.readValue w
[dfdx] <- TF.gradients f [x]
let f'x = TF.reduceSum dfdx
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
return [TF.value w, dfdw]
return [TF.readValue w, TF.expr dfdw]
TF.runSession $ do
[w, dfdw] <- TF.build tower
@ -255,12 +256,12 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
let step = w `TF.add` dfdw
w0 <- TF.run step
liftIO $ ((V.fromList [4, 4 :: Float]) @=? w0)
liftIO $ V.fromList [4, 4 :: Float] @=? w0
-- test that gradient of matMul deals correctly with transpose_a and transpose_b
matMulTransposeGradient :: (Bool, Bool) -> Test
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw)) $ do
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw) $ do
let (transposeX, transposeW) = txw
let dfBuild = do
@ -268,7 +269,7 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw
let xZeros = TF.zeros xShape
x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros
variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
let wv = if transposeW then TF.matTranspose variable else TF.expr variable
let wv = if transposeW then TF.matTranspose (TF.readValue variable) else TF.readValue variable
let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float
w <- TF.render wv
ds <- TF.gradients f [x, w]