mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 12:59:43 +01:00
Fix MonadFail-related errors to support ghc 8.8
This commit is contained in:
parent
739f6618f4
commit
0f322b2e06
3 changed files with 40 additions and 12 deletions
|
@ -90,7 +90,10 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
|
||||||
=> OpParams -> Tensor v a -> m (Variable a)
|
=> OpParams -> Tensor v a -> m (Variable a)
|
||||||
initializedVariable' params initializer = do
|
initializedVariable' params initializer = do
|
||||||
-- The shape is not known initially.
|
-- The shape is not known initially.
|
||||||
(Variable h Nothing :: Variable a) <- variableInternal params Nothing
|
variables <- variableInternal params Nothing
|
||||||
|
h <- pure $ case variables of
|
||||||
|
(Variable h Nothing :: Variable a) -> h
|
||||||
|
_ -> error "variableInternal is empty"
|
||||||
initializer' <- renderValue initializer
|
initializer' <- renderValue initializer
|
||||||
i <- CoreOps.assignVariableOp h initializer'
|
i <- CoreOps.assignVariableOp h initializer'
|
||||||
addInitializer =<< group i
|
addInitializer =<< group i
|
||||||
|
|
|
@ -161,7 +161,8 @@ instance Arbitrary a => Arbitrary (LookupExample a) where
|
||||||
let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64)
|
let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64)
|
||||||
doubleMaxDim :: Double
|
doubleMaxDim :: Double
|
||||||
doubleMaxDim = 100 ** (1 / fromIntegral rank)
|
doubleMaxDim = 100 ** (1 / fromIntegral rank)
|
||||||
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
|
shape <- vectorOf rank (choose (1, maxDim))
|
||||||
|
let firstDim = head shape
|
||||||
values <- vectorOf (fromIntegral $ product shape) arbitrary
|
values <- vectorOf (fromIntegral $ product shape) arbitrary
|
||||||
numParts <- choose (2, 15)
|
numParts <- choose (2, 15)
|
||||||
indSize <- choose (0, fromIntegral $ firstDim - 1)
|
indSize <- choose (0, fromIntegral $ firstDim - 1)
|
||||||
|
|
|
@ -26,7 +26,7 @@ import Test.Framework (defaultMain, Test)
|
||||||
import Lens.Family2 ((^..), (.~))
|
import Lens.Family2 ((^..), (.~))
|
||||||
|
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?), assertEqual)
|
import Test.HUnit ((@=?), assertEqual, assertFailure)
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import System.Random (randomIO, randomRIO)
|
import System.Random (randomIO, randomRIO)
|
||||||
import Control.Monad(forM_, replicateM, zipWithM)
|
import Control.Monad(forM_, replicateM, zipWithM)
|
||||||
|
@ -557,13 +557,19 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
|
||||||
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
|
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
|
||||||
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
|
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
|
||||||
let f = x `TF.matMul` TF.readValue w
|
let f = x `TF.matMul` TF.readValue w
|
||||||
[dfdx] <- TF.gradients f [x]
|
l1 <- TF.gradients f [x]
|
||||||
|
let dfdx = head l1 -- avoid MonadFail
|
||||||
let f'x = TF.reduceSum dfdx
|
let f'x = TF.reduceSum dfdx
|
||||||
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
l2 <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||||
|
let dfdw = head l2
|
||||||
return [TF.readValue w, TF.expr dfdw]
|
return [TF.readValue w, TF.expr dfdw]
|
||||||
|
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
[w, dfdw] <- TF.build tower
|
l <- TF.build tower
|
||||||
|
(w, dfdw) <-
|
||||||
|
case l of
|
||||||
|
[w, dfdw] -> pure (w, dfdw)
|
||||||
|
_ -> liftIO $ assertFailure "pattern-match failure in matMulGradMad"
|
||||||
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
||||||
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
@ -589,7 +595,11 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw)
|
||||||
return (x, w, ds)
|
return (x, w, ds)
|
||||||
|
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
(x, w, [dx, dw]) <- TF.build dfBuild
|
(x, w, d) <- TF.build dfBuild
|
||||||
|
(dx, dw) <-
|
||||||
|
case d of
|
||||||
|
[dx, dw] -> pure (dx, dw)
|
||||||
|
_ -> liftIO $ assertFailure "pattern-match failure in matMulTransposeGradient"
|
||||||
xShape <- TF.run $ TF.shape x
|
xShape <- TF.run $ TF.shape x
|
||||||
dxShape <- TF.run $ TF.shape dx
|
dxShape <- TF.run $ TF.shape dx
|
||||||
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
||||||
|
@ -616,7 +626,11 @@ batchMatMulGradient = testCase "batchMatMulGradients" $ do
|
||||||
return (x, dfs)
|
return (x, dfs)
|
||||||
|
|
||||||
(xShape, dxShape) <- TF.runSession $ do
|
(xShape, dxShape) <- TF.runSession $ do
|
||||||
(x, [dx]) <- TF.build dfBuild
|
(x, dl) <- TF.build dfBuild
|
||||||
|
dx <-
|
||||||
|
case dl of
|
||||||
|
[dx] -> pure dx
|
||||||
|
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradient"
|
||||||
TF.run (TF.shape x, TF.shape dx)
|
TF.run (TF.shape x, TF.shape dx)
|
||||||
|
|
||||||
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
|
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
|
||||||
|
@ -633,13 +647,19 @@ batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
|
||||||
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
|
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
|
||||||
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
|
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
|
||||||
let f = x `TF.batchMatMul` TF.readValue w
|
let f = x `TF.batchMatMul` TF.readValue w
|
||||||
[dfdx] <- TF.gradients f [x]
|
l1 <- TF.gradients f [x]
|
||||||
|
let dfdx = head l1
|
||||||
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
|
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
|
||||||
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
l2 <- TF.gradients f'x [w] -- take gradient again (this time over w)
|
||||||
|
let dfdw = head l2
|
||||||
return [TF.readValue w, TF.expr dfdw]
|
return [TF.readValue w, TF.expr dfdw]
|
||||||
|
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
[w, dfdw] <- TF.build tower
|
l <- TF.build tower
|
||||||
|
(w, dfdw) <-
|
||||||
|
case l of
|
||||||
|
[w, dfdw] -> pure (w, dfdw)
|
||||||
|
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradGrad"
|
||||||
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
|
||||||
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
|
||||||
|
|
||||||
|
@ -665,7 +685,11 @@ batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ sho
|
||||||
return (x, w, ds)
|
return (x, w, ds)
|
||||||
|
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
(x, w, [dx, dw]) <- TF.build dfBuild
|
(x, w, d) <- TF.build dfBuild
|
||||||
|
(dx, dw) <-
|
||||||
|
case d of
|
||||||
|
[dx, dw] -> pure (dx, dw)
|
||||||
|
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulAdjointGradient"
|
||||||
xShape <- TF.run $ TF.shape x
|
xShape <- TF.run $ TF.shape x
|
||||||
dxShape <- TF.run $ TF.shape dx
|
dxShape <- TF.run $ TF.shape dx
|
||||||
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
|
||||||
|
|
Loading…
Reference in a new issue