mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-24 02:29:46 +01:00
Make code --pedantic (#35)
* Enforce pedantic build mode in CI. * Our imports drifted really far from where they should be.
This commit is contained in:
parent
69fdbf677f
commit
2b5e41ffeb
21 changed files with 101 additions and 56 deletions
|
@ -8,4 +8,4 @@ IMAGE_NAME=tensorflow/haskell/ci_build:v0
|
|||
|
||||
git submodule update
|
||||
docker build -t $IMAGE_NAME -f ci_build/Dockerfile .
|
||||
docker run $IMAGE_NAME stack test
|
||||
docker run $IMAGE_NAME stack build --pedantic --test
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
|
||||
import Control.Monad (zipWithM, when, forM, forM_)
|
||||
import Control.Monad (zipWithM, when, forM_)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
|
@ -34,7 +34,8 @@ import qualified TensorFlow.Types as TF
|
|||
import TensorFlow.Examples.MNIST.InputData
|
||||
import TensorFlow.Examples.MNIST.Parse
|
||||
|
||||
numPixels = 28^2 :: Int64
|
||||
numPixels, numLabels :: Int64
|
||||
numPixels = 28*28 :: Int64
|
||||
numLabels = 10 :: Int64
|
||||
|
||||
-- | Create tensor with random values where the stddev depends on the width.
|
||||
|
@ -44,6 +45,7 @@ randomParam width (TF.Shape shape) =
|
|||
where
|
||||
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||
|
||||
reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
|
||||
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
|
||||
|
||||
-- Types must match due to model structure.
|
||||
|
@ -108,6 +110,7 @@ createModel = do
|
|||
] errorRateTensor
|
||||
}
|
||||
|
||||
main :: IO ()
|
||||
main = TF.runSession $ do
|
||||
-- Read training and test data.
|
||||
trainingImages <- liftIO (readMNISTSamples =<< trainingImageData)
|
||||
|
|
|
@ -52,12 +52,14 @@ import TensorFlow.Nodes (unScalar)
|
|||
import TensorFlow.Session
|
||||
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
||||
import TensorFlow.Types (TensorType(..), Shape(..))
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?), Assertion)
|
||||
import Google.Test
|
||||
import qualified Data.Vector as V
|
||||
|
||||
-- | Test that a file can be read and the GraphDef proto correctly parsed.
|
||||
testReadMessageFromFileOrDie :: Test
|
||||
testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
|
||||
-- Check the function on a known well-formatted file.
|
||||
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
|
||||
|
@ -72,6 +74,7 @@ testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
|
|||
|
||||
-- | Parse the test set for label and image data. Will only fail if the file is
|
||||
-- missing or incredibly corrupt.
|
||||
testReadMNIST :: Test
|
||||
testReadMNIST = testCase "testReadMNIST" $ do
|
||||
imageData <- readMNISTSamples =<< testImageData
|
||||
10000 @=? length imageData
|
||||
|
@ -84,6 +87,7 @@ testNodeName n g = n @=? opName
|
|||
opName = head (gDef^.node)^.op
|
||||
gDef = asGraphDef $ render g
|
||||
|
||||
testGraphDefGen :: Test
|
||||
testGraphDefGen = testCase "testGraphDefGen" $ do
|
||||
-- Test the inferred operation type.
|
||||
let f0 :: Tensor Value Float
|
||||
|
@ -101,6 +105,7 @@ testGraphDefGen = testCase "testGraphDefGen" $ do
|
|||
testNodeName "Mul" $ (1 + f0) * 2
|
||||
|
||||
-- | Convert a simple graph to GraphDef, load it, run it, and check the output.
|
||||
testGraphDefExec :: Test
|
||||
testGraphDefExec = testCase "testGraphDefExec" $ do
|
||||
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
||||
runSession $ do
|
||||
|
@ -110,6 +115,7 @@ testGraphDefExec = testCase "testGraphDefExec" $ do
|
|||
|
||||
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
|
||||
-- sample data.
|
||||
testMNISTExec :: Test
|
||||
testMNISTExec = testCase "testMNISTExec" $ do
|
||||
-- Switch to unicode to enable pretty printing of MNIST digits.
|
||||
IO.hSetEncoding IO.stdout IO.utf8
|
||||
|
|
|
@ -22,7 +22,7 @@ module TensorFlow.NN
|
|||
import Prelude hiding ( log
|
||||
, exp
|
||||
)
|
||||
import TensorFlow.Build ( Build(..)
|
||||
import TensorFlow.Build ( Build
|
||||
, render
|
||||
, withNameScope
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual
|
|||
, exp
|
||||
)
|
||||
import TensorFlow.Tensor ( Tensor(..)
|
||||
, Value(..)
|
||||
, Value
|
||||
)
|
||||
import TensorFlow.Types ( TensorType(..)
|
||||
, OneOf
|
||||
|
|
|
@ -12,28 +12,22 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Test (assertAllClose)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@?))
|
||||
import Test.HUnit.Lang (Assertion(..))
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.NN as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
-- | These tests are ported from:
|
||||
--
|
||||
|
@ -46,9 +40,9 @@ sigmoidXentWithLogits :: Floating a => Ord a => [a] -> [a] -> [a]
|
|||
sigmoidXentWithLogits logits' targets' =
|
||||
let sig = map (\x -> 1 / (1 + exp (-x))) logits'
|
||||
eps = 0.0001
|
||||
pred = map (\p -> min (max p eps) (1 - eps)) sig
|
||||
predictions = map (\p -> min (max p eps) (1 - eps)) sig
|
||||
xent y z = (-z) * (log y) - (1 - z) * log (1 - y)
|
||||
in zipWith xent pred targets'
|
||||
in zipWith xent predictions targets'
|
||||
|
||||
|
||||
data Inputs = Inputs {
|
||||
|
@ -64,6 +58,7 @@ defInputs = Inputs {
|
|||
}
|
||||
|
||||
|
||||
testLogisticOutput :: Test
|
||||
testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||
let inputs = defInputs
|
||||
vLogits = TF.vector $ logits inputs
|
||||
|
@ -75,6 +70,7 @@ testLogisticOutput = testCase "testLogisticOutput" $ do
|
|||
assertAllClose r ourLoss
|
||||
|
||||
|
||||
testLogisticOutputMultipleDim :: Test
|
||||
testLogisticOutputMultipleDim =
|
||||
testCase "testLogisticOutputMultipleDim" $ do
|
||||
let inputs = defInputs
|
||||
|
@ -88,6 +84,7 @@ testLogisticOutputMultipleDim =
|
|||
assertAllClose r ourLoss
|
||||
|
||||
|
||||
testGradientAtZero :: Test
|
||||
testGradientAtZero = testCase "testGradientAtZero" $ do
|
||||
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
||||
vLogits = TF.vector $ logits inputs
|
||||
|
@ -100,10 +97,9 @@ testGradientAtZero = testCase "testGradientAtZero" $ do
|
|||
|
||||
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
||||
|
||||
|
||||
run :: TF.Fetchable t a => TF.Build t -> IO a
|
||||
run = TF.runSession . TF.buildAnd TF.run
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testGradientAtZero
|
||||
, testLogisticOutput
|
||||
|
|
|
@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
|
|||
import Control.Monad (zipWithM)
|
||||
import Data.Int (Int32, Int64)
|
||||
import TensorFlow.Build (Build, colocateWith, render)
|
||||
import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Tensor (Tensor, Value)
|
||||
import TensorFlow.Types (OneOf, TensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
|
|
@ -95,7 +95,7 @@ import TensorFlow.Tensor
|
|||
, tensorOutput
|
||||
, tensorAttr
|
||||
)
|
||||
import TensorFlow.Types (OneOf, TensorType, attrLens)
|
||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
(NodeDef, attr, input, op, name)
|
||||
|
||||
|
@ -406,7 +406,7 @@ toT = Tensor ValueKind
|
|||
|
||||
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
|
||||
-- simple slicing operations.
|
||||
flatSlice :: forall v1 t i . (TensorType t)
|
||||
flatSlice :: forall v1 t . (TensorType t)
|
||||
=> Tensor v1 t -- ^ __input__
|
||||
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
||||
-- 'input' to slice from.
|
||||
|
@ -415,7 +415,7 @@ flatSlice :: forall v1 t i . (TensorType t)
|
|||
-- are included in the slice (i.e. this is equivalent to setting
|
||||
-- size = input.dim_size(0) - begin).
|
||||
-> Tensor Value t -- ^ __output__
|
||||
flatSlice input begin size = CoreOps.slice input (vector [begin]) (vector [size])
|
||||
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
|
||||
|
||||
|
||||
-- | The gradient function for an op type.
|
||||
|
@ -703,10 +703,14 @@ numOutputs o =
|
|||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||
|
||||
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
||||
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32
|
||||
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
||||
|
||||
allDimensions :: Tensor Value Int32
|
||||
allDimensions = vector [-1 :: Int32]
|
||||
|
||||
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32
|
||||
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
||||
|
||||
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
|
||||
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens
|
||||
|
|
|
@ -17,6 +17,7 @@ module Main where
|
|||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Google.Test (googleTest)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import qualified Data.Vector as V
|
||||
|
@ -26,6 +27,7 @@ import qualified TensorFlow.Session as TF
|
|||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
-- | Test split and concat are inverses.
|
||||
testSplit :: Test
|
||||
testSplit = testCase "testSplit" $ TF.runSession $ do
|
||||
let original = TF.constant [2, 3] [0..5 :: Float]
|
||||
splitList = CoreOps.split 3 dim original
|
||||
|
|
|
@ -59,12 +59,14 @@ import TensorFlow.Session
|
|||
, runSession
|
||||
, run_
|
||||
)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import Google.Test (googleTest)
|
||||
import qualified Data.Vector as V
|
||||
|
||||
-- | Test named behavior.
|
||||
testNamed :: Test
|
||||
testNamed = testCase "testNamed" $ do
|
||||
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
|
||||
nodeDef :: NodeDef
|
||||
|
@ -73,6 +75,7 @@ testNamed = testCase "testNamed" $ do
|
|||
"foo" @=? (nodeDef ^. name)
|
||||
|
||||
-- | Test named deRef behavior.
|
||||
testNamedDeRef :: Test
|
||||
testNamedDeRef = testCase "testNamedDeRef" $ do
|
||||
let graph = named "foo" <$> do
|
||||
v :: Tensor Ref Float <- variable []
|
||||
|
@ -84,11 +87,13 @@ testNamedDeRef = testCase "testNamedDeRef" $ do
|
|||
|
||||
-- | Test that "run" will render and extend any pure ops that haven't already
|
||||
-- been rendered.
|
||||
testPureRender :: Test
|
||||
testPureRender = testCase "testPureRender" $ runSession $ do
|
||||
result <- run $ 2 `add` 2
|
||||
liftIO $ 4 @=? (unScalar result :: Float)
|
||||
|
||||
-- | Test that "run" assigns any previously accumulated initializers.
|
||||
testInitializedVariable :: Test
|
||||
testInitializedVariable =
|
||||
testCase "testInitializedVariable" $ runSession $ do
|
||||
(formula, reset) <- build $ do
|
||||
|
@ -101,6 +106,7 @@ testInitializedVariable =
|
|||
rerunResult <- run formula
|
||||
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
||||
|
||||
testInitializedVariableShape :: Test
|
||||
testInitializedVariableShape =
|
||||
testCase "testInitializedVariableShape" $ runSession $ do
|
||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
||||
|
@ -108,6 +114,7 @@ testInitializedVariableShape =
|
|||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
|
||||
-- | Test nameScoped behavior.
|
||||
testNameScoped :: Test
|
||||
testNameScoped = testCase "testNameScoped" $ do
|
||||
let graph = withNameScope "foo" $ variable [] :: Build (Tensor Ref Float)
|
||||
nodeDef :: NodeDef
|
||||
|
@ -116,6 +123,7 @@ testNameScoped = testCase "testNameScoped" $ do
|
|||
"Variable" @=? (nodeDef ^. op)
|
||||
|
||||
-- | Test combined named and nameScoped behavior.
|
||||
testNamedAndScoped :: Test
|
||||
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||
let graph :: Build (Tensor Ref Float)
|
||||
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
|
||||
|
@ -133,6 +141,7 @@ flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
|
|||
flushed field = sort . map field <$> liftBuild flushNodeBuffer
|
||||
|
||||
-- | Test the interaction of rendering, CSE and scoping.
|
||||
testRenderDedup :: Test
|
||||
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
||||
liftBuild renderNodes
|
||||
names <- flushed (^. name)
|
||||
|
@ -154,6 +163,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
|||
return ()
|
||||
|
||||
-- | Test the interaction of rendering, CSE and scoping.
|
||||
testDeviceColocation :: Test
|
||||
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
|
||||
liftBuild renderNodes
|
||||
devices <- flushed (\x -> (x ^. name, x ^. device))
|
||||
|
|
|
@ -22,9 +22,9 @@ import Data.Int (Int32, Int64)
|
|||
import Data.List (genericLength)
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.EmbeddingOps (embeddingLookup)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||
import Test.HUnit.Lang (Assertion(..))
|
||||
import Test.HUnit ((@=?), (@?))
|
||||
import Test.HUnit ((@=?))
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
||||
import Test.QuickCheck.Monadic (monadicIO, run)
|
||||
|
@ -46,13 +46,14 @@ buildAndRun = TF.runSession . TF.buildAnd TF.run
|
|||
|
||||
|
||||
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||
testEmbeddingLookupHasRightShapeWithPartition =
|
||||
testEmbeddingLookupHasRightShapeWithPartition :: Test
|
||||
testEmbeddingLookupHasRightShapeWithPartition =
|
||||
testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
|
||||
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||
let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||
let embedding1 = [1, 1, 1 :: Int32]
|
||||
let embedding2 = [0, 0, 0 :: Int32]
|
||||
let embedding = [ TF.constant shape embedding1
|
||||
, TF.constant shape embedding2
|
||||
let embedding = [ TF.constant embShape embedding1
|
||||
, TF.constant embShape embedding2
|
||||
]
|
||||
|
||||
let idValues = [0, 1 :: Int32]
|
||||
|
@ -71,15 +72,16 @@ testEmbeddingLookupHasRightShapeWithPartition =
|
|||
|
||||
|
||||
-- | Tries to perform a simple embedding lookup, with only a single partition.
|
||||
testEmbeddingLookupHasRightShape =
|
||||
testEmbeddingLookupHasRightShape :: Test
|
||||
testEmbeddingLookupHasRightShape =
|
||||
testCase "testEmbeddingLookupHasRightShape" $ do
|
||||
-- Consider a 3-dim embedding of two items
|
||||
let shape = TF.Shape [2, 3]
|
||||
let embShape = TF.Shape [2, 3]
|
||||
let embeddingInit = [ 1, 1, 1
|
||||
, 0, 0, 0 :: Int32
|
||||
]
|
||||
|
||||
let embedding = TF.constant shape embeddingInit
|
||||
let embedding = TF.constant embShape embeddingInit
|
||||
let idValues = [0, 1 :: Int32]
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
let op = embeddingLookup [embedding] ids
|
||||
|
@ -96,6 +98,7 @@ testEmbeddingLookupHasRightShape =
|
|||
|
||||
|
||||
-- | Check that we can calculate gradients w.r.t embeddings.
|
||||
testEmbeddingLookupGradients :: Test
|
||||
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||
-- Agrees with "embedding", so gradient should be zero.
|
||||
let xVals = V.fromList ([20, 20 :: Float])
|
||||
|
@ -103,15 +106,15 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|||
|
||||
gs <- TF.runSession $ do
|
||||
grads <- TF.build $ do
|
||||
let shape = TF.Shape [2, 1]
|
||||
let embShape = TF.Shape [2, 1]
|
||||
let embeddingInit = [1, 20 ::Float]
|
||||
let idValues = [1, 1 :: Int32]
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
|
||||
x <- TF.placeholder (TF.Shape [2])
|
||||
embedding <- TF.initializedVariable
|
||||
=<< TF.render (TF.constant shape embeddingInit)
|
||||
|
||||
x <- TF.placeholder (TF.Shape [2])
|
||||
embedding <- TF.initializedVariable
|
||||
=<< TF.render (TF.constant embShape embeddingInit)
|
||||
|
||||
op <- embeddingLookup [embedding] ids
|
||||
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
||||
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
||||
|
@ -163,7 +166,9 @@ instance Arbitrary a => Arbitrary (LookupExample a) where
|
|||
arbitrary = do
|
||||
rank <- choose (1, 4)
|
||||
-- Takes rank-th root of 100 to cap the tensor size.
|
||||
let maxDim = fromIntegral $ ceiling $ 100 ** (1 / fromIntegral rank)
|
||||
let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64)
|
||||
doubleMaxDim :: Double
|
||||
doubleMaxDim = 100 ** (1 / fromIntegral rank)
|
||||
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
|
||||
values <- vectorOf (fromIntegral $ product shape) arbitrary
|
||||
numParts <- choose (2, 15)
|
||||
|
|
|
@ -19,6 +19,7 @@ module Main where
|
|||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import Google.Test
|
||||
|
@ -29,6 +30,7 @@ import TensorFlow.Ops
|
|||
import TensorFlow.Session
|
||||
|
||||
-- | Test fetching multiple outputs from an op.
|
||||
testMultipleOutputs :: Test
|
||||
testMultipleOutputs = testCase "testMultipleOutputs" $
|
||||
runSession $ do
|
||||
(values, indices) <-
|
||||
|
@ -37,6 +39,7 @@ testMultipleOutputs = testCase "testMultipleOutputs" $
|
|||
liftIO $ [1, 3] @=? V.toList (indices :: V.Vector Int32)
|
||||
|
||||
-- | Test op with variable number of inputs.
|
||||
testVarargs :: Test
|
||||
testVarargs = testCase "testVarargs" $
|
||||
runSession $ do
|
||||
xs <- run $ pack $ map scalar [1..8]
|
||||
|
|
|
@ -12,9 +12,7 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
|
||||
module Main where
|
||||
|
||||
|
@ -22,6 +20,7 @@ import Control.Monad.IO.Class (liftIO)
|
|||
import Data.Int (Int32, Int64)
|
||||
import Google.Test (googleTest)
|
||||
import System.IO.Temp (withSystemTempDirectory)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import qualified Data.ByteString.Char8 as B8
|
||||
|
@ -33,25 +32,31 @@ import qualified TensorFlow.Nodes as TF
|
|||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
-- | Test that one can easily determine number of elements in the tensor.
|
||||
testSize :: Test
|
||||
testSize = testCase "testSize" $ do
|
||||
x <- eval $ TF.size (TF.constant [2, 3] [0..5 :: Float])
|
||||
x <- eval $ TF.size (TF.constant (TF.Shape [2, 3]) [0..5 :: Float])
|
||||
TF.Scalar (2 * 3 :: Int32) @=? x
|
||||
|
||||
eval :: TF.Fetchable t a => t -> IO a
|
||||
eval = TF.runSession . TF.buildAnd TF.run . return
|
||||
|
||||
-- | Confirms that the original example from Python code works.
|
||||
testReducedShape :: Test
|
||||
testReducedShape = testCase "testReducedShape" $ do
|
||||
x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64])
|
||||
(TF.vector [1, 2 :: Int32])
|
||||
V.fromList [2, 1, 1, 7 :: Int32] @=? x
|
||||
|
||||
testSaveRestore :: Test
|
||||
testSaveRestore = testCase "testSaveRestore" $
|
||||
withSystemTempDirectory "" $ \dirPath -> do
|
||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||
var :: TF.Build (TF.Tensor TF.Ref Float)
|
||||
var = TF.render =<< TF.named "a" <$> TF.zeroInitializedVariable []
|
||||
var = TF.render =<<
|
||||
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
||||
TF.runSession $ do
|
||||
v <- TF.build var
|
||||
TF.buildAnd TF.run_ $ TF.assign v 134
|
||||
|
|
|
@ -21,11 +21,10 @@ import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar)
|
|||
import Data.ByteString.Builder (toLazyByteString)
|
||||
import Data.ByteString.Lazy (isPrefixOf)
|
||||
import Data.Default (def)
|
||||
import Data.Monoid ((<>))
|
||||
import Lens.Family2 ((&), (.~))
|
||||
import Test.Framework (defaultMain)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?), assertBool, assertFailure)
|
||||
import Test.HUnit (assertBool, assertFailure)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
|
@ -44,6 +43,7 @@ testTracing = do
|
|||
assertBool ("Unexpected log entry " ++ show got)
|
||||
("Session.extend" `isPrefixOf` got)
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain
|
||||
[ testCase "Tracing" testTracing
|
||||
]
|
||||
|
|
|
@ -19,11 +19,14 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
-- Purposely disabled to confirm doubleFuncNoSig can be written without type.
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
|
||||
import Control.Monad (replicateM)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int64)
|
||||
import Google.Test (googleTest)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||
import Test.HUnit ((@=?))
|
||||
|
@ -43,10 +46,11 @@ instance Arbitrary B.ByteString where
|
|||
|
||||
-- Test encoding tensors, feeding them through tensorflow, and decoding the
|
||||
-- results.
|
||||
testFFIRoundTrip :: Test
|
||||
testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
||||
TF.runSession $ do
|
||||
let floatData = V.fromList [1..6 :: Float]
|
||||
stringData = V.fromList [B8.pack (show x) | x <- [1..6]]
|
||||
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
|
||||
f <- TF.build $ TF.placeholder [2,3]
|
||||
s <- TF.build $ TF.placeholder [2,3]
|
||||
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
||||
|
@ -78,12 +82,15 @@ encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
|
|||
encodeDecodeProp (TensorDataInputs shape vec) =
|
||||
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
|
||||
|
||||
testEncodeDecodeQcFloat :: Test
|
||||
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
|
||||
(encodeDecodeProp :: TensorDataInputs Float -> Bool)
|
||||
|
||||
testEncodeDecodeQcInt64 :: Test
|
||||
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
|
||||
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
|
||||
|
||||
testEncodeDecodeQcString :: Test
|
||||
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
|
||||
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
|
||||
|
||||
|
@ -101,6 +108,7 @@ doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
|
|||
-- can't simplify the type all the way to `Double -> Double`.
|
||||
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
|
||||
|
||||
typeConstraintTests :: Test
|
||||
typeConstraintTests = testCase "type constraints" $ do
|
||||
42 @=? doubleOrInt64Func (42 :: Double)
|
||||
42 @=? doubleOrInt64Func (42 :: Int64)
|
||||
|
|
|
@ -31,11 +31,13 @@ import TensorFlow.Session
|
|||
, runSession
|
||||
, run_
|
||||
)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import qualified Data.ByteString as BS
|
||||
|
||||
-- | Test basic queue behaviors.
|
||||
testBasic :: Test
|
||||
testBasic = testCase "testBasic" $ runSession $ do
|
||||
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
|
||||
buildAnd run_ (enqueue q 42 (scalar "Hi"))
|
||||
|
@ -47,6 +49,7 @@ testBasic = testCase "testBasic" $ runSession $ do
|
|||
liftIO $ (Scalar 56, Scalar "Bar") @=? y
|
||||
|
||||
-- | Test queue pumping.
|
||||
testPump :: Test
|
||||
testPump = testCase "testPump" $ runSession $ do
|
||||
(deq, pump) <- build $ do
|
||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
|
||||
|
@ -61,6 +64,7 @@ testPump = testCase "testPump" $ runSession $ do
|
|||
liftIO $ (Scalar 31, Scalar "Baz") @=? x
|
||||
liftIO $ (Scalar 31, Scalar "Baz") @=? y
|
||||
|
||||
testAsync :: Test
|
||||
testAsync = testCase "testAsync" $ runSession $ do
|
||||
(deq, pump) <- build $ do
|
||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
|
||||
|
|
|
@ -20,8 +20,7 @@ module TensorFlow.Test
|
|||
|
||||
import qualified Data.Vector as V
|
||||
import Test.HUnit ((@?))
|
||||
import Test.HUnit.Lang (Assertion(..))
|
||||
|
||||
import Test.HUnit.Lang (Assertion)
|
||||
-- | Compares that the vectors are element-by-element equal within the given
|
||||
-- tolerance. Raises an assertion and prints some information if not.
|
||||
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||
|
@ -31,4 +30,3 @@ assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
|||
where
|
||||
absDiff x y = abs (x - y)
|
||||
tol = 0.001 :: Float
|
||||
|
||||
|
|
|
@ -36,7 +36,6 @@ import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
|||
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
||||
import Control.Monad (when)
|
||||
import Data.Bits (Bits, toIntegralSized)
|
||||
import Data.Data (Data, dataTypeName, dataTypeOf)
|
||||
import Data.Int (Int64)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Typeable (Typeable)
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module TensorFlow.Output
|
||||
( ControlNode(..)
|
||||
|
@ -150,8 +151,8 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
|
|||
-- code into a Build function
|
||||
instance IsString Output where
|
||||
fromString s = case break (==':') s of
|
||||
(n, ':':ixStr)
|
||||
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
|
||||
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
|
||||
-> Output (fromInteger ix) $ assigned n
|
||||
_ -> Output 0 $ assigned s
|
||||
where assigned n = Rendered $ def & name .~ Text.pack n
|
||||
|
||||
|
|
|
@ -46,24 +46,22 @@ import Data.ByteString (ByteString)
|
|||
import Data.Default (Default, def)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Data.Monoid ((<>))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import Data.ProtoLens (showMessage)
|
||||
import Data.Set (Set)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.ProtoLens (def, showMessage)
|
||||
import Lens.Family2 (Lens', (^.), (&), (.~))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Nodes
|
||||
import TensorFlow.Output (NodeName, unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
|
||||
import qualified Data.ByteString.Builder as Builder
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
|
||||
-- | An action for logging.
|
||||
type Tracer = Builder.Builder -> IO ()
|
||||
|
|
|
@ -33,6 +33,7 @@ testParseAll = do
|
|||
. not . null . view op)
|
||||
(decodeMessage opList :: Either String OpList)
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain
|
||||
[ testCase "ParseAllOps" testParseAll
|
||||
]
|
||||
|
|
|
@ -16,11 +16,13 @@ module Main where
|
|||
|
||||
import Data.ByteString.Builder (toLazyByteString)
|
||||
import Google.Test (googleTest)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
||||
import qualified Data.Attoparsec.ByteString.Lazy as Atto
|
||||
|
||||
import TensorFlow.Internal.VarInt
|
||||
|
||||
testEncodeDecode :: Test
|
||||
testEncodeDecode = testProperty "testEncodeDecode" $ \x ->
|
||||
let bytes = toLazyByteString (putVarInt x)
|
||||
in case Atto.eitherResult $ Atto.parse getVarInt bytes of
|
||||
|
|
Loading…
Reference in a new issue