1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

Address review feedback

This commit is contained in:
fkm3 2017-05-13 16:50:38 -07:00
parent d78c0f18e9
commit 111c494fcd
4 changed files with 31 additions and 11 deletions

View File

@ -181,7 +181,7 @@ gradients y xs = build $ do
gradientMap <- graphGrads gr initPending
-- Lookup the gradients for each x.
forM xs $ \x ->
let (Output i xName) = targetOutput x
let Output i xName = targetOutput x
in maybe (render $ targetZeros x) return $ do
n <- nodeMap ^. at xName
gradientMap ^. at n . nonEmpty . outputIxAt i
@ -696,6 +696,10 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
where
rx = rangeOfRank dz
-- Treat read ops as an identity function on the variable. This allows us to
-- take gradients w.r.t. to the variable handle instead of the result of a read
-- op. If a variable is read multiple times, the gradients will propagate back
-- through each read.
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
-- TODO(fmayle): These can go away if we properly prune the graph.

View File

@ -33,14 +33,14 @@ import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (tensorNodeName, renderedOutput)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros, fill, shape)
import TensorFlow.Ops (zeros)
import TensorFlow.Gradient (GradientTarget(..))
newtype Variable a = Variable (Tensor Value ResourceHandle)
instance GradientTarget Variable where
targetOutput (Variable v) = renderedOutput v
targetZeros (Variable v) = fill (shape v) 0
targetZeros = CoreOps.zerosLike . readValue
-- | Creates a new, uninitialized variable.
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)

View File

@ -32,9 +32,10 @@ import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (max, tile)
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
import qualified TensorFlow.Output as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Variable as TF
import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Framework.NodeDef (op)
@ -223,7 +224,7 @@ matMulGradient = testCase "matMulGradients" $ do
let dfBuild = do
x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64]
w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
let f = x `TF.matMul` w :: TF.Tensor TF.Build Float
let f = x `TF.matMul` TF.readValue w :: TF.Tensor TF.Build Float
dfs <- TF.gradients f [x]
return (x, dfs)
@ -243,11 +244,11 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
let tower = do
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
let f = x `TF.matMul` w
let f = x `TF.matMul` TF.readValue w
[dfdx] <- TF.gradients f [x]
let f'x = TF.reduceSum dfdx
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
return [TF.value w, dfdw]
return [TF.readValue w, TF.expr dfdw]
TF.runSession $ do
[w, dfdw] <- TF.build tower
@ -256,12 +257,12 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
let step = w `TF.add` dfdw
w0 <- TF.run step
liftIO $ ((V.fromList [4, 4 :: Float]) @=? w0)
liftIO $ V.fromList [4, 4 :: Float] @=? w0
-- test that gradient of matMul deals correctly with transpose_a and transpose_b
matMulTransposeGradient :: (Bool, Bool) -> Test
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw)) $ do
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw) $ do
let (transposeX, transposeW) = txw
let dfBuild = do
@ -269,7 +270,7 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw
let xZeros = TF.zeros xShape
x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros
variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
let wv = if transposeW then TF.matTranspose variable else TF.expr variable
let wv = if transposeW then TF.matTranspose (TF.readValue variable) else TF.readValue variable
let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float
w <- TF.render wv
ds <- TF.gradients f [x, w]

View File

@ -10,8 +10,11 @@ import TensorFlow.Core
, run_
, runSession
, run
, withControlDependencies)
, withControlDependencies
, Shape(..)
)
import qualified TensorFlow.Ops as Ops
import TensorFlow.Gradient (targetZeros)
import TensorFlow.Variable
( readValue
, initializedVariable
@ -29,6 +32,7 @@ main = googleTest [ testInitializedVariable
, testDependency
, testRereadRef
, testAssignAdd
, testTargetZeros
]
testInitializedVariable :: Test
@ -77,3 +81,14 @@ testAssignAdd = testCase "testAssignAdd" $ runSession $ do
run_ =<< assignAdd w 17
f1 <- run (readValue w)
liftIO $ (42 + 17 :: Float) @=? unScalar f1
testTargetZeros :: Test
testTargetZeros = testCase "testTargetZeros" $ runSession $ do
do
w <- initializedVariable 42
z <- run (targetZeros w)
liftIO $ (0 :: Float) @=? unScalar z
do
w <- initializedVariable (Ops.constant (Shape [2, 3]) [1..6])
z <- run (targetZeros w)
liftIO $ (replicate 6 (0 :: Float)) @=? (V.toList z)