mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-27 11:15:03 +01:00
Fix MonadFail-related errors to support ghc 8.8
This commit is contained in:
parent
739f6618f4
commit
0f322b2e06
3 changed files with 40 additions and 12 deletions
|
@ -90,7 +90,10 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
|
|||
=> OpParams -> Tensor v a -> m (Variable a)
|
||||
initializedVariable' params initializer = do
|
||||
-- The shape is not known initially.
|
||||
(Variable h Nothing :: Variable a) <- variableInternal params Nothing
|
||||
variables <- variableInternal params Nothing
|
||||
h <- pure $ case variables of
|
||||
(Variable h Nothing :: Variable a) -> h
|
||||
_ -> error "variableInternal is empty"
|
||||
initializer' <- renderValue initializer
|
||||
i <- CoreOps.assignVariableOp h initializer'
|
||||
addInitializer =<< group i
|
||||
|
|
|
@ -161,7 +161,8 @@ instance Arbitrary a => Arbitrary (LookupExample a) where
|
|||
let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64)
|
||||
doubleMaxDim :: Double
|
||||
doubleMaxDim = 100 ** (1 / fromIntegral rank)
|
||||
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
|
||||
shape <- vectorOf rank (choose (1, maxDim))
|
||||
let firstDim = head shape
|
||||
values <- vectorOf (fromIntegral $ product shape) arbitrary
|
||||
numParts <- choose (2, 15)
|
||||
indSize <- choose (0, fromIntegral $ firstDim - 1)
|
||||
|
|
|
@ -26,7 +26,7 @@ import Test.Framework (defaultMain, Test)
|
|||
import Lens.Family2 ((^..), (.~))
|
||||
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?), assertEqual)
|
||||
import Test.HUnit ((@=?), assertEqual, assertFailure)
|
||||
import qualified Data.Vector as V
|
||||
import System.Random (randomIO, randomRIO)
|
||||
import Control.Monad(forM_, replicateM, zipWithM)
|
||||
|
@ -557,13 +557,19 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
|
|||
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
|
||||
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
|
||||
let f = x `TF.matMul` TF.readValue w
|
||||
[dfdx] <- TF.gradients f [x]
|
||||
l1 <- TF.gradients f [x]
|
||||
let dfdx = head l1 -- avoid MonadFail
|
||||
let f'x = TF.reduceSum dfdx
|
||||
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||
l2 <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||
let dfdw = head l2
|
||||
return [TF.readValue w, TF.expr dfdw]
|
||||
|
||||
TF.runSession $ do
|
||||
[w, dfdw] <- TF.build tower
|
||||
l <- TF.build tower
|
||||
(w, dfdw) <-
|
||||
case l of
|
||||
[w, dfdw] -> pure (w, dfdw)
|
||||
_ -> liftIO $ assertFailure "pattern-match failure in matMulGradMad"
|
||||
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
||||
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
||||
|
||||
|
@ -589,7 +595,11 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw)
|
|||
return (x, w, ds)
|
||||
|
||||
TF.runSession $ do
|
||||
(x, w, [dx, dw]) <- TF.build dfBuild
|
||||
(x, w, d) <- TF.build dfBuild
|
||||
(dx, dw) <-
|
||||
case d of
|
||||
[dx, dw] -> pure (dx, dw)
|
||||
_ -> liftIO $ assertFailure "pattern-match failure in matMulTransposeGradient"
|
||||
xShape <- TF.run $ TF.shape x
|
||||
dxShape <- TF.run $ TF.shape dx
|
||||
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
||||
|
@ -616,7 +626,11 @@ batchMatMulGradient = testCase "batchMatMulGradients" $ do
|
|||
return (x, dfs)
|
||||
|
||||
(xShape, dxShape) <- TF.runSession $ do
|
||||
(x, [dx]) <- TF.build dfBuild
|
||||
(x, dl) <- TF.build dfBuild
|
||||
dx <-
|
||||
case dl of
|
||||
[dx] -> pure dx
|
||||
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradient"
|
||||
TF.run (TF.shape x, TF.shape dx)
|
||||
|
||||
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
|
||||
|
@ -633,13 +647,19 @@ batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
|
|||
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
|
||||
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
|
||||
let f = x `TF.batchMatMul` TF.readValue w
|
||||
[dfdx] <- TF.gradients f [x]
|
||||
l1 <- TF.gradients f [x]
|
||||
let dfdx = head l1
|
||||
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
|
||||
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||
l2 <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||
let dfdw = head l2
|
||||
return [TF.readValue w, TF.expr dfdw]
|
||||
|
||||
TF.runSession $ do
|
||||
[w, dfdw] <- TF.build tower
|
||||
l <- TF.build tower
|
||||
(w, dfdw) <-
|
||||
case l of
|
||||
[w, dfdw] -> pure (w, dfdw)
|
||||
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradGrad"
|
||||
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
||||
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
||||
|
||||
|
@ -665,7 +685,11 @@ batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ sho
|
|||
return (x, w, ds)
|
||||
|
||||
TF.runSession $ do
|
||||
(x, w, [dx, dw]) <- TF.build dfBuild
|
||||
(x, w, d) <- TF.build dfBuild
|
||||
(dx, dw) <-
|
||||
case d of
|
||||
[dx, dw] -> pure (dx, dw)
|
||||
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulAdjointGradient"
|
||||
xShape <- TF.run $ TF.shape x
|
||||
dxShape <- TF.run $ TF.shape dx
|
||||
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
||||
|
|
Loading…
Add table
Reference in a new issue