Make code --pedantic (#35)

* Enforce pedantic build mode in CI.
* Our imports drifted really far from where they should be.
This commit is contained in:
Greg Steuck 2016-11-18 10:42:02 -08:00 committed by GitHub
parent 69fdbf677f
commit 2b5e41ffeb
21 changed files with 101 additions and 56 deletions

View File

@ -8,4 +8,4 @@ IMAGE_NAME=tensorflow/haskell/ci_build:v0
git submodule update git submodule update
docker build -t $IMAGE_NAME -f ci_build/Dockerfile . docker build -t $IMAGE_NAME -f ci_build/Dockerfile .
docker run $IMAGE_NAME stack test docker run $IMAGE_NAME stack build --pedantic --test

View File

@ -15,7 +15,7 @@
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedLists #-}
import Control.Monad (zipWithM, when, forM, forM_) import Control.Monad (zipWithM, when, forM_)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.List (genericLength) import Data.List (genericLength)
@ -34,7 +34,8 @@ import qualified TensorFlow.Types as TF
import TensorFlow.Examples.MNIST.InputData import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse import TensorFlow.Examples.MNIST.Parse
numPixels = 28^2 :: Int64 numPixels, numLabels :: Int64
numPixels = 28*28 :: Int64
numLabels = 10 :: Int64 numLabels = 10 :: Int64
-- | Create tensor with random values where the stddev depends on the width. -- | Create tensor with random values where the stddev depends on the width.
@ -44,6 +45,7 @@ randomParam width (TF.Shape shape) =
where where
stddev = TF.scalar (1 / sqrt (fromIntegral width)) stddev = TF.scalar (1 / sqrt (fromIntegral width))
reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32)) reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
-- Types must match due to model structure. -- Types must match due to model structure.
@ -108,6 +110,7 @@ createModel = do
] errorRateTensor ] errorRateTensor
} }
main :: IO ()
main = TF.runSession $ do main = TF.runSession $ do
-- Read training and test data. -- Read training and test data.
trainingImages <- liftIO (readMNISTSamples =<< trainingImageData) trainingImages <- liftIO (readMNISTSamples =<< trainingImageData)

View File

@ -52,12 +52,14 @@ import TensorFlow.Nodes (unScalar)
import TensorFlow.Session import TensorFlow.Session
(runSession, run, run_, runWithFeeds, build, buildAnd) (runSession, run, run_, runWithFeeds, build, buildAnd)
import TensorFlow.Types (TensorType(..), Shape(..)) import TensorFlow.Types (TensorType(..), Shape(..))
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), Assertion) import Test.HUnit ((@=?), Assertion)
import Google.Test import Google.Test
import qualified Data.Vector as V import qualified Data.Vector as V
-- | Test that a file can be read and the GraphDef proto correctly parsed. -- | Test that a file can be read and the GraphDef proto correctly parsed.
testReadMessageFromFileOrDie :: Test
testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
-- Check the function on a known well-formatted file. -- Check the function on a known well-formatted file.
mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef mnist <- readMessageFromFileOrDie =<< mnistPb :: IO GraphDef
@ -72,6 +74,7 @@ testReadMessageFromFileOrDie = testCase "testReadMessageFromFileOrDie" $ do
-- | Parse the test set for label and image data. Will only fail if the file is -- | Parse the test set for label and image data. Will only fail if the file is
-- missing or incredibly corrupt. -- missing or incredibly corrupt.
testReadMNIST :: Test
testReadMNIST = testCase "testReadMNIST" $ do testReadMNIST = testCase "testReadMNIST" $ do
imageData <- readMNISTSamples =<< testImageData imageData <- readMNISTSamples =<< testImageData
10000 @=? length imageData 10000 @=? length imageData
@ -84,6 +87,7 @@ testNodeName n g = n @=? opName
opName = head (gDef^.node)^.op opName = head (gDef^.node)^.op
gDef = asGraphDef $ render g gDef = asGraphDef $ render g
testGraphDefGen :: Test
testGraphDefGen = testCase "testGraphDefGen" $ do testGraphDefGen = testCase "testGraphDefGen" $ do
-- Test the inferred operation type. -- Test the inferred operation type.
let f0 :: Tensor Value Float let f0 :: Tensor Value Float
@ -101,6 +105,7 @@ testGraphDefGen = testCase "testGraphDefGen" $ do
testNodeName "Mul" $ (1 + f0) * 2 testNodeName "Mul" $ (1 + f0) * 2
-- | Convert a simple graph to GraphDef, load it, run it, and check the output. -- | Convert a simple graph to GraphDef, load it, run it, and check the output.
testGraphDefExec :: Test
testGraphDefExec = testCase "testGraphDefExec" $ do testGraphDefExec = testCase "testGraphDefExec" $ do
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10 let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
runSession $ do runSession $ do
@ -110,6 +115,7 @@ testGraphDefExec = testCase "testGraphDefExec" $ do
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on -- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
-- sample data. -- sample data.
testMNISTExec :: Test
testMNISTExec = testCase "testMNISTExec" $ do testMNISTExec = testCase "testMNISTExec" $ do
-- Switch to unicode to enable pretty printing of MNIST digits. -- Switch to unicode to enable pretty printing of MNIST digits.
IO.hSetEncoding IO.stdout IO.utf8 IO.hSetEncoding IO.stdout IO.utf8

