Fix MonadFail-related errors to support ghc 8.8

This commit is contained in:
Mike Sperber 2020-04-14 01:48:43 +02:00 committed by GitHub
parent 739f6618f4
commit 0f322b2e06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 12 deletions

View File

@ -90,7 +90,10 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params initializer = do
-- The shape is not known initially.
(Variable h Nothing :: Variable a) <- variableInternal params Nothing
variables <- variableInternal params Nothing
h <- pure $ case variables of
(Variable h Nothing :: Variable a) -> h
_ -> error "variableInternal is empty"
initializer' <- renderValue initializer
i <- CoreOps.assignVariableOp h initializer'
addInitializer =<< group i

View File

@ -161,7 +161,8 @@ instance Arbitrary a => Arbitrary (LookupExample a) where
let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64)
doubleMaxDim :: Double
doubleMaxDim = 100 ** (1 / fromIntegral rank)
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
shape <- vectorOf rank (choose (1, maxDim))
let firstDim = head shape
values <- vectorOf (fromIntegral $ product shape) arbitrary
numParts <- choose (2, 15)
indSize <- choose (0, fromIntegral $ firstDim - 1)

View File

@ -26,7 +26,7 @@ import Test.Framework (defaultMain, Test)
import Lens.Family2 ((^..), (.~))
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), assertEqual)
import Test.HUnit ((@=?), assertEqual, assertFailure)
import qualified Data.Vector as V
import System.Random (randomIO, randomRIO)
import Control.Monad(forM_, replicateM, zipWithM)
@ -557,13 +557,19 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
let f = x `TF.matMul` TF.readValue w
[dfdx] <- TF.gradients f [x]
l1 <- TF.gradients f [x]
let dfdx = head l1 -- avoid MonadFail
let f'x = TF.reduceSum dfdx
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
l2 <- TF.gradients f'x [w] -- take gradient again (this time over w)
let dfdw = head l2
return [TF.readValue w, TF.expr dfdw]
TF.runSession $ do
[w, dfdw] <- TF.build tower
l <- TF.build tower
(w, dfdw) <-
case l of
[w, dfdw] -> pure (w, dfdw)
_ -> liftIO $ assertFailure "pattern-match failure in matMulGradMad"
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
@ -589,7 +595,11 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw)
return (x, w, ds)
TF.runSession $ do
(x, w, [dx, dw]) <- TF.build dfBuild
(x, w, d) <- TF.build dfBuild
(dx, dw) <-
case d of
[dx, dw] -> pure (dx, dw)
_ -> liftIO $ assertFailure "pattern-match failure in matMulTransposeGradient"
xShape <- TF.run $ TF.shape x
dxShape <- TF.run $ TF.shape dx
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
@ -616,7 +626,11 @@ batchMatMulGradient = testCase "batchMatMulGradients" $ do
return (x, dfs)
(xShape, dxShape) <- TF.runSession $ do
(x, [dx]) <- TF.build dfBuild
(x, dl) <- TF.build dfBuild
dx <-
case dl of
[dx] -> pure dx
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradient"
TF.run (TF.shape x, TF.shape dx)
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
@ -633,13 +647,19 @@ batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
let f = x `TF.batchMatMul` TF.readValue w
[dfdx] <- TF.gradients f [x]
l1 <- TF.gradients f [x]
let dfdx = head l1
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
l2 <- TF.gradients f'x [w] -- take gradient again (this time over w)
let dfdw = head l2
return [TF.readValue w, TF.expr dfdw]
TF.runSession $ do
[w, dfdw] <- TF.build tower
l <- TF.build tower
(w, dfdw) <-
case l of
[w, dfdw] -> pure (w, dfdw)
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradGrad"
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
@ -665,7 +685,11 @@ batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ sho
return (x, w, ds)
TF.runSession $ do
(x, w, [dx, dw]) <- TF.build dfBuild
(x, w, d) <- TF.build dfBuild
(dx, dw) <-
case d of
[dx, dw] -> pure (dx, dw)
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulAdjointGradient"
xShape <- TF.run $ TF.shape x
dxShape <- TF.run $ TF.shape dx
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)