mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Support Variable in TensorFlow.Gradient and use in mnist example (#116)
This commit is contained in:
parent
ddb4fe4f90
commit
b86945f008
3 changed files with 33 additions and 19 deletions
|
@ -24,7 +24,8 @@ import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF hiding (initializedVariable, zeroInitializedVariable)
|
||||||
|
import qualified TensorFlow.Variable as TF
|
||||||
|
|
||||||
import TensorFlow.Examples.MNIST.InputData
|
import TensorFlow.Examples.MNIST.InputData
|
||||||
import TensorFlow.Examples.MNIST.Parse
|
import TensorFlow.Examples.MNIST.Parse
|
||||||
|
@ -68,13 +69,15 @@ createModel = do
|
||||||
hiddenWeights <-
|
hiddenWeights <-
|
||||||
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
|
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
|
||||||
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
|
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
|
||||||
let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases
|
let hiddenZ = (images `TF.matMul` TF.readValue hiddenWeights)
|
||||||
|
`TF.add` TF.readValue hiddenBiases
|
||||||
let hidden = TF.relu hiddenZ
|
let hidden = TF.relu hiddenZ
|
||||||
-- Logits.
|
-- Logits.
|
||||||
logitWeights <-
|
logitWeights <-
|
||||||
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
|
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
|
||||||
logitBiases <- TF.zeroInitializedVariable [numLabels]
|
logitBiases <- TF.zeroInitializedVariable [numLabels]
|
||||||
let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases
|
let logits = (hidden `TF.matMul` TF.readValue logitWeights)
|
||||||
|
`TF.add` TF.readValue logitBiases
|
||||||
predict <- TF.render $ TF.cast $
|
predict <- TF.render $ TF.cast $
|
||||||
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
|
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
|
||||||
|
|
||||||
|
@ -87,7 +90,7 @@ createModel = do
|
||||||
grads <- TF.gradients loss params
|
grads <- TF.gradients loss params
|
||||||
|
|
||||||
let lr = TF.scalar 0.00001
|
let lr = TF.scalar 0.00001
|
||||||
applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad)
|
applyGrad param grad = TF.assignAdd param (negate $ lr `TF.mul` grad)
|
||||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
||||||
|
|
||||||
let correctPredictions = TF.equal predict labels
|
let correctPredictions = TF.equal predict labels
|
||||||
|
|
|
@ -99,6 +99,7 @@ import TensorFlow.Tensor
|
||||||
, tensorNodeName
|
, tensorNodeName
|
||||||
, renderedOutput
|
, renderedOutput
|
||||||
, renderValue
|
, renderValue
|
||||||
|
, ToTensor(..)
|
||||||
)
|
)
|
||||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
@ -116,12 +117,13 @@ type GradientCompatible a =
|
||||||
|
|
||||||
|
|
||||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||||
gradients :: forall a v1 v2 m . ( MonadBuild m
|
gradients :: forall a v1 t m . ( MonadBuild m
|
||||||
, Rendered (Tensor v2)
|
, Rendered t
|
||||||
|
, ToTensor t
|
||||||
, GradientCompatible a
|
, GradientCompatible a
|
||||||
)
|
)
|
||||||
=> Tensor v1 a -- ^ The output of the graph.
|
=> Tensor v1 a -- ^ The output of the graph.
|
||||||
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
-> [t a] -- ^ Tensors for which gradients are computed.
|
||||||
-> m [Tensor Value a]
|
-> m [Tensor Value a]
|
||||||
gradients y xs = build $ do
|
gradients y xs = build $ do
|
||||||
-- The gradients are computed using "reverse accumulation", similarly to
|
-- The gradients are computed using "reverse accumulation", similarly to
|
||||||
|
@ -171,10 +173,9 @@ gradients y xs = build $ do
|
||||||
gradientMap <- graphGrads gr initPending
|
gradientMap <- graphGrads gr initPending
|
||||||
-- Lookup the gradients for each x.
|
-- Lookup the gradients for each x.
|
||||||
forM xs $ \x ->
|
forM xs $ \x ->
|
||||||
let xName = tensorNodeName x
|
let Output i xName = renderedOutput x
|
||||||
in maybe (render $ zerosLike x) return $ do
|
in maybe (render $ zerosLike $ toTensor x) return $ do
|
||||||
n <- nodeMap ^. at xName
|
n <- nodeMap ^. at xName
|
||||||
let i = outputIndex $ renderedOutput x
|
|
||||||
gradientMap ^. at n . nonEmpty . outputIxAt i
|
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||||
|
|
||||||
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
||||||
|
@ -687,9 +688,16 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
||||||
where
|
where
|
||||||
rx = rangeOfRank dz
|
rx = rangeOfRank dz
|
||||||
|
|
||||||
|
-- Treat read ops as an identity function on the variable. This allows us to
|
||||||
|
-- take gradients w.r.t. to the variable handle instead of the result of a read
|
||||||
|
-- op. If a variable is read multiple times, the gradients will propagate back
|
||||||
|
-- through each read.
|
||||||
|
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
||||||
|
|
||||||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||||
opGrad "Placeholder" _ _ _ = []
|
opGrad "Placeholder" _ _ _ = []
|
||||||
|
opGrad "VarHandleOp" _ _ _ = []
|
||||||
opGrad "Variable" _ _ _ = []
|
opGrad "Variable" _ _ _ = []
|
||||||
|
|
||||||
opGrad n nodeDef ins grads =
|
opGrad n nodeDef ins grads =
|
||||||
|
@ -723,6 +731,7 @@ numOutputs o =
|
||||||
"Neg" -> 1
|
"Neg" -> 1
|
||||||
"Placeholder" -> 1
|
"Placeholder" -> 1
|
||||||
"OneHot" -> 1
|
"OneHot" -> 1
|
||||||
|
"ReadVariableOp" -> 1
|
||||||
"RefIdentity" -> 1
|
"RefIdentity" -> 1
|
||||||
"Relu" -> 1
|
"Relu" -> 1
|
||||||
"ReluGrad" -> 1
|
"ReluGrad" -> 1
|
||||||
|
@ -737,6 +746,7 @@ numOutputs o =
|
||||||
"Tile" -> 1
|
"Tile" -> 1
|
||||||
"Transpose" -> 1
|
"Transpose" -> 1
|
||||||
"TruncatedNormal" -> 1
|
"TruncatedNormal" -> 1
|
||||||
|
"VarHandleOp" -> 1
|
||||||
"Variable" -> 1
|
"Variable" -> 1
|
||||||
"ZerosLike" -> 1
|
"ZerosLike" -> 1
|
||||||
"Fill" -> 1
|
"Fill" -> 1
|
||||||
|
|
|
@ -31,9 +31,10 @@ import Control.Monad.IO.Class (liftIO)
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF (max, tile)
|
import qualified TensorFlow.GenOps.Core as TF (max, tile)
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
||||||
import qualified TensorFlow.Output as TF
|
import qualified TensorFlow.Output as TF
|
||||||
import qualified TensorFlow.Types as TF
|
import qualified TensorFlow.Types as TF
|
||||||
|
import qualified TensorFlow.Variable as TF
|
||||||
|
|
||||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||||
|
@ -222,7 +223,7 @@ matMulGradient = testCase "matMulGradients" $ do
|
||||||
let dfBuild = do
|
let dfBuild = do
|
||||||
x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64]
|
x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64]
|
||||||
w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
||||||
let f = x `TF.matMul` w :: TF.Tensor TF.Build Float
|
let f = x `TF.matMul` TF.readValue w :: TF.Tensor TF.Build Float
|
||||||
dfs <- TF.gradients f [x]
|
dfs <- TF.gradients f [x]
|
||||||
return (x, dfs)
|
return (x, dfs)
|
||||||
|
|
||||||
|
@ -242,11 +243,11 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
|
||||||
let tower = do
|
let tower = do
|
||||||
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
|
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
|
||||||
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
|
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
|
||||||
let f = x `TF.matMul` w
|
let f = x `TF.matMul` TF.readValue w
|
||||||
[dfdx] <- TF.gradients f [x]
|
[dfdx] <- TF.gradients f [x]
|
||||||
let f'x = TF.reduceSum dfdx
|
let f'x = TF.reduceSum dfdx
|
||||||
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||||
return [TF.value w, dfdw]
|
return [TF.readValue w, TF.expr dfdw]
|
||||||
|
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
[w, dfdw] <- TF.build tower
|
[w, dfdw] <- TF.build tower
|
||||||
|
@ -255,12 +256,12 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
|
||||||
|
|
||||||
let step = w `TF.add` dfdw
|
let step = w `TF.add` dfdw
|
||||||
w0 <- TF.run step
|
w0 <- TF.run step
|
||||||
liftIO $ ((V.fromList [4, 4 :: Float]) @=? w0)
|
liftIO $ V.fromList [4, 4 :: Float] @=? w0
|
||||||
|
|
||||||
|
|
||||||
-- test that gradient of matMul deals correctly with transpose_a and transpose_b
|
-- test that gradient of matMul deals correctly with transpose_a and transpose_b
|
||||||
matMulTransposeGradient :: (Bool, Bool) -> Test
|
matMulTransposeGradient :: (Bool, Bool) -> Test
|
||||||
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw)) $ do
|
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw) $ do
|
||||||
let (transposeX, transposeW) = txw
|
let (transposeX, transposeW) = txw
|
||||||
|
|
||||||
let dfBuild = do
|
let dfBuild = do
|
||||||
|
@ -268,7 +269,7 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw
|
||||||
let xZeros = TF.zeros xShape
|
let xZeros = TF.zeros xShape
|
||||||
x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros
|
x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros
|
||||||
variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
||||||
let wv = if transposeW then TF.matTranspose variable else TF.expr variable
|
let wv = if transposeW then TF.matTranspose (TF.readValue variable) else TF.readValue variable
|
||||||
let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float
|
let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float
|
||||||
w <- TF.render wv
|
w <- TF.render wv
|
||||||
ds <- TF.gradients f [x, w]
|
ds <- TF.gradients f [x, w]
|
||||||
|
|
Loading…
Reference in a new issue