mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Make code --pedantic (#35)
* Enforce pedantic build mode in CI. * Our imports drifted really far from where they should be.
This commit is contained in:
parent
69fdbf677f
commit
2b5e41ffeb
21 changed files with 101 additions and 56 deletions
|
@ -8,4 +8,4 @@ IMAGE_NAME=tensorflow/haskell/ci_build:v0
|
||||||
|
|
||||||
git submodule update
|
git submodule update
|
||||||
docker build -t $IMAGE_NAME -f ci_build/Dockerfile .
|
docker build -t $IMAGE_NAME -f ci_build/Dockerfile .
|
||||||
docker run $IMAGE_NAME stack test
|
docker run $IMAGE_NAME stack build --pedantic --test
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE OverloadedLists #-}
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
|
||||||
import Control.Monad (zipWithM, when, forM, forM_)
|
import Control.Monad (zipWithM, when, forM_)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.List (genericLength)
|
import Data.List (genericLength)
|
||||||
|
@ -34,7 +34,8 @@ import qualified TensorFlow.Types as TF
|
||||||
import TensorFlow.Examples.MNIST.InputData
|
import TensorFlow.Examples.MNIST.InputData
|
||||||
import TensorFlow.Examples.MNIST.Parse
|
import TensorFlow.Examples.MNIST.Parse
|
||||||
|
|
||||||
numPixels = 28^2 :: Int64
|
numPixels, numLabels :: Int64
|
||||||
|
numPixels = 28*28 :: Int64
|
||||||
numLabels = 10 :: Int64
|
numLabels = 10 :: Int64
|
||||||
|
|
||||||
-- | Create tensor with random values where the stddev depends on the width.
|
-- | Create tensor with random values where the stddev depends on the width.
|
||||||
|
@ -44,6 +45,7 @@ randomParam width (TF.Shape shape) =
|
||||||
where
|
where
|
||||||
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||||
|
|
||||||
|
reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
|
||||||
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
|
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
|
||||||
|
|
||||||
-- Types must match due to model structure.
|
-- Types must match due to model structure.
|
||||||
|
@ -108,6 +110,7 @@ createModel = do
|
||||||
] errorRateTensor
|
] errorRateTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
main = TF.runSession $ do
|
main = TF.runSession $ do
|
||||||
-- Read training and test data.
|
-- Read training and test data.
|
||||||
trainingImages <- liftIO (readMNISTSamples =<< trainingImageData)
|
trainingImages <- liftIO (readMNISTSamples =<< trainingImageData)
|
||||||
|
|
|
@ -52,12 +52,14 @@ import TensorFlow.Nodes (unScalar)
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
||||||
import TensorFlow.Types (TensorType(..), Shape(..))
|
import TensorFlow.Types (TensorType(..), Shape(..))
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?), Assertion)
|
import Test.HUnit ((@=?), Assertion)
|
||||||
import Google.Test
|
import Google.Test
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
-- | Test that a file can be read and the GraphDef proto correctly parsed.
|
-- | Test that a file can be read and the GraphDef proto correctly parsed.
|
||||||
|
testReadMessageFromFileOrDie :: Test
|
||||||
testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
|
testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
|
||||||
-- Check the function on a known well-formatted file.
|
-- Check the function on a known well-formatted file.
|
||||||
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
|
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
|
||||||
|
@ -72,6 +74,7 @@ testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
|
||||||
|
|
||||||
-- | Parse the test set for label and image data. Will only fail if the file is
|
-- | Parse the test set for label and image data. Will only fail if the file is
|
||||||
-- missing or incredibly corrupt.
|
-- missing or incredibly corrupt.
|
||||||
|
testReadMNIST :: Test
|
||||||
testReadMNIST = testCase "testReadMNIST" $ do
|
testReadMNIST = testCase "testReadMNIST" $ do
|
||||||
imageData <- readMNISTSamples =<< testImageData
|
imageData <- readMNISTSamples =<< testImageData
|
||||||
10000 @=? length imageData
|
10000 @=? length imageData
|
||||||
|
@ -84,6 +87,7 @@ testNodeName n g = n @=? opName
|
||||||
opName = head (gDef^.node)^.op
|
opName = head (gDef^.node)^.op
|
||||||
gDef = asGraphDef $ render g
|
gDef = asGraphDef $ render g
|
||||||
|
|
||||||
|
testGraphDefGen :: Test
|
||||||
testGraphDefGen = testCase "testGraphDefGen" $ do
|
testGraphDefGen = testCase "testGraphDefGen" $ do
|
||||||
-- Test the inferred operation type.
|
-- Test the inferred operation type.
|
||||||
let f0 :: Tensor Value Float
|
let f0 :: Tensor Value Float
|
||||||
|
@ -101,6 +105,7 @@ testGraphDefGen = testCase "testGraphDefGen" $ do
|
||||||
testNodeName "Mul" $ (1 + f0) * 2
|
testNodeName "Mul" $ (1 + f0) * 2
|
||||||
|
|
||||||
-- | Convert a simple graph to GraphDef, load it, run it, and check the output.
|
-- | Convert a simple graph to GraphDef, load it, run it, and check the output.
|
||||||
|
testGraphDefExec :: Test
|
||||||
testGraphDefExec = testCase "testGraphDefExec" $ do
|
testGraphDefExec = testCase "testGraphDefExec" $ do
|
||||||
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
||||||
runSession $ do
|
runSession $ do
|
||||||
|
@ -110,6 +115,7 @@ testGraphDefExec = testCase "testGraphDefExec" $ do
|
||||||
|
|
||||||
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
|
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
|
||||||
-- sample data.
|
-- sample data.
|
||||||
|
testMNISTExec :: Test
|
||||||
testMNISTExec = testCase "testMNISTExec" $ do
|
testMNISTExec = testCase "testMNISTExec" $ do
|
||||||
-- Switch to unicode to enable pretty printing of MNIST digits.
|
-- Switch to unicode to enable pretty printing of MNIST digits.
|
||||||
IO.hSetEncoding IO.stdout IO.utf8
|
IO.hSetEncoding IO.stdout IO.utf8
|
||||||
|
|
|
@ -22,7 +22,7 @@ module TensorFlow.NN
|
||||||
import Prelude hiding ( log
|
import Prelude hiding ( log
|
||||||
, exp
|
, exp
|
||||||
)
|
)
|
||||||
import TensorFlow.Build ( Build(..)
|
import TensorFlow.Build ( Build
|
||||||
, render
|
, render
|
||||||
, withNameScope
|
, withNameScope
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual
|
||||||
, exp
|
, exp
|
||||||
)
|
)
|
||||||
import TensorFlow.Tensor ( Tensor(..)
|
import TensorFlow.Tensor ( Tensor(..)
|
||||||
, Value(..)
|
, Value
|
||||||
)
|
)
|
||||||
import TensorFlow.Types ( TensorType(..)
|
import TensorFlow.Types ( TensorType(..)
|
||||||
, OneOf
|
, OneOf
|
||||||
|
|
|
@ -12,28 +12,22 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE OverloadedLists #-}
|
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
|
||||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
|
||||||
module Main where
|
module Main where
|
||||||
|
|
||||||
import Data.Maybe (fromMaybe)
|
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import TensorFlow.Test (assertAllClose)
|
import TensorFlow.Test (assertAllClose)
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@?))
|
|
||||||
import Test.HUnit.Lang (Assertion(..))
|
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import qualified TensorFlow.Build as TF
|
import qualified TensorFlow.Build as TF
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
|
import qualified TensorFlow.Nodes as TF
|
||||||
import qualified TensorFlow.NN as TF
|
import qualified TensorFlow.NN as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
|
||||||
import qualified TensorFlow.Types as TF
|
|
||||||
|
|
||||||
-- | These tests are ported from:
|
-- | These tests are ported from:
|
||||||
--
|
--
|
||||||
|
@ -46,9 +40,9 @@ sigmoidXentWithLogits :: Floating a => Ord a => [a] -> [a] -> [a]
|
||||||
sigmoidXentWithLogits logits' targets' =
|
sigmoidXentWithLogits logits' targets' =
|
||||||
let sig = map (\x -> 1 / (1 + exp (-x))) logits'
|
let sig = map (\x -> 1 / (1 + exp (-x))) logits'
|
||||||
eps = 0.0001
|
eps = 0.0001
|
||||||
pred = map (\p -> min (max p eps) (1 - eps)) sig
|
predictions = map (\p -> min (max p eps) (1 - eps)) sig
|
||||||
xent y z = (-z) * (log y) - (1 - z) * log (1 - y)
|
xent y z = (-z) * (log y) - (1 - z) * log (1 - y)
|
||||||
in zipWith xent pred targets'
|
in zipWith xent predictions targets'
|
||||||
|
|
||||||
|
|
||||||
data Inputs = Inputs {
|
data Inputs = Inputs {
|
||||||
|
@ -64,6 +58,7 @@ defInputs = Inputs {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
testLogisticOutput :: Test
|
||||||
testLogisticOutput = testCase "testLogisticOutput" $ do
|
testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||||
let inputs = defInputs
|
let inputs = defInputs
|
||||||
vLogits = TF.vector $ logits inputs
|
vLogits = TF.vector $ logits inputs
|
||||||
|
@ -75,6 +70,7 @@ testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||||
assertAllClose r ourLoss
|
assertAllClose r ourLoss
|
||||||
|
|
||||||
|
|
||||||
|
testLogisticOutputMultipleDim :: Test
|
||||||
testLogisticOutputMultipleDim =
|
testLogisticOutputMultipleDim =
|
||||||
testCase "testLogisticOutputMultipleDim" $ do
|
testCase "testLogisticOutputMultipleDim" $ do
|
||||||
let inputs = defInputs
|
let inputs = defInputs
|
||||||
|
@ -88,6 +84,7 @@ testLogisticOutputMultipleDim =
|
||||||
assertAllClose r ourLoss
|
assertAllClose r ourLoss
|
||||||
|
|
||||||
|
|
||||||
|
testGradientAtZero :: Test
|
||||||
testGradientAtZero = testCase "testGradientAtZero" $ do
|
testGradientAtZero = testCase "testGradientAtZero" $ do
|
||||||
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
||||||
vLogits = TF.vector $ logits inputs
|
vLogits = TF.vector $ logits inputs
|
||||||
|
@ -100,10 +97,9 @@ testGradientAtZero = testCase "testGradientAtZero" $ do
|
||||||
|
|
||||||
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
||||||
|
|
||||||
|
run :: TF.Fetchable t a => TF.Build t -> IO a
|
||||||
run = TF.runSession . TF.buildAnd TF.run
|
run = TF.runSession . TF.buildAnd TF.run
|
||||||
|
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testGradientAtZero
|
main = googleTest [ testGradientAtZero
|
||||||
, testLogisticOutput
|
, testLogisticOutput
|
||||||
|
|
|
@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
|
||||||
import Control.Monad (zipWithM)
|
import Control.Monad (zipWithM)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import TensorFlow.Build (Build, colocateWith, render)
|
import TensorFlow.Build (Build, colocateWith, render)
|
||||||
import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
|
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
||||||
import TensorFlow.Tensor (Tensor, Value)
|
import TensorFlow.Tensor (Tensor, Value)
|
||||||
import TensorFlow.Types (OneOf, TensorType)
|
import TensorFlow.Types (OneOf, TensorType)
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
|
@ -95,7 +95,7 @@ import TensorFlow.Tensor
|
||||||
, tensorOutput
|
, tensorOutput
|
||||||
, tensorAttr
|
, tensorAttr
|
||||||
)
|
)
|
||||||
import TensorFlow.Types (OneOf, TensorType, attrLens)
|
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
(NodeDef, attr, input, op, name)
|
(NodeDef, attr, input, op, name)
|
||||||
|
|
||||||
|
@ -406,7 +406,7 @@ toT = Tensor ValueKind
|
||||||
|
|
||||||
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
|
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
|
||||||
-- simple slicing operations.
|
-- simple slicing operations.
|
||||||
flatSlice :: forall v1 t i . (TensorType t)
|
flatSlice :: forall v1 t . (TensorType t)
|
||||||
=> Tensor v1 t -- ^ __input__
|
=> Tensor v1 t -- ^ __input__
|
||||||
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
||||||
-- 'input' to slice from.
|
-- 'input' to slice from.
|
||||||
|
@ -415,7 +415,7 @@ flatSlice :: forall v1 t i . (TensorType t)
|
||||||
-- are included in the slice (i.e. this is equivalent to setting
|
-- are included in the slice (i.e. this is equivalent to setting
|
||||||
-- size = input.dim_size(0) - begin).
|
-- size = input.dim_size(0) - begin).
|
||||||
-> Tensor Value t -- ^ __output__
|
-> Tensor Value t -- ^ __output__
|
||||||
flatSlice input begin size = CoreOps.slice input (vector [begin]) (vector [size])
|
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
|
||||||
|
|
||||||
|
|
||||||
-- | The gradient function for an op type.
|
-- | The gradient function for an op type.
|
||||||
|
@ -703,10 +703,14 @@ numOutputs o =
|
||||||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||||
|
|
||||||
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
||||||
|
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32
|
||||||
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
||||||
|
|
||||||
|
allDimensions :: Tensor Value Int32
|
||||||
allDimensions = vector [-1 :: Int32]
|
allDimensions = vector [-1 :: Int32]
|
||||||
|
|
||||||
|
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32
|
||||||
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
||||||
|
|
||||||
|
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
|
||||||
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens
|
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens
|
||||||
|
|
|
@ -17,6 +17,7 @@ module Main where
|
||||||
|
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?))
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
@ -26,6 +27,7 @@ import qualified TensorFlow.Session as TF
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
-- | Test split and concat are inverses.
|
-- | Test split and concat are inverses.
|
||||||
|
testSplit :: Test
|
||||||
testSplit = testCase "testSplit" $ TF.runSession $ do
|
testSplit = testCase "testSplit" $ TF.runSession $ do
|
||||||
let original = TF.constant [2, 3] [0..5 :: Float]
|
let original = TF.constant [2, 3] [0..5 :: Float]
|
||||||
splitList = CoreOps.split 3 dim original
|
splitList = CoreOps.split 3 dim original
|
||||||
|
|
|
@ -59,12 +59,14 @@ import TensorFlow.Session
|
||||||
, runSession
|
, runSession
|
||||||
, run_
|
, run_
|
||||||
)
|
)
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?))
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
-- | Test named behavior.
|
-- | Test named behavior.
|
||||||
|
testNamed :: Test
|
||||||
testNamed = testCase "testNamed" $ do
|
testNamed = testCase "testNamed" $ do
|
||||||
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
|
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
|
||||||
nodeDef :: NodeDef
|
nodeDef :: NodeDef
|
||||||
|
@ -73,6 +75,7 @@ testNamed = testCase "testNamed" $ do
|
||||||
"foo" @=? (nodeDef ^. name)
|
"foo" @=? (nodeDef ^. name)
|
||||||
|
|
||||||
-- | Test named deRef behavior.
|
-- | Test named deRef behavior.
|
||||||
|
testNamedDeRef :: Test
|
||||||
testNamedDeRef = testCase "testNamedDeRef" $ do
|
testNamedDeRef = testCase "testNamedDeRef" $ do
|
||||||
let graph = named "foo" <$> do
|
let graph = named "foo" <$> do
|
||||||
v :: Tensor Ref Float <- variable []
|
v :: Tensor Ref Float <- variable []
|
||||||
|
@ -84,11 +87,13 @@ testNamedDeRef = testCase "testNamedDeRef" $ do
|
||||||
|
|
||||||
-- | Test that "run" will render and extend any pure ops that haven't already
|
-- | Test that "run" will render and extend any pure ops that haven't already
|
||||||
-- been rendered.
|
-- been rendered.
|
||||||
|
testPureRender :: Test
|
||||||
testPureRender = testCase "testPureRender" $ runSession $ do
|
testPureRender = testCase "testPureRender" $ runSession $ do
|
||||||
result <- run $ 2 `add` 2
|
result <- run $ 2 `add` 2
|
||||||
liftIO $ 4 @=? (unScalar result :: Float)
|
liftIO $ 4 @=? (unScalar result :: Float)
|
||||||
|
|
||||||
-- | Test that "run" assigns any previously accumulated initializers.
|
-- | Test that "run" assigns any previously accumulated initializers.
|
||||||
|
testInitializedVariable :: Test
|
||||||
testInitializedVariable =
|
testInitializedVariable =
|
||||||
testCase "testInitializedVariable" $ runSession $ do
|
testCase "testInitializedVariable" $ runSession $ do
|
||||||
(formula, reset) <- build $ do
|
(formula, reset) <- build $ do
|
||||||
|
@ -101,6 +106,7 @@ testInitializedVariable =
|
||||||
rerunResult <- run formula
|
rerunResult <- run formula
|
||||||
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
||||||
|
|
||||||
|
testInitializedVariableShape :: Test
|
||||||
testInitializedVariableShape =
|
testInitializedVariableShape =
|
||||||
testCase "testInitializedVariableShape" $ runSession $ do
|
testCase "testInitializedVariableShape" $ runSession $ do
|
||||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
||||||
|
@ -108,6 +114,7 @@ testInitializedVariableShape =
|
||||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||||
|
|
||||||
-- | Test nameScoped behavior.
|
-- | Test nameScoped behavior.
|
||||||
|
testNameScoped :: Test
|
||||||
testNameScoped = testCase "testNameScoped" $ do
|
testNameScoped = testCase "testNameScoped" $ do
|
||||||
let graph = withNameScope "foo" $ variable [] :: Build (Tensor Ref Float)
|
let graph = withNameScope "foo" $ variable [] :: Build (Tensor Ref Float)
|
||||||
nodeDef :: NodeDef
|
nodeDef :: NodeDef
|
||||||
|
@ -116,6 +123,7 @@ testNameScoped = testCase "testNameScoped" $ do
|
||||||
"Variable" @=? (nodeDef ^. op)
|
"Variable" @=? (nodeDef ^. op)
|
||||||
|
|
||||||
-- | Test combined named and nameScoped behavior.
|
-- | Test combined named and nameScoped behavior.
|
||||||
|
testNamedAndScoped :: Test
|
||||||
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||||
let graph :: Build (Tensor Ref Float)
|
let graph :: Build (Tensor Ref Float)
|
||||||
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
|
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
|
||||||
|
@ -133,6 +141,7 @@ flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
|
||||||
flushed field = sort . map field <$> liftBuild flushNodeBuffer
|
flushed field = sort . map field <$> liftBuild flushNodeBuffer
|
||||||
|
|
||||||
-- | Test the interaction of rendering, CSE and scoping.
|
-- | Test the interaction of rendering, CSE and scoping.
|
||||||
|
testRenderDedup :: Test
|
||||||
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
||||||
liftBuild renderNodes
|
liftBuild renderNodes
|
||||||
names <- flushed (^. name)
|
names <- flushed (^. name)
|
||||||
|
@ -154,6 +163,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
-- | Test the interaction of rendering, CSE and scoping.
|
-- | Test the interaction of rendering, CSE and scoping.
|
||||||
|
testDeviceColocation :: Test
|
||||||
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
|
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
|
||||||
liftBuild renderNodes
|
liftBuild renderNodes
|
||||||
devices <- flushed (\x -> (x ^. name, x ^. device))
|
devices <- flushed (\x -> (x ^. name, x ^. device))
|
||||||
|
|
|
@ -22,9 +22,9 @@ import Data.Int (Int32, Int64)
|
||||||
import Data.List (genericLength)
|
import Data.List (genericLength)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import TensorFlow.EmbeddingOps (embeddingLookup)
|
import TensorFlow.EmbeddingOps (embeddingLookup)
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
import Test.HUnit.Lang (Assertion(..))
|
import Test.HUnit ((@=?))
|
||||||
import Test.HUnit ((@=?), (@?))
|
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||||
import Test.QuickCheck.Monadic (monadicIO, run)
|
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||||
|
@ -46,13 +46,14 @@ buildAndRun = TF.runSession . TF.buildAnd TF.run
|
||||||
|
|
||||||
|
|
||||||
-- | Tries to perform a simple embedding lookup, with two partitions.
|
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||||
|
testEmbeddingLookupHasRightShapeWithPartition :: Test
|
||||||
testEmbeddingLookupHasRightShapeWithPartition =
|
testEmbeddingLookupHasRightShapeWithPartition =
|
||||||
testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
||||||
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||||
let embedding1 = [1, 1, 1 :: Int32]
|
let embedding1 = [1, 1, 1 :: Int32]
|
||||||
let embedding2 = [0, 0, 0 :: Int32]
|
let embedding2 = [0, 0, 0 :: Int32]
|
||||||
let embedding = [ TF.constant shape embedding1
|
let embedding = [ TF.constant embShape embedding1
|
||||||
, TF.constant shape embedding2
|
, TF.constant embShape embedding2
|
||||||
]
|
]
|
||||||
|
|
||||||
let idValues = [0, 1 :: Int32]
|
let idValues = [0, 1 :: Int32]
|
||||||
|
@ -71,15 +72,16 @@ testEmbeddingLookupHasRightShapeWithPartition =
|
||||||
|
|
||||||
|
|
||||||
-- | Tries to perform a simple embedding lookup, with only a single partition.
|
-- | Tries to perform a simple embedding lookup, with only a single partition.
|
||||||
|
testEmbeddingLookupHasRightShape :: Test
|
||||||
testEmbeddingLookupHasRightShape =
|
testEmbeddingLookupHasRightShape =
|
||||||
testCase "testEmbeddingLookupHasRightShape" $ do
|
testCase "testEmbeddingLookupHasRightShape" $ do
|
||||||
-- Consider a 3-dim embedding of two items
|
-- Consider a 3-dim embedding of two items
|
||||||
let shape = TF.Shape [2, 3]
|
let embShape = TF.Shape [2, 3]
|
||||||
let embeddingInit = [ 1, 1, 1
|
let embeddingInit = [ 1, 1, 1
|
||||||
, 0, 0, 0 :: Int32
|
, 0, 0, 0 :: Int32
|
||||||
]
|
]
|
||||||
|
|
||||||
let embedding = TF.constant shape embeddingInit
|
let embedding = TF.constant embShape embeddingInit
|
||||||
let idValues = [0, 1 :: Int32]
|
let idValues = [0, 1 :: Int32]
|
||||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
let op = embeddingLookup [embedding] ids
|
let op = embeddingLookup [embedding] ids
|
||||||
|
@ -96,6 +98,7 @@ testEmbeddingLookupHasRightShape =
|
||||||
|
|
||||||
|
|
||||||
-- | Check that we can calculate gradients w.r.t embeddings.
|
-- | Check that we can calculate gradients w.r.t embeddings.
|
||||||
|
testEmbeddingLookupGradients :: Test
|
||||||
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||||
-- Agrees with "embedding", so gradient should be zero.
|
-- Agrees with "embedding", so gradient should be zero.
|
||||||
let xVals = V.fromList ([20, 20 :: Float])
|
let xVals = V.fromList ([20, 20 :: Float])
|
||||||
|
@ -103,14 +106,14 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||||
|
|
||||||
gs <- TF.runSession $ do
|
gs <- TF.runSession $ do
|
||||||
grads <- TF.build $ do
|
grads <- TF.build $ do
|
||||||
let shape = TF.Shape [2, 1]
|
let embShape = TF.Shape [2, 1]
|
||||||
let embeddingInit = [1, 20 ::Float]
|
let embeddingInit = [1, 20 ::Float]
|
||||||
let idValues = [1, 1 :: Int32]
|
let idValues = [1, 1 :: Int32]
|
||||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
|
|
||||||
x <- TF.placeholder (TF.Shape [2])
|
x <- TF.placeholder (TF.Shape [2])
|
||||||
embedding <- TF.initializedVariable
|
embedding <- TF.initializedVariable
|
||||||
=<< TF.render (TF.constant shape embeddingInit)
|
=<< TF.render (TF.constant embShape embeddingInit)
|
||||||
|
|
||||||
op <- embeddingLookup [embedding] ids
|
op <- embeddingLookup [embedding] ids
|
||||||
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
||||||
|
@ -163,7 +166,9 @@ instance Arbitrary a => Arbitrary (LookupExample a) where
|
||||||
arbitrary = do
|
arbitrary = do
|
||||||
rank <- choose (1, 4)
|
rank <- choose (1, 4)
|
||||||
-- Takes rank-th root of 100 to cap the tensor size.
|
-- Takes rank-th root of 100 to cap the tensor size.
|
||||||
let maxDim = fromIntegral $ ceiling $ 100 ** (1 / fromIntegral rank)
|
let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64)
|
||||||
|
doubleMaxDim :: Double
|
||||||
|
doubleMaxDim = 100 ** (1 / fromIntegral rank)
|
||||||
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
|
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
|
||||||
values <- vectorOf (fromIntegral $ product shape) arbitrary
|
values <- vectorOf (fromIntegral $ product shape) arbitrary
|
||||||
numParts <- choose (2, 15)
|
numParts <- choose (2, 15)
|
||||||
|
|
|
@ -19,6 +19,7 @@ module Main where
|
||||||
|
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int32)
|
import Data.Int (Int32)
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?))
|
||||||
import Google.Test
|
import Google.Test
|
||||||
|
@ -29,6 +30,7 @@ import TensorFlow.Ops
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
|
|
||||||
-- | Test fetching multiple outputs from an op.
|
-- | Test fetching multiple outputs from an op.
|
||||||
|
testMultipleOutputs :: Test
|
||||||
testMultipleOutputs = testCase "testMultipleOutputs" $
|
testMultipleOutputs = testCase "testMultipleOutputs" $
|
||||||
runSession $ do
|
runSession $ do
|
||||||
(values, indices) <-
|
(values, indices) <-
|
||||||
|
@ -37,6 +39,7 @@ testMultipleOutputs = testCase "testMultipleOutputs" $
|
||||||
liftIO $ [1, 3] @=? V.toList (indices :: V.Vector Int32)
|
liftIO $ [1, 3] @=? V.toList (indices :: V.Vector Int32)
|
||||||
|
|
||||||
-- | Test op with variable number of inputs.
|
-- | Test op with variable number of inputs.
|
||||||
|
testVarargs :: Test
|
||||||
testVarargs = testCase "testVarargs" $
|
testVarargs = testCase "testVarargs" $
|
||||||
runSession $ do
|
runSession $ do
|
||||||
xs <- run $ pack $ map scalar [1..8]
|
xs <- run $ pack $ map scalar [1..8]
|
||||||
|
|
|
@ -12,9 +12,7 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE OverloadedLists #-}
|
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
|
||||||
|
|
||||||
module Main where
|
module Main where
|
||||||
|
|
||||||
|
@ -22,6 +20,7 @@ import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import System.IO.Temp (withSystemTempDirectory)
|
import System.IO.Temp (withSystemTempDirectory)
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?))
|
||||||
import qualified Data.ByteString.Char8 as B8
|
import qualified Data.ByteString.Char8 as B8
|
||||||
|
@ -33,25 +32,31 @@ import qualified TensorFlow.Nodes as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
-- | Test that one can easily determine number of elements in the tensor.
|
-- | Test that one can easily determine number of elements in the tensor.
|
||||||
|
testSize :: Test
|
||||||
testSize = testCase "testSize" $ do
|
testSize = testCase "testSize" $ do
|
||||||
x <- eval $ TF.size (TF.constant [2, 3] [0..5 :: Float])
|
x <- eval $ TF.size (TF.constant (TF.Shape [2, 3]) [0..5 :: Float])
|
||||||
TF.Scalar (2 * 3 :: Int32) @=? x
|
TF.Scalar (2 * 3 :: Int32) @=? x
|
||||||
|
|
||||||
|
eval :: TF.Fetchable t a => t -> IO a
|
||||||
eval = TF.runSession . TF.buildAnd TF.run . return
|
eval = TF.runSession . TF.buildAnd TF.run . return
|
||||||
|
|
||||||
-- | Confirms that the original example from Python code works.
|
-- | Confirms that the original example from Python code works.
|
||||||
|
testReducedShape :: Test
|
||||||
testReducedShape = testCase "testReducedShape" $ do
|
testReducedShape = testCase "testReducedShape" $ do
|
||||||
x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64])
|
x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64])
|
||||||
(TF.vector [1, 2 :: Int32])
|
(TF.vector [1, 2 :: Int32])
|
||||||
V.fromList [2, 1, 1, 7 :: Int32] @=? x
|
V.fromList [2, 1, 1, 7 :: Int32] @=? x
|
||||||
|
|
||||||
|
testSaveRestore :: Test
|
||||||
testSaveRestore = testCase "testSaveRestore" $
|
testSaveRestore = testCase "testSaveRestore" $
|
||||||
withSystemTempDirectory "" $ \dirPath -> do
|
withSystemTempDirectory "" $ \dirPath -> do
|
||||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||||
var :: TF.Build (TF.Tensor TF.Ref Float)
|
var :: TF.Build (TF.Tensor TF.Ref Float)
|
||||||
var = TF.render =<< TF.named "a" <$> TF.zeroInitializedVariable []
|
var = TF.render =<<
|
||||||
|
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
v <- TF.build var
|
v <- TF.build var
|
||||||
TF.buildAnd TF.run_ $ TF.assign v 134
|
TF.buildAnd TF.run_ $ TF.assign v 134
|
||||||
|
|
|
@ -21,11 +21,10 @@ import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar)
|
||||||
import Data.ByteString.Builder (toLazyByteString)
|
import Data.ByteString.Builder (toLazyByteString)
|
||||||
import Data.ByteString.Lazy (isPrefixOf)
|
import Data.ByteString.Lazy (isPrefixOf)
|
||||||
import Data.Default (def)
|
import Data.Default (def)
|
||||||
import Data.Monoid ((<>))
|
|
||||||
import Lens.Family2 ((&), (.~))
|
import Lens.Family2 ((&), (.~))
|
||||||
import Test.Framework (defaultMain)
|
import Test.Framework (defaultMain)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?), assertBool, assertFailure)
|
import Test.HUnit (assertBool, assertFailure)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
|
@ -44,6 +43,7 @@ testTracing = do
|
||||||
assertBool ("Unexpected log entry " ++ show got)
|
assertBool ("Unexpected log entry " ++ show got)
|
||||||
("Session.extend" `isPrefixOf` got)
|
("Session.extend" `isPrefixOf` got)
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
main = defaultMain
|
main = defaultMain
|
||||||
[ testCase "Tracing" testTracing
|
[ testCase "Tracing" testTracing
|
||||||
]
|
]
|
||||||
|
|
|
@ -19,11 +19,14 @@
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||||
|
-- Purposely disabled to confirm doubleFuncNoSig can be written without type.
|
||||||
|
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||||
|
|
||||||
import Control.Monad (replicateM)
|
import Control.Monad (replicateM)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?))
|
||||||
|
@ -43,10 +46,11 @@ instance Arbitrary B.ByteString where
|
||||||
|
|
||||||
-- Test encoding tensors, feeding them through tensorflow, and decoding the
|
-- Test encoding tensors, feeding them through tensorflow, and decoding the
|
||||||
-- results.
|
-- results.
|
||||||
|
testFFIRoundTrip :: Test
|
||||||
testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
let floatData = V.fromList [1..6 :: Float]
|
let floatData = V.fromList [1..6 :: Float]
|
||||||
stringData = V.fromList [B8.pack (show x) | x <- [1..6]]
|
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
|
||||||
f <- TF.build $ TF.placeholder [2,3]
|
f <- TF.build $ TF.placeholder [2,3]
|
||||||
s <- TF.build $ TF.placeholder [2,3]
|
s <- TF.build $ TF.placeholder [2,3]
|
||||||
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
||||||
|
@ -78,12 +82,15 @@ encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
|
||||||
encodeDecodeProp (TensorDataInputs shape vec) =
|
encodeDecodeProp (TensorDataInputs shape vec) =
|
||||||
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
|
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
|
||||||
|
|
||||||
|
testEncodeDecodeQcFloat :: Test
|
||||||
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
|
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
|
||||||
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
|
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
|
||||||
|
|
||||||
|
testEncodeDecodeQcInt64 :: Test
|
||||||
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
|
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
|
||||||
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
|
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
|
||||||
|
|
||||||
|
testEncodeDecodeQcString :: Test
|
||||||
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
|
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
|
||||||
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
|
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
|
||||||
|
|
||||||
|
@ -101,6 +108,7 @@ doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
|
||||||
-- can't simplify the type all the way to `Double -> Double`.
|
-- can't simplify the type all the way to `Double -> Double`.
|
||||||
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
|
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
|
||||||
|
|
||||||
|
typeConstraintTests :: Test
|
||||||
typeConstraintTests = testCase "type constraints" $ do
|
typeConstraintTests = testCase "type constraints" $ do
|
||||||
42 @=? doubleOrInt64Func (42 :: Double)
|
42 @=? doubleOrInt64Func (42 :: Double)
|
||||||
42 @=? doubleOrInt64Func (42 :: Int64)
|
42 @=? doubleOrInt64Func (42 :: Int64)
|
||||||
|
|
|
@ -31,11 +31,13 @@ import TensorFlow.Session
|
||||||
, runSession
|
, runSession
|
||||||
, run_
|
, run_
|
||||||
)
|
)
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?))
|
||||||
import qualified Data.ByteString as BS
|
import qualified Data.ByteString as BS
|
||||||
|
|
||||||
-- | Test basic queue behaviors.
|
-- | Test basic queue behaviors.
|
||||||
|
testBasic :: Test
|
||||||
testBasic = testCase "testBasic" $ runSession $ do
|
testBasic = testCase "testBasic" $ runSession $ do
|
||||||
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
|
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
|
||||||
buildAnd run_ (enqueue q 42 (scalar "Hi"))
|
buildAnd run_ (enqueue q 42 (scalar "Hi"))
|
||||||
|
@ -47,6 +49,7 @@ testBasic = testCase "testBasic" $ runSession $ do
|
||||||
liftIO $ (Scalar 56, Scalar "Bar") @=? y
|
liftIO $ (Scalar 56, Scalar "Bar") @=? y
|
||||||
|
|
||||||
-- | Test queue pumping.
|
-- | Test queue pumping.
|
||||||
|
testPump :: Test
|
||||||
testPump = testCase "testPump" $ runSession $ do
|
testPump = testCase "testPump" $ runSession $ do
|
||||||
(deq, pump) <- build $ do
|
(deq, pump) <- build $ do
|
||||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
|
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
|
||||||
|
@ -61,6 +64,7 @@ testPump = testCase "testPump" $ runSession $ do
|
||||||
liftIO $ (Scalar 31, Scalar "Baz") @=? x
|
liftIO $ (Scalar 31, Scalar "Baz") @=? x
|
||||||
liftIO $ (Scalar 31, Scalar "Baz") @=? y
|
liftIO $ (Scalar 31, Scalar "Baz") @=? y
|
||||||
|
|
||||||
|
testAsync :: Test
|
||||||
testAsync = testCase "testAsync" $ runSession $ do
|
testAsync = testCase "testAsync" $ runSession $ do
|
||||||
(deq, pump) <- build $ do
|
(deq, pump) <- build $ do
|
||||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
|
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
|
||||||
|
|
|
@ -20,8 +20,7 @@ module TensorFlow.Test
|
||||||
|
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import Test.HUnit ((@?))
|
import Test.HUnit ((@?))
|
||||||
import Test.HUnit.Lang (Assertion(..))
|
import Test.HUnit.Lang (Assertion)
|
||||||
|
|
||||||
-- | Compares that the vectors are element-by-element equal within the given
|
-- | Compares that the vectors are element-by-element equal within the given
|
||||||
-- tolerance. Raises an assertion and prints some information if not.
|
-- tolerance. Raises an assertion and prints some information if not.
|
||||||
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||||
|
@ -31,4 +30,3 @@ assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
||||||
where
|
where
|
||||||
absDiff x y = abs (x - y)
|
absDiff x y = abs (x - y)
|
||||||
tol = 0.001 :: Float
|
tol = 0.001 :: Float
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,6 @@ import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
||||||
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
||||||
import Control.Monad (when)
|
import Control.Monad (when)
|
||||||
import Data.Bits (Bits, toIntegralSized)
|
import Data.Bits (Bits, toIntegralSized)
|
||||||
import Data.Data (Data, dataTypeName, dataTypeOf)
|
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
import Data.Maybe (fromMaybe)
|
import Data.Maybe (fromMaybe)
|
||||||
import Data.Typeable (Typeable)
|
import Data.Typeable (Typeable)
|
||||||
|
|
|
@ -13,8 +13,9 @@
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
{-# LANGUAGE Rank2Types #-}
|
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE Rank2Types #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
module TensorFlow.Output
|
module TensorFlow.Output
|
||||||
( ControlNode(..)
|
( ControlNode(..)
|
||||||
|
@ -150,8 +151,8 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
|
||||||
-- code into a Build function
|
-- code into a Build function
|
||||||
instance IsString Output where
|
instance IsString Output where
|
||||||
fromString s = case break (==':') s of
|
fromString s = case break (==':') s of
|
||||||
(n, ':':ixStr)
|
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
|
||||||
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
|
-> Output (fromInteger ix) $ assigned n
|
||||||
_ -> Output 0 $ assigned s
|
_ -> Output 0 $ assigned s
|
||||||
where assigned n = Rendered $ def & name .~ Text.pack n
|
where assigned n = Rendered $ def & name .~ Text.pack n
|
||||||
|
|
||||||
|
|
|
@ -46,24 +46,22 @@ import Data.ByteString (ByteString)
|
||||||
import Data.Default (Default, def)
|
import Data.Default (Default, def)
|
||||||
import Data.Functor.Identity (runIdentity)
|
import Data.Functor.Identity (runIdentity)
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
import qualified Data.Map.Strict as Map
|
import Data.ProtoLens (showMessage)
|
||||||
import qualified Data.Set as Set
|
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
import Data.Text.Encoding (encodeUtf8)
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
import Data.ProtoLens (def, showMessage)
|
|
||||||
import Lens.Family2 (Lens', (^.), (&), (.~))
|
import Lens.Family2 (Lens', (^.), (&), (.~))
|
||||||
import Lens.Family2.Unchecked (lens)
|
import Lens.Family2.Unchecked (lens)
|
||||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||||
|
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.Nodes
|
import TensorFlow.Nodes
|
||||||
import TensorFlow.Output (NodeName, unNodeName)
|
import TensorFlow.Output (NodeName, unNodeName)
|
||||||
import TensorFlow.Tensor
|
import TensorFlow.Tensor
|
||||||
|
|
||||||
import qualified Data.ByteString.Builder as Builder
|
import qualified Data.ByteString.Builder as Builder
|
||||||
|
import qualified Data.Map.Strict as Map
|
||||||
|
import qualified Data.Set as Set
|
||||||
import qualified TensorFlow.Internal.FFI as FFI
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
import qualified TensorFlow.Internal.Raw as Raw
|
|
||||||
|
|
||||||
-- | An action for logging.
|
-- | An action for logging.
|
||||||
type Tracer = Builder.Builder -> IO ()
|
type Tracer = Builder.Builder -> IO ()
|
||||||
|
|
|
@ -33,6 +33,7 @@ testParseAll = do
|
||||||
. not . null . view op)
|
. not . null . view op)
|
||||||
(decodeMessage opList :: Either String OpList)
|
(decodeMessage opList :: Either String OpList)
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
main = defaultMain
|
main = defaultMain
|
||||||
[ testCase "ParseAllOps" testParseAll
|
[ testCase "ParseAllOps" testParseAll
|
||||||
]
|
]
|
||||||
|
|
|
@ -16,11 +16,13 @@ module Main where
|
||||||
|
|
||||||
import Data.ByteString.Builder (toLazyByteString)
|
import Data.ByteString.Builder (toLazyByteString)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||||
import qualified Data.Attoparsec.ByteString.Lazy as Atto
|
import qualified Data.Attoparsec.ByteString.Lazy as Atto
|
||||||
|
|
||||||
import TensorFlow.Internal.VarInt
|
import TensorFlow.Internal.VarInt
|
||||||
|
|
||||||
|
testEncodeDecode :: Test
|
||||||
testEncodeDecode = testProperty "testEncodeDecode" $ \x ->
|
testEncodeDecode = testProperty "testEncodeDecode" $ \x ->
|
||||||
let bytes = toLazyByteString (putVarInt x)
|
let bytes = toLazyByteString (putVarInt x)
|
||||||
in case Atto.eitherResult $ Atto.parse getVarInt bytes of
|
in case Atto.eitherResult $ Atto.parse getVarInt bytes of
|
||||||
|
|
Loading…
Reference in a new issue