mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Support Variable in TensorFlow.Gradient and use in mnist example (#116)
This commit is contained in:
parent
ddb4fe4f90
commit
b86945f008
3 changed files with 33 additions and 19 deletions
|
@ -24,7 +24,8 @@ import qualified Data.Vector as V
|
|||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (initializedVariable, zeroInitializedVariable)
|
||||
import qualified TensorFlow.Variable as TF
|
||||
|
||||
import TensorFlow.Examples.MNIST.InputData
|
||||
import TensorFlow.Examples.MNIST.Parse
|
||||
|
@ -68,13 +69,15 @@ createModel = do
|
|||
hiddenWeights <-
|
||||
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
|
||||
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
|
||||
let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases
|
||||
let hiddenZ = (images `TF.matMul` TF.readValue hiddenWeights)
|
||||
`TF.add` TF.readValue hiddenBiases
|
||||
let hidden = TF.relu hiddenZ
|
||||
-- Logits.
|
||||
logitWeights <-
|
||||
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
|
||||
logitBiases <- TF.zeroInitializedVariable [numLabels]
|
||||
let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases
|
||||
let logits = (hidden `TF.matMul` TF.readValue logitWeights)
|
||||
`TF.add` TF.readValue logitBiases
|
||||
predict <- TF.render $ TF.cast $
|
||||
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
|
||||
|
||||
|
@ -87,7 +90,7 @@ createModel = do
|
|||
grads <- TF.gradients loss params
|
||||
|
||||
let lr = TF.scalar 0.00001
|
||||
applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad)
|
||||
applyGrad param grad = TF.assignAdd param (negate $ lr `TF.mul` grad)
|
||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
||||
|
||||
let correctPredictions = TF.equal predict labels
|
||||
|
|
|
@ -99,6 +99,7 @@ import TensorFlow.Tensor
|
|||
, tensorNodeName
|
||||
, renderedOutput
|
||||
, renderValue
|
||||
, ToTensor(..)
|
||||
)
|
||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
|
@ -116,12 +117,13 @@ type GradientCompatible a =
|
|||
|
||||
|
||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||
gradients :: forall a v1 v2 m . ( MonadBuild m
|
||||
, Rendered (Tensor v2)
|
||||
, GradientCompatible a
|
||||
)
|
||||
gradients :: forall a v1 t m . ( MonadBuild m
|
||||
, Rendered t
|
||||
, ToTensor t
|
||||
, GradientCompatible a
|
||||
)
|
||||
=> Tensor v1 a -- ^ The output of the graph.
|
||||
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
||||
-> [t a] -- ^ Tensors for which gradients are computed.
|
||||
-> m [Tensor Value a]
|
||||
gradients y xs = build $ do
|
||||
-- The gradients are computed using "reverse accumulation", similarly to
|
||||
|
@ -171,10 +173,9 @@ gradients y xs = build $ do
|
|||
gradientMap <- graphGrads gr initPending
|
||||
-- Lookup the gradients for each x.
|
||||
forM xs $ \x ->
|
||||
let xName = tensorNodeName x
|
||||
in maybe (render $ zerosLike x) return $ do
|
||||
let Output i xName = renderedOutput x
|
||||
in maybe (render $ zerosLike $ toTensor x) return $ do
|
||||
n <- nodeMap ^. at xName
|
||||
let i = outputIndex $ renderedOutput x
|
||||
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||
|
||||
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
||||
|
@ -687,9 +688,16 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
|||
where
|
||||
rx = rangeOfRank dz
|
||||
|
||||
-- Treat read ops as an identity function on the variable. This allows us to
|
||||
-- take gradients w.r.t. to the variable handle instead of the result of a read
|
||||
-- op. If a variable is read multiple times, the gradients will propagate back
|
||||
-- through each read.
|
||||
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
||||
|
||||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "Placeholder" _ _ _ = []
|
||||
opGrad "VarHandleOp" _ _ _ = []
|
||||
opGrad "Variable" _ _ _ = []
|
||||
|
||||
opGrad n nodeDef ins grads =
|
||||
|
@ -723,6 +731,7 @@ numOutputs o =
|
|||
"Neg" -> 1
|
||||
"Placeholder" -> 1
|
||||
"OneHot" -> 1
|
||||
"ReadVariableOp" -> 1
|
||||
"RefIdentity" -> 1
|
||||
"Relu" -> 1
|
||||
"ReluGrad" -> 1
|
||||
|
@ -737,6 +746,7 @@ numOutputs o =
|
|||
"Tile" -> 1
|
||||
"Transpose" -> 1
|
||||
"TruncatedNormal" -> 1
|
||||
"VarHandleOp" -> 1
|
||||
"Variable" -> 1
|
||||
"ZerosLike" -> 1
|
||||
"Fill" -> 1
|
||||
|
|
|
@ -31,9 +31,10 @@ import Control.Monad.IO.Class (liftIO)
|
|||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (max, tile)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
||||
import qualified TensorFlow.Output as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Variable as TF
|
||||
|
||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||
|
@ -222,7 +223,7 @@ matMulGradient = testCase "matMulGradients" $ do
|
|||
let dfBuild = do
|
||||
x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64]
|
||||
w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
||||
let f = x `TF.matMul` w :: TF.Tensor TF.Build Float
|
||||
let f = x `TF.matMul` TF.readValue w :: TF.Tensor TF.Build Float
|
||||
dfs <- TF.gradients f [x]
|
||||
return (x, dfs)
|
||||
|
||||
|
@ -242,11 +243,11 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
|
|||
let tower = do
|
||||
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
|
||||
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
|
||||
let f = x `TF.matMul` w
|
||||
let f = x `TF.matMul` TF.readValue w
|
||||
[dfdx] <- TF.gradients f [x]
|
||||
let f'x = TF.reduceSum dfdx
|
||||
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||
return [TF.value w, dfdw]
|
||||
return [TF.readValue w, TF.expr dfdw]
|
||||
|
||||
TF.runSession $ do
|
||||
[w, dfdw] <- TF.build tower
|
||||
|
@ -255,12 +256,12 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
|
|||
|
||||
let step = w `TF.add` dfdw
|
||||
w0 <- TF.run step
|
||||
liftIO $ ((V.fromList [4, 4 :: Float]) @=? w0)
|
||||
liftIO $ V.fromList [4, 4 :: Float] @=? w0
|
||||
|
||||
|
||||
-- test that gradient of matMul deals correctly with transpose_a and transpose_b
|
||||
matMulTransposeGradient :: (Bool, Bool) -> Test
|
||||
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw)) $ do
|
||||
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw) $ do
|
||||
let (transposeX, transposeW) = txw
|
||||
|
||||
let dfBuild = do
|
||||
|
@ -268,7 +269,7 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw
|
|||
let xZeros = TF.zeros xShape
|
||||
x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros
|
||||
variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
||||
let wv = if transposeW then TF.matTranspose variable else TF.expr variable
|
||||
let wv = if transposeW then TF.matTranspose (TF.readValue variable) else TF.readValue variable
|
||||
let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float
|
||||
w <- TF.render wv
|
||||
ds <- TF.gradients f [x, w]
|
||||
|
|
Loading…
Reference in a new issue