From 0f322b2e0611cbe7011c84ba8b6cb822e4725ebc Mon Sep 17 00:00:00 2001 From: Mike Sperber Date: Tue, 14 Apr 2020 01:48:43 +0200 Subject: [PATCH] Fix MonadFail-related errors to support ghc 8.8 --- tensorflow-ops/src/TensorFlow/Variable.hs | 5 ++- tensorflow-ops/tests/EmbeddingOpsTest.hs | 3 +- tensorflow-ops/tests/GradientTest.hs | 44 +++++++++++++++++------ 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/tensorflow-ops/src/TensorFlow/Variable.hs b/tensorflow-ops/src/TensorFlow/Variable.hs index ee2fd15..239bf0c 100644 --- a/tensorflow-ops/src/TensorFlow/Variable.hs +++ b/tensorflow-ops/src/TensorFlow/Variable.hs @@ -90,7 +90,10 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a) => OpParams -> Tensor v a -> m (Variable a) initializedVariable' params initializer = do -- The shape is not known initially. - (Variable h Nothing :: Variable a) <- variableInternal params Nothing + variables <- variableInternal params Nothing + h <- pure $ case variables of + (Variable h Nothing :: Variable a) -> h + _ -> error "variableInternal is empty" initializer' <- renderValue initializer i <- CoreOps.assignVariableOp h initializer' addInitializer =<< group i diff --git a/tensorflow-ops/tests/EmbeddingOpsTest.hs b/tensorflow-ops/tests/EmbeddingOpsTest.hs index 6dcf1e5..135bde2 100644 --- a/tensorflow-ops/tests/EmbeddingOpsTest.hs +++ b/tensorflow-ops/tests/EmbeddingOpsTest.hs @@ -161,7 +161,8 @@ instance Arbitrary a => Arbitrary (LookupExample a) where let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64) doubleMaxDim :: Double doubleMaxDim = 100 ** (1 / fromIntegral rank) - shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim)) + shape <- vectorOf rank (choose (1, maxDim)) + let firstDim = head shape values <- vectorOf (fromIntegral $ product shape) arbitrary numParts <- choose (2, 15) indSize <- choose (0, fromIntegral $ firstDim - 1) diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 24cb427..7b28c40 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -26,7 +26,7 @@ import Test.Framework (defaultMain, Test) import Lens.Family2 ((^..), (.~)) import Test.Framework.Providers.HUnit (testCase) -import Test.HUnit ((@=?), assertEqual) +import Test.HUnit ((@=?), assertEqual, assertFailure) import qualified Data.Vector as V import System.Random (randomIO, randomRIO) import Control.Monad(forM_, replicateM, zipWithM) @@ -557,13 +557,19 @@ matMulGradGrad = testCase "matMulGradGrad" $ do x <- TF.render $ TF.zeros $ TF.Shape [batch, 1] w <- TF.zeroInitializedVariable $ TF.Shape [1, width] let f = x `TF.matMul` TF.readValue w - [dfdx] <- TF.gradients f [x] + l1 <- TF.gradients f [x] + let dfdx = head l1 -- avoid MonadFail let f'x = TF.reduceSum dfdx - [dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w) + l2 <- TF.gradients f'x [w] -- take gradient again (this time over w) + let dfdw = head l2 return [TF.readValue w, TF.expr dfdw] TF.runSession $ do - [w, dfdw] <- TF.build tower + l <- TF.build tower + (w, dfdw) <- + case l of + [w, dfdw] -> pure (w, dfdw) + _ -> liftIO $ assertFailure "pattern-match failure in matMulGradMad" (wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw) liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32) @@ -589,7 +595,11 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw) return (x, w, ds) TF.runSession $ do - (x, w, [dx, dw]) <- TF.build dfBuild + (x, w, d) <- TF.build dfBuild + (dx, dw) <- + case d of + [dx, dw] -> pure (dx, dw) + _ -> liftIO $ assertFailure "pattern-match failure in matMulTransposeGradient" xShape <- TF.run $ TF.shape x dxShape <- TF.run $ TF.shape dx liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32) @@ -616,7 +626,11 @@ batchMatMulGradient = testCase "batchMatMulGradients" $ do return (x, dfs) (xShape, dxShape) <- TF.runSession $ do - (x, [dx]) <- TF.build dfBuild + (x, dl) <- TF.build dfBuild + dx <- + case dl of + [dx] -> pure dx + _ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradient" TF.run (TF.shape x, TF.shape dx) assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32) @@ -633,13 +647,19 @@ batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1] w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width] let f = x `TF.batchMatMul` TF.readValue w - [dfdx] <- TF.gradients f [x] + l1 <- TF.gradients f [x] + let dfdx = head l1 let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32]) - [dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w) + l2 <- TF.gradients f'x [w] -- take gradient again (this time over w) + let dfdw = head l2 return [TF.readValue w, TF.expr dfdw] TF.runSession $ do - [w, dfdw] <- TF.build tower + l <- TF.build tower + (w, dfdw) <- + case l of + [w, dfdw] -> pure (w, dfdw) + _ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradGrad" (wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw) liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32) @@ -665,7 +685,11 @@ batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ sho return (x, w, ds) TF.runSession $ do - (x, w, [dx, dw]) <- TF.build dfBuild + (x, w, d) <- TF.build dfBuild + (dx, dw) <- + case d of + [dx, dw] -> pure (dx, dw) + _ -> liftIO $ assertFailure "pattern-match failure in batchMatMulAdjointGradient" xShape <- TF.run $ TF.shape x dxShape <- TF.run $ TF.shape dx liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)