diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 8f9c44c..612eb42 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -551,13 +551,13 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] = , Just $ matMul' (transAttrs True False) x dz] (False, True) -> [ Just $ matMul dz y - , Just $ matMul' (transAttrs True False) x dz] + , Just $ matMul' (transAttrs True False) dz x] (True, False) -> - [ Just $ matMul' (transAttrs False True) dz y + [ Just $ matMul' (transAttrs False True) y dz , Just $ matMul x dz] (True, True) -> - [ Just $ matMul' (transAttrs True True) dz y - , Just $ matMul' (transAttrs True True) x dz] + [ Just $ matMul' (transAttrs True True) y dz + , Just $ matMul' (transAttrs True True) dz x] opGrad "Transpose" _ [_, toT -> p] [dz] = [ Just $ CoreOps.transpose dz diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 40cedfc..e52b00f 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -318,21 +318,15 @@ scalarize t = CoreOps.reshape t (vector scalarShape) -- | Sum a tensor down to a scalar -- Seee `TensorFlow.GenOps.Core.sum` -reduceSum - :: ( TensorType a - , OneOf '[ Double, Float, Int32, Int64 - , Complex Float, Complex Double] a - ) - => Tensor v a -> Tensor Build a +reduceSum :: (OneOf '[ Double, Float, Int32, Int64 + , Complex Float, Complex Double] a) => + Tensor v a -> Tensor Build a reduceSum x = CoreOps.sum x allAxes where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1 -reduceSum' - :: ( TensorType a - , OneOf '[ Double, Float, Int32, Int64 - , Complex Float, Complex Double] a - ) - => OpParams -> Tensor v a -> Tensor Build a +reduceSum' :: (OneOf '[ Double, Float, Int32, Int64 + , Complex Float, Complex Double] a) => + OpParams -> Tensor v a -> Tensor Build a reduceSum' params x = CoreOps.sum' params x allAxes where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1 diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index 0a11db8..07d851a 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -200,6 +200,7 @@ Test-Suite GradientTest , tensorflow-proto , test-framework , test-framework-hunit + , transformers , vector Test-Suite MiscTest diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 56fc1d4..cb9222b 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -15,21 +15,26 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE NoMonomorphismRestriction #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE FlexibleContexts #-} -import Data.Int (Int32) +import Data.Int (Int32, Int64) import Data.List (sort) import Data.ProtoLens.TextFormat (showMessage) import Google.Test (googleTest) -import Lens.Family2 ((^..)) +import Lens.Family2 ((^..), (.~)) + import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) -import Test.HUnit ((@=?)) +import Test.HUnit ((@=?), assertEqual) import qualified Data.Vector as V +import Control.Monad.IO.Class (liftIO) import qualified TensorFlow.Core as TF import qualified TensorFlow.GenOps.Core as TF (max, tile) import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Ops as TF +import qualified TensorFlow.Output as TF +import qualified TensorFlow.Types as TF import Proto.Tensorflow.Core.Framework.Graph (node) import Proto.Tensorflow.Core.Framework.NodeDef (op) @@ -207,15 +212,85 @@ testTile2DGrad = testCase "testTileGrad2D" $ do let y = TF.tile x multiples [dx] <- TF.gradients y [x] - - shapeDX <- TF.run $ TF.shape dx - shapeX <- TF.run $ TF.shape x - dxv <- TF.run dx - return (dxv, shapeDX, shapeX) + TF.run (dx, TF.shape dx, TF.shape x) shapeX @=? (shapeDX :: V.Vector Int32) V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float) +matMulGradient :: Test +matMulGradient = testCase "matMulGradients" $ do + + let dfBuild = do + x <- TF.render $ TF.zeros $ TF.Shape [3, 1 :: Int64] + w <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64] + let f = x `TF.matMul` w :: TF.Tensor TF.Build Float + dfs <- TF.gradients f [x] + return (x, dfs) + + (xShape, dxShape) <- TF.runSession $ do + (x, [dx]) <- TF.build dfBuild + TF.run (TF.shape x, TF.shape dx) + + assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32) + + +-- test that gradient of matMul can be taken gradient of +matMulGradGrad :: Test +matMulGradGrad = testCase "matMulGradGrad" $ do + let width = 2 :: Int64 + batch = 4 :: Int64 + + let tower = do + x <- TF.render $ TF.zeros $ TF.Shape [batch, 1] + w <- TF.zeroInitializedVariable $ TF.Shape [1, width] + let f = x `TF.matMul` w + [dfdx] <- TF.gradients f [x] + let f'x = TF.reduceSum dfdx + [dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w) + return [TF.value w, dfdw] + + TF.runSession $ do + [w, dfdw] <- TF.build tower + (wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw) + liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32) + + let step = w `TF.add` dfdw + w0 <- TF.run step + liftIO $ ((V.fromList [4, 4 :: Float]) @=? w0) + + +-- test that gradient of matMul deals correctly with transpose_a and transpose_b +matMulTransposeGradient :: (Bool, Bool) -> Test +matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ (show txw)) $ do + let (transposeX, transposeW) = txw + + let dfBuild = do + let xShape = TF.Shape [3, 1 :: Int64] + let xZeros = TF.zeros xShape + x <- TF.render $ if transposeX then TF.matTranspose xZeros else xZeros + variable <- TF.zeroInitializedVariable $ TF.Shape [1, 2 :: Int64] + let wv = if transposeW then TF.matTranspose variable else TF.expr variable + let f = TF.matMul' (transAttrs transposeX transposeW) x wv :: TF.Tensor TF.Build Float + w <- TF.render wv + ds <- TF.gradients f [x, w] + return (x, w, ds) + + TF.runSession $ do + (x, w, [dx, dw]) <- TF.build dfBuild + xShape <- TF.run $ TF.shape x + dxShape <- TF.run $ TF.shape dx + liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32) + + wShape <- TF.run $ TF.shape w + dwShape <- TF.run $ TF.shape dw + liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32) + +transAttrs :: (TF.Attribute a, + TF.Attribute b) => + a -> b -> TF.OpDef -> TF.OpDef +transAttrs a b = + (TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b) + main :: IO () main = googleTest [ testGradientSimple , testGradientDisconnected @@ -228,4 +303,10 @@ main = googleTest [ testGradientSimple , testFillGrad , testTileGrad , testTile2DGrad + , matMulGradient + , matMulGradGrad + , matMulTransposeGradient (False, False) + , matMulTransposeGradient (False, True) + , matMulTransposeGradient (True, False) + , matMulTransposeGradient (True, True) ]