mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
Address review feedback
This commit is contained in:
parent
d78c0f18e9
commit
111c494fcd
|
@ -181,7 +181,7 @@ gradients y xs = build $ do
|
|||
gradientMap <- graphGrads gr initPending
|
||||
-- Lookup the gradients for each x.
|
||||
forM xs $ \x ->
|
||||
let (Output i xName) = targetOutput x
|
||||
let Output i xName = targetOutput x
|
||||
in maybe (render $ targetZeros x) return $ do
|
||||
n <- nodeMap ^. at xName
|
||||
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||
|
@ -696,6 +696,10 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
|||
where
|
||||
rx = rangeOfRank dz
|
||||
|
||||
-- Treat read ops as an identity function on the variable. This allows us to
|
||||
-- take gradients w.r.t. to the variable handle instead of the result of a read
|
||||
-- op. If a variable is read multiple times, the gradients will propagate back
|
||||
-- through each read.
|
||||
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
||||
|
||||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||
|
|
|
@ -33,14 +33,14 @@ import TensorFlow.Output (opInputs, unNodeName)
|
|||
import TensorFlow.Tensor (tensorNodeName, renderedOutput)
|
||||
import TensorFlow.Types (tensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Ops (zeros, fill, shape)
|
||||
import TensorFlow.Ops (zeros)
|
||||
import TensorFlow.Gradient (GradientTarget(..))
|
||||
|
||||
newtype Variable a = Variable (Tensor Value ResourceHandle)
|
||||
|
||||
instance GradientTarget Variable where
|
||||
targetOutput (Variable v) = renderedOutput v
|
||||
targetZeros (Variable v) = fill (shape v) 0
|
||||
targetZeros = CoreOps.zerosLike . readValue
|
||||
|
||||
-- | Creates a new, uninitialized variable.
|
||||
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
||||
|
|
|
@ -32,9 +32,10 @@ import Control.Monad.IO.Class (liftIO)
|
|||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (max, tile)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
|
||||
import qualified TensorFlow.Output as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Variable as TF
|
||||
|
||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||
|
@ -223,7 +224,7 @@ matMulGradient = testCase "matMulGradients" $ do
|
|||
let dfBuild = do
|
||||
x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64]
|
||||
w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
||||
let f = x `TF.matMul` w :: TF.Tensor TF.Build Float
|
||||
let f = x `TF.matMul` TF.readValue w :: TF.Tensor TF.Build Float
|
||||
dfs <- TF.gradients f [x]
|
||||
return (x, dfs)
|
||||
|
||||
|
@ -243,11 +244,11 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
|
|||
let tower = do
|
||||
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
|
||||
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
|
||||
let f = x `TF.matMul` w
|
||||
let f = x `TF.matMul` TF.readValue w
|
||||
[dfdx] <- TF.gradients f [x]
|
||||
let f'x = TF.reduceSum dfdx
|
||||
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||
return [TF.value w, dfdw]
|
||||
return [TF.readValue w, TF.expr dfdw]
|
||||
|
||||
TF.runSession $ do
|
||||
[w, dfdw] <- TF.build tower
|
||||
|
@ -256,12 +257,12 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
|
|||
|
||||
let step = w `TF.add` dfdw
|
||||
w0 <- TF.run step
|
||||
liftIO $ ((V.fromList [4, 4 :: Float]) @=? w0)
|
||||
liftIO $ V.fromList [4, 4 :: Float] @=? w0
|
||||
|
||||
|
||||
-- test that gradient of matMul deals correctly with transpose_a and transpose_b
|
||||
matMulTransposeGradient :: (Bool, Bool) -> Test
|
||||
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw)) $ do
|
||||
matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw) $ do
|
||||
let (transposeX, transposeW) = txw
|
||||
|
||||
let dfBuild = do
|
||||
|
@ -269,7 +270,7 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw
|
|||
let xZeros = TF.zeros xShape
|
||||
x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros
|
||||
variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64]
|
||||
let wv = if transposeW then TF.matTranspose variable else TF.expr variable
|
||||
let wv = if transposeW then TF.matTranspose (TF.readValue variable) else TF.readValue variable
|
||||
let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float
|
||||
w <- TF.render wv
|
||||
ds <- TF.gradients f [x, w]
|
||||
|
|
|
@ -10,8 +10,11 @@ import TensorFlow.Core
|
|||
, run_
|
||||
, runSession
|
||||
, run
|
||||
, withControlDependencies)
|
||||
, withControlDependencies
|
||||
, Shape(..)
|
||||
)
|
||||
import qualified TensorFlow.Ops as Ops
|
||||
import TensorFlow.Gradient (targetZeros)
|
||||
import TensorFlow.Variable
|
||||
( readValue
|
||||
, initializedVariable
|
||||
|
@ -29,6 +32,7 @@ main = googleTest [ testInitializedVariable
|
|||
, testDependency
|
||||
, testRereadRef
|
||||
, testAssignAdd
|
||||
, testTargetZeros
|
||||
]
|
||||
|
||||
testInitializedVariable :: Test
|
||||
|
@ -77,3 +81,14 @@ testAssignAdd = testCase "testAssignAdd" $ runSession $ do
|
|||
run_ =<< assignAdd w 17
|
||||
f1 <- run (readValue w)
|
||||
liftIO $ (42 + 17 :: Float) @=? unScalar f1
|
||||
|
||||
testTargetZeros :: Test
|
||||
testTargetZeros = testCase "testTargetZeros" $ runSession $ do
|
||||
do
|
||||
w <- initializedVariable 42
|
||||
z <- run (targetZeros w)
|
||||
liftIO $ (0 :: Float) @=? unScalar z
|
||||
do
|
||||
w <- initializedVariable (Ops.constant (Shape [2, 3]) [1..6])
|
||||
z <- run (targetZeros w)
|
||||
liftIO $ (replicate 6 (0 :: Float)) @=? (V.toList z)
|
||||
|
|
Loading…
Reference in New Issue
Block a user