View File

@ -22,7 +22,7 @@ module TensorFlow.NN
import Prelude hiding ( log import Prelude hiding ( log
, exp , exp
) )
import TensorFlow.Build ( Build(..) import TensorFlow.Build ( Build
, render , render
, withNameScope , withNameScope
) )
@ -32,7 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual
, exp , exp
) )
import TensorFlow.Tensor ( Tensor(..) import TensorFlow.Tensor ( Tensor(..)
, Value(..) , Value
) )
import TensorFlow.Types ( TensorType(..) import TensorFlow.Types ( TensorType(..)
, OneOf , OneOf

View File

@ -12,28 +12,22 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLists #-}
module Main where module Main where
import Data.Maybe (fromMaybe)
import Google.Test (googleTest) import Google.Test (googleTest)
import TensorFlow.Test (assertAllClose) import TensorFlow.Test (assertAllClose)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@?))
import Test.HUnit.Lang (Assertion(..))
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.Build as TF import qualified TensorFlow.Build as TF
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.NN as TF import qualified TensorFlow.NN as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
-- | These tests are ported from: -- | These tests are ported from:
-- --
@ -46,9 +40,9 @@ sigmoidXentWithLogits :: Floating a => Ord a => [a] -> [a] -> [a]
sigmoidXentWithLogits logits' targets' = sigmoidXentWithLogits logits' targets' =
let sig = map (\x -> 1 / (1 + exp (-x))) logits' let sig = map (\x -> 1 / (1 + exp (-x))) logits'
eps = 0.0001 eps = 0.0001
pred = map (\p -> min (max p eps) (1 - eps)) sig predictions = map (\p -> min (max p eps) (1 - eps)) sig
xent y z = (-z) * (log y) - (1 - z) * log (1 - y) xent y z = (-z) * (log y) - (1 - z) * log (1 - y)
in zipWith xent pred targets' in zipWith xent predictions targets'
data Inputs = Inputs { data Inputs = Inputs {
@ -64,6 +58,7 @@ defInputs = Inputs {
} }
testLogisticOutput :: Test
testLogisticOutput = testCase "testLogisticOutput" $ do testLogisticOutput = testCase "testLogisticOutput" $ do
let inputs = defInputs let inputs = defInputs
vLogits = TF.vector $ logits inputs vLogits = TF.vector $ logits inputs
@ -75,6 +70,7 @@ testLogisticOutput = testCase "testLogisticOutput" $ do
assertAllClose r ourLoss assertAllClose r ourLoss
testLogisticOutputMultipleDim :: Test
testLogisticOutputMultipleDim = testLogisticOutputMultipleDim =
testCase "testLogisticOutputMultipleDim" $ do testCase "testLogisticOutputMultipleDim" $ do
let inputs = defInputs let inputs = defInputs
@ -88,6 +84,7 @@ testLogisticOutputMultipleDim =
assertAllClose r ourLoss assertAllClose r ourLoss
testGradientAtZero :: Test
testGradientAtZero = testCase "testGradientAtZero" $ do testGradientAtZero = testCase "testGradientAtZero" $ do
let inputs = defInputs { logits = [0, 0], targets = [0, 1] } let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
vLogits = TF.vector $ logits inputs vLogits = TF.vector $ logits inputs
@ -100,10 +97,9 @@ testGradientAtZero = testCase "testGradientAtZero" $ do
assertAllClose (head r) (V.fromList [0.5, -0.5]) assertAllClose (head r) (V.fromList [0.5, -0.5])
run :: TF.Fetchable t a => TF.Build t -> IO a
run = TF.runSession . TF.buildAnd TF.run run = TF.runSession . TF.buildAnd TF.run
main :: IO () main :: IO ()
main = googleTest [ testGradientAtZero main = googleTest [ testGradientAtZero
, testLogisticOutput , testLogisticOutput

View File

@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM) import Control.Monad (zipWithM)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import TensorFlow.Build (Build, colocateWith, render) import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType) import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps

View File

@ -95,7 +95,7 @@ import TensorFlow.Tensor
, tensorOutput , tensorOutput
, tensorAttr , tensorAttr
) )
import TensorFlow.Types (OneOf, TensorType, attrLens) import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef import Proto.Tensorflow.Core.Framework.NodeDef
(NodeDef, attr, input, op, name) (NodeDef, attr, input, op, name)
@ -406,7 +406,7 @@ toT = Tensor ValueKind
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for -- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
-- simple slicing operations. -- simple slicing operations.
flatSlice :: forall v1 t i . (TensorType t) flatSlice :: forall v1 t . (TensorType t)
=> Tensor v1 t -- ^ __input__ => Tensor v1 t -- ^ __input__
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of -> Int32 -- ^ __begin__: specifies the offset into the first dimension of
-- 'input' to slice from. -- 'input' to slice from.
@ -415,7 +415,7 @@ flatSlice :: forall v1 t i . (TensorType t)
-- are included in the slice (i.e. this is equivalent to setting -- are included in the slice (i.e. this is equivalent to setting
-- size = input.dim_size(0) - begin). -- size = input.dim_size(0) - begin).
-> Tensor Value t -- ^ __output__ -> Tensor Value t -- ^ __output__
flatSlice input begin size = CoreOps.slice input (vector [begin]) (vector [size]) flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
-- | The gradient function for an op type. -- | The gradient function for an op type.
@ -703,10 +703,14 @@ numOutputs o =
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op) _ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0` -- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1) safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
allDimensions :: Tensor Value Int32
allDimensions = vector [-1 :: Int32] allDimensions = vector [-1 :: Int32]
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1 rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens

View File

@ -17,6 +17,7 @@ module Main where
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Google.Test (googleTest) import Google.Test (googleTest)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
import qualified Data.Vector as V import qualified Data.Vector as V
@ -26,6 +27,7 @@ import qualified TensorFlow.Session as TF
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
-- | Test split and concat are inverses. -- | Test split and concat are inverses.
testSplit :: Test
testSplit = testCase "testSplit" $ TF.runSession $ do testSplit = testCase "testSplit" $ TF.runSession $ do
let original = TF.constant [2, 3] [0..5 :: Float] let original = TF.constant [2, 3] [0..5 :: Float]
splitList = CoreOps.split 3 dim original splitList = CoreOps.split 3 dim original

View File

@ -59,12 +59,14 @@ import TensorFlow.Session
, runSession , runSession
, run_ , run_
) )
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
import Google.Test (googleTest) import Google.Test (googleTest)
import qualified Data.Vector as V import qualified Data.Vector as V
-- | Test named behavior. -- | Test named behavior.
testNamed :: Test
testNamed = testCase "testNamed" $ do testNamed = testCase "testNamed" $ do
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float) let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
nodeDef :: NodeDef nodeDef :: NodeDef
@ -73,6 +75,7 @@ testNamed = testCase "testNamed" $ do
"foo" @=? (nodeDef ^. name) "foo" @=? (nodeDef ^. name)
-- | Test named deRef behavior. -- | Test named deRef behavior.
testNamedDeRef :: Test
testNamedDeRef = testCase "testNamedDeRef" $ do testNamedDeRef = testCase "testNamedDeRef" $ do
let graph = named "foo" <$> do let graph = named "foo" <$> do
v :: Tensor Ref Float <- variable [] v :: Tensor Ref Float <- variable []
@ -84,11 +87,13 @@ testNamedDeRef = testCase "testNamedDeRef" $ do
-- | Test that "run" will render and extend any pure ops that haven't already -- | Test that "run" will render and extend any pure ops that haven't already
-- been rendered. -- been rendered.
testPureRender :: Test
testPureRender = testCase "testPureRender" $ runSession $ do testPureRender = testCase "testPureRender" $ runSession $ do
result <- run $ 2 `add` 2 result <- run $ 2 `add` 2
liftIO $ 4 @=? (unScalar result :: Float) liftIO $ 4 @=? (unScalar result :: Float)
-- | Test that "run" assigns any previously accumulated initializers. -- | Test that "run" assigns any previously accumulated initializers.
testInitializedVariable :: Test
testInitializedVariable = testInitializedVariable =
testCase "testInitializedVariable" $ runSession $ do testCase "testInitializedVariable" $ runSession $ do
(formula, reset) <- build $ do (formula, reset) <- build $ do
@ -101,6 +106,7 @@ testInitializedVariable =
rerunResult <- run formula rerunResult <- run formula
liftIO $ 25 @=? (unScalar rerunResult :: Float) liftIO $ 25 @=? (unScalar rerunResult :: Float)
testInitializedVariableShape :: Test
testInitializedVariableShape = testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do testCase "testInitializedVariableShape" $ runSession $ do
vector <- build $ initializedVariable (constant [1] [42 :: Float]) vector <- build $ initializedVariable (constant [1] [42 :: Float])
@ -108,6 +114,7 @@ testInitializedVariableShape =
liftIO $ [42] @=? (result :: V.Vector Float) liftIO $ [42] @=? (result :: V.Vector Float)
-- | Test nameScoped behavior. -- | Test nameScoped behavior.
testNameScoped :: Test
testNameScoped = testCase "testNameScoped" $ do testNameScoped = testCase "testNameScoped" $ do
let graph = withNameScope "foo" $ variable [] :: Build (Tensor Ref Float) let graph = withNameScope "foo" $ variable [] :: Build (Tensor Ref Float)
nodeDef :: NodeDef nodeDef :: NodeDef
@ -116,6 +123,7 @@ testNameScoped = testCase "testNameScoped" $ do
"Variable" @=? (nodeDef ^. op) "Variable" @=? (nodeDef ^. op)
-- | Test combined named and nameScoped behavior. -- | Test combined named and nameScoped behavior.
testNamedAndScoped :: Test
testNamedAndScoped = testCase "testNamedAndScoped" $ do testNamedAndScoped = testCase "testNamedAndScoped" $ do
let graph :: Build (Tensor Ref Float) let graph :: Build (Tensor Ref Float)
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render) graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
@ -133,6 +141,7 @@ flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
flushed field = sort . map field <$> liftBuild flushNodeBuffer flushed field = sort . map field <$> liftBuild flushNodeBuffer
-- | Test the interaction of rendering, CSE and scoping. -- | Test the interaction of rendering, CSE and scoping.
testRenderDedup :: Test
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
liftBuild renderNodes liftBuild renderNodes
names <- flushed (^. name) names <- flushed (^. name)
@ -154,6 +163,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
return () return ()
-- | Test the interaction of rendering, CSE and scoping. -- | Test the interaction of rendering, CSE and scoping.
testDeviceColocation :: Test
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
liftBuild renderNodes liftBuild renderNodes
devices <- flushed (\x -> (x ^. name, x ^. device)) devices <- flushed (\x -> (x ^. name, x ^. device))

View File

@ -22,9 +22,9 @@ import Data.Int (Int32, Int64)
import Data.List (genericLength) import Data.List (genericLength)
import Google.Test (googleTest) import Google.Test (googleTest)
import TensorFlow.EmbeddingOps (embeddingLookup) import TensorFlow.EmbeddingOps (embeddingLookup)
import Test.Framework (Test)
import Test.Framework.Providers.QuickCheck2 (testProperty) import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit.Lang (Assertion(..)) import Test.HUnit ((@=?))
import Test.HUnit ((@=?), (@?))
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf) import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run) import Test.QuickCheck.Monadic (monadicIO, run)
@ -46,13 +46,14 @@ buildAndRun = TF.runSession . TF.buildAnd TF.run
-- | Tries to perform a simple embedding lookup, with two partitions. -- | Tries to perform a simple embedding lookup, with two partitions.
testEmbeddingLookupHasRightShapeWithPartition = testEmbeddingLookupHasRightShapeWithPartition :: Test
testEmbeddingLookupHasRightShapeWithPartition =
testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items. let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
let embedding1 = [1, 1, 1 :: Int32] let embedding1 = [1, 1, 1 :: Int32]
let embedding2 = [0, 0, 0 :: Int32] let embedding2 = [0, 0, 0 :: Int32]
let embedding = [ TF.constant shape embedding1 let embedding = [ TF.constant embShape embedding1
, TF.constant shape embedding2 , TF.constant embShape embedding2
] ]
let idValues = [0, 1 :: Int32] let idValues = [0, 1 :: Int32]
@ -71,15 +72,16 @@ testEmbeddingLookupHasRightShapeWithPartition =
-- | Tries to perform a simple embedding lookup, with only a single partition. -- | Tries to perform a simple embedding lookup, with only a single partition.
testEmbeddingLookupHasRightShape = testEmbeddingLookupHasRightShape :: Test
testEmbeddingLookupHasRightShape =
testCase "testEmbeddingLookupHasRightShape" $ do testCase "testEmbeddingLookupHasRightShape" $ do
-- Consider a 3-dim embedding of two items -- Consider a 3-dim embedding of two items
let shape = TF.Shape [2, 3] let embShape = TF.Shape [2, 3]
let embeddingInit = [ 1, 1, 1 let embeddingInit = [ 1, 1, 1
, 0, 0, 0 :: Int32 , 0, 0, 0 :: Int32
] ]
let embedding = TF.constant shape embeddingInit let embedding = TF.constant embShape embeddingInit
let idValues = [0, 1 :: Int32] let idValues = [0, 1 :: Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids let op = embeddingLookup [embedding] ids
@ -96,6 +98,7 @@ testEmbeddingLookupHasRightShape =
-- | Check that we can calculate gradients w.r.t embeddings. -- | Check that we can calculate gradients w.r.t embeddings.
testEmbeddingLookupGradients :: Test
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
-- Agrees with "embedding", so gradient should be zero. -- Agrees with "embedding", so gradient should be zero.
let xVals = V.fromList ([20, 20 :: Float]) let xVals = V.fromList ([20, 20 :: Float])
@ -103,15 +106,15 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
gs <- TF.runSession $ do gs <- TF.runSession $ do
grads <- TF.build $ do grads <- TF.build $ do
let shape = TF.Shape [2, 1] let embShape = TF.Shape [2, 1]
let embeddingInit = [1, 20 ::Float] let embeddingInit = [1, 20 ::Float]
let idValues = [1, 1 :: Int32] let idValues = [1, 1 :: Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues let ids = TF.constant (TF.Shape [1, 2]) idValues
x <- TF.placeholder (TF.Shape [2]) x <- TF.placeholder (TF.Shape [2])
embedding <- TF.initializedVariable embedding <- TF.initializedVariable
=<< TF.render (TF.constant shape embeddingInit) =<< TF.render (TF.constant embShape embeddingInit)
op <- embeddingLookup [embedding] ids op <- embeddingLookup [embedding] ids
let twoNorm = CoreOps.square $ TF.abs (op - x) let twoNorm = CoreOps.square $ TF.abs (op - x)
loss = TF.mean twoNorm (TF.scalar (0 :: Int32)) loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
@ -163,7 +166,9 @@ instance Arbitrary a => Arbitrary (LookupExample a) where
arbitrary = do arbitrary = do
rank <- choose (1, 4) rank <- choose (1, 4)
-- Takes rank-th root of 100 to cap the tensor size. -- Takes rank-th root of 100 to cap the tensor size.
let maxDim = fromIntegral $ ceiling $ 100 ** (1 / fromIntegral rank) let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64)
doubleMaxDim :: Double
doubleMaxDim = 100 ** (1 / fromIntegral rank)
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim)) shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
values <- vectorOf (fromIntegral $ product shape) arbitrary values <- vectorOf (fromIntegral $ product shape) arbitrary
numParts <- choose (2, 15) numParts <- choose (2, 15)

View File

@ -19,6 +19,7 @@ module Main where
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32) import Data.Int (Int32)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
import Google.Test import Google.Test
@ -29,6 +30,7 @@ import TensorFlow.Ops
import TensorFlow.Session import TensorFlow.Session
-- | Test fetching multiple outputs from an op. -- | Test fetching multiple outputs from an op.
testMultipleOutputs :: Test
testMultipleOutputs = testCase "testMultipleOutputs" $ testMultipleOutputs = testCase "testMultipleOutputs" $
runSession $ do runSession $ do
(values, indices) <- (values, indices) <-
@ -37,6 +39,7 @@ testMultipleOutputs = testCase "testMultipleOutputs" $
liftIO $ [1, 3] @=? V.toList (indices :: V.Vector Int32) liftIO $ [1, 3] @=? V.toList (indices :: V.Vector Int32)
-- | Test op with variable number of inputs. -- | Test op with variable number of inputs.
testVarargs :: Test
testVarargs = testCase "testVarargs" $ testVarargs = testCase "testVarargs" $
runSession $ do runSession $ do
xs <- run $ pack $ map scalar [1..8] xs <- run $ pack $ map scalar [1..8]

View File

@ -12,9 +12,7 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
module Main where module Main where
@ -22,6 +20,7 @@ import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Google.Test (googleTest) import Google.Test (googleTest)
import System.IO.Temp (withSystemTempDirectory) import System.IO.Temp (withSystemTempDirectory)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Char8 as B8
@ -33,25 +32,31 @@ import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
-- | Test that one can easily determine number of elements in the tensor. -- | Test that one can easily determine number of elements in the tensor.
testSize :: Test
testSize = testCase "testSize" $ do testSize = testCase "testSize" $ do
x <- eval $ TF.size (TF.constant [2, 3] [0..5 :: Float]) x <- eval $ TF.size (TF.constant (TF.Shape [2, 3]) [0..5 :: Float])
TF.Scalar (2 * 3 :: Int32) @=? x TF.Scalar (2 * 3 :: Int32) @=? x
eval :: TF.Fetchable t a => t -> IO a
eval = TF.runSession . TF.buildAnd TF.run . return eval = TF.runSession . TF.buildAnd TF.run . return
-- | Confirms that the original example from Python code works. -- | Confirms that the original example from Python code works.
testReducedShape :: Test
testReducedShape = testCase "testReducedShape" $ do testReducedShape = testCase "testReducedShape" $ do
x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64]) x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64])
(TF.vector [1, 2 :: Int32]) (TF.vector [1, 2 :: Int32])
V.fromList [2, 1, 1, 7 :: Int32] @=? x V.fromList [2, 1, 1, 7 :: Int32] @=? x
testSaveRestore :: Test
testSaveRestore = testCase "testSaveRestore" $ testSaveRestore = testCase "testSaveRestore" $
withSystemTempDirectory "" $ \dirPath -> do withSystemTempDirectory "" $ \dirPath -> do
let path = B8.pack $ dirPath ++ "/checkpoint" let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.Build (TF.Tensor TF.Ref Float) var :: TF.Build (TF.Tensor TF.Ref Float)
var = TF.render =<< TF.named "a" <$> TF.zeroInitializedVariable [] var = TF.render =<<
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
TF.runSession $ do TF.runSession $ do
v <- TF.build var v <- TF.build var
TF.buildAnd TF.run_ $ TF.assign v 134 TF.buildAnd TF.run_ $ TF.assign v 134

View File

@ -21,11 +21,10 @@ import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar)
import Data.ByteString.Builder (toLazyByteString) import Data.ByteString.Builder (toLazyByteString)
import Data.ByteString.Lazy (isPrefixOf) import Data.ByteString.Lazy (isPrefixOf)
import Data.Default (def) import Data.Default (def)
import Data.Monoid ((<>))
import Lens.Family2 ((&), (.~)) import Lens.Family2 ((&), (.~))
import Test.Framework (defaultMain) import Test.Framework (defaultMain)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), assertBool, assertFailure) import Test.HUnit (assertBool, assertFailure)
import qualified TensorFlow.Core as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
@ -44,6 +43,7 @@ testTracing = do
assertBool ("Unexpected log entry " ++ show got) assertBool ("Unexpected log entry " ++ show got)
("Session.extend" `isPrefixOf` got) ("Session.extend" `isPrefixOf` got)
main :: IO ()
main = defaultMain main = defaultMain
[ testCase "Tracing" testTracing [ testCase "Tracing" testTracing
] ]

View File

@ -19,11 +19,14 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-} {-# OPTIONS_GHC -fno-warn-orphans #-}
-- Purposely disabled to confirm doubleFuncNoSig can be written without type.
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
import Control.Monad (replicateM) import Control.Monad (replicateM)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64) import Data.Int (Int64)
import Google.Test (googleTest) import Google.Test (googleTest)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.Framework.Providers.QuickCheck2 (testProperty) import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
@ -43,10 +46,11 @@ instance Arbitrary B.ByteString where
-- Test encoding tensors, feeding them through tensorflow, and decoding the -- Test encoding tensors, feeding them through tensorflow, and decoding the
-- results. -- results.
testFFIRoundTrip :: Test
testFFIRoundTrip = testCase "testFFIRoundTrip" $ testFFIRoundTrip = testCase "testFFIRoundTrip" $
TF.runSession $ do TF.runSession $ do
let floatData = V.fromList [1..6 :: Float] let floatData = V.fromList [1..6 :: Float]
stringData = V.fromList [B8.pack (show x) | x <- [1..6]] stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
f <- TF.build $ TF.placeholder [2,3] f <- TF.build $ TF.placeholder [2,3]
s <- TF.build $ TF.placeholder [2,3] s <- TF.build $ TF.placeholder [2,3]
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData) let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
@ -78,12 +82,15 @@ encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
encodeDecodeProp (TensorDataInputs shape vec) = encodeDecodeProp (TensorDataInputs shape vec) =
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
testEncodeDecodeQcFloat :: Test
testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat" testEncodeDecodeQcFloat = testProperty "testEncodeDecodeQcFloat"
(encodeDecodeProp :: TensorDataInputs Float -> Bool) (encodeDecodeProp :: TensorDataInputs Float -> Bool)
testEncodeDecodeQcInt64 :: Test
testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64" testEncodeDecodeQcInt64 = testProperty "testEncodeDecodeQcInt64"
(encodeDecodeProp :: TensorDataInputs Int64 -> Bool) (encodeDecodeProp :: TensorDataInputs Int64 -> Bool)
testEncodeDecodeQcString :: Test
testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString" testEncodeDecodeQcString = testProperty "testEncodeDecodeQcString"
(encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool) (encodeDecodeProp :: TensorDataInputs B.ByteString -> Bool)
@ -101,6 +108,7 @@ doubleFunc = doubleOrFloatFunc . doubleOrInt64Func
-- can't simplify the type all the way to `Double -> Double`. -- can't simplify the type all the way to `Double -> Double`.
doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func doubleFuncNoSig = doubleOrFloatFunc . doubleOrInt64Func
typeConstraintTests :: Test
typeConstraintTests = testCase "type constraints" $ do typeConstraintTests = testCase "type constraints" $ do
42 @=? doubleOrInt64Func (42 :: Double) 42 @=? doubleOrInt64Func (42 :: Double)
42 @=? doubleOrInt64Func (42 :: Int64) 42 @=? doubleOrInt64Func (42 :: Int64)

View File

@ -31,11 +31,13 @@ import TensorFlow.Session
, runSession , runSession
, run_ , run_
) )
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?)) import Test.HUnit ((@=?))
import qualified Data.ByteString as BS import qualified Data.ByteString as BS
-- | Test basic queue behaviors. -- | Test basic queue behaviors.
testBasic :: Test
testBasic = testCase "testBasic" $ runSession $ do testBasic = testCase "testBasic" $ runSession $ do
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 "" (q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
buildAnd run_ (enqueue q 42 (scalar "Hi")) buildAnd run_ (enqueue q 42 (scalar "Hi"))
@ -47,6 +49,7 @@ testBasic = testCase "testBasic" $ runSession $ do
liftIO $ (Scalar 56, Scalar "Bar") @=? y liftIO $ (Scalar 56, Scalar "Bar") @=? y
-- | Test queue pumping. -- | Test queue pumping.
testPump :: Test
testPump = testCase "testPump" $ runSession $ do testPump = testCase "testPump" $ runSession $ do
(deq, pump) <- build $ do (deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue" q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
@ -61,6 +64,7 @@ testPump = testCase "testPump" $ runSession $ do
liftIO $ (Scalar 31, Scalar "Baz") @=? x liftIO $ (Scalar 31, Scalar "Baz") @=? x
liftIO $ (Scalar 31, Scalar "Baz") @=? y liftIO $ (Scalar 31, Scalar "Baz") @=? y
testAsync :: Test
testAsync = testCase "testAsync" $ runSession $ do testAsync = testCase "testAsync" $ runSession $ do
(deq, pump) <- build $ do (deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "" q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""

View File

@ -20,8 +20,7 @@ module TensorFlow.Test
import qualified Data.Vector as V import qualified Data.Vector as V
import Test.HUnit ((@?)) import Test.HUnit ((@?))
import Test.HUnit.Lang (Assertion(..)) import Test.HUnit.Lang (Assertion)
-- | Compares that the vectors are element-by-element equal within the given -- | Compares that the vectors are element-by-element equal within the given
-- tolerance. Raises an assertion and prints some information if not. -- tolerance. Raises an assertion and prints some information if not.
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
@ -31,4 +30,3 @@ assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
where where
absDiff x y = abs (x - y) absDiff x y = abs (x - y)
tol = 0.001 :: Float tol = 0.001 :: Float

View File

@ -36,7 +36,6 @@ import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_) import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when) import Control.Monad (when)
import Data.Bits (Bits, toIntegralSized) import Data.Bits (Bits, toIntegralSized)
import Data.Data (Data, dataTypeName, dataTypeOf)
import Data.Int (Int64) import Data.Int (Int64)
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable) import Data.Typeable (Typeable)

View File

@ -13,8 +13,9 @@
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Output module TensorFlow.Output
( ControlNode(..) ( ControlNode(..)
@ -150,8 +151,8 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
-- code into a Build function -- code into a Build function
instance IsString Output where instance IsString Output where
fromString s = case break (==':') s of fromString s = case break (==':') s of
(n, ':':ixStr) (n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n -> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s _ -> Output 0 $ assigned s
where assigned n = Rendered $ def & name .~ Text.pack n where assigned n = Rendered $ def & name .~ Text.pack n

View File

@ -46,24 +46,22 @@ import Data.ByteString (ByteString)
import Data.Default (Default, def) import Data.Default (Default, def)
import Data.Functor.Identity (runIdentity) import Data.Functor.Identity (runIdentity)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
import qualified Data.Map.Strict as Map import Data.ProtoLens (showMessage)
import qualified Data.Set as Set
import Data.Set (Set) import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8) import Data.Text.Encoding (encodeUtf8)
import Data.ProtoLens (def, showMessage)
import Lens.Family2 (Lens', (^.), (&), (.~)) import Lens.Family2 (Lens', (^.), (&), (.~))
import Lens.Family2.Unchecked (lens) import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph (node) import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build import TensorFlow.Build
import TensorFlow.Nodes import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName) import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Tensor import TensorFlow.Tensor
import qualified Data.ByteString.Builder as Builder import qualified Data.ByteString.Builder as Builder
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified TensorFlow.Internal.FFI as FFI import qualified TensorFlow.Internal.FFI as FFI
import qualified TensorFlow.Internal.Raw as Raw
-- | An action for logging. -- | An action for logging.
type Tracer = Builder.Builder -> IO () type Tracer = Builder.Builder -> IO ()

View File

@ -33,6 +33,7 @@ testParseAll = do
. not . null . view op) . not . null . view op)
(decodeMessage opList :: Either String OpList) (decodeMessage opList :: Either String OpList)
main :: IO ()
main = defaultMain main = defaultMain
[ testCase "ParseAllOps" testParseAll [ testCase "ParseAllOps" testParseAll
] ]

View File

@ -16,11 +16,13 @@ module Main where
import Data.ByteString.Builder (toLazyByteString) import Data.ByteString.Builder (toLazyByteString)
import Google.Test (googleTest) import Google.Test (googleTest)
import Test.Framework (Test)
import Test.Framework.Providers.QuickCheck2 (testProperty) import Test.Framework.Providers.QuickCheck2 (testProperty)
import qualified Data.Attoparsec.ByteString.Lazy as Atto import qualified Data.Attoparsec.ByteString.Lazy as Atto
import TensorFlow.Internal.VarInt import TensorFlow.Internal.VarInt
testEncodeDecode :: Test
testEncodeDecode = testProperty "testEncodeDecode" $ \x -> testEncodeDecode = testProperty "testEncodeDecode" $ \x ->
let bytes = toLazyByteString (putVarInt x) let bytes = toLazyByteString (putVarInt x)
in case Atto.eitherResult $ Atto.parse getVarInt bytes of in case Atto.eitherResult $ Atto.parse getVarInt bytes of