1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

Make code --pedantic.

Enforce pedantic build mode in CI.
Our imports drifted really far from where they should be.
This commit is contained in:
Greg Steuck 2016-11-17 11:34:24 -08:00
parent fc3d398ca9
commit 252885f414
8 changed files with 22 additions and 16 deletions

View File

@ -8,4 +8,4 @@ IMAGE_NAME=tensorflow/haskell/ci_build:v0
git submodule update
docker build -t $IMAGE_NAME -f ci_build/Dockerfile .
docker run $IMAGE_NAME stack test
docker run $IMAGE_NAME stack build --pedantic --test

View File

@ -15,7 +15,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLists #-}
import Control.Monad (zipWithM, when, forM, forM_)
import Control.Monad (zipWithM, when, forM_)
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64)
import Data.List (genericLength)
@ -34,7 +34,8 @@ import qualified TensorFlow.Types as TF
import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse
numPixels = 28^2 :: Int64
numPixels, numLabels :: Int64
numPixels = 28*28 :: Int64
numLabels = 10 :: Int64
-- | Create tensor with random values where the stddev depends on the width.
@ -44,6 +45,7 @@ randomParam width (TF.Shape shape) =
where
stddev = TF.scalar (1 / sqrt (fromIntegral width))
reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
-- Types must match due to model structure.
@ -108,6 +110,7 @@ createModel = do
] errorRateTensor
}
main :: IO ()
main = TF.runSession $ do
-- Read training and test data.
trainingImages <- liftIO (readMNISTSamples =<< trainingImageData)

View File

@ -22,7 +22,7 @@ module TensorFlow.NN
import Prelude hiding ( log
, exp
)
import TensorFlow.Build ( Build(..)
import TensorFlow.Build ( Build
, render
, withNameScope
)
@ -32,7 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual
, exp
)
import TensorFlow.Tensor ( Tensor(..)
, Value(..)
, Value
)
import TensorFlow.Types ( TensorType(..)
, OneOf

View File

@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps

View File

@ -94,7 +94,7 @@ import TensorFlow.Tensor
, tensorOutput
, tensorAttr
)
import TensorFlow.Types (OneOf, TensorType, attrLens)
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef
(NodeDef, attr, input, op, name)
@ -688,10 +688,15 @@ numOutputs o =
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
allDimensions :: Tensor Value Int32
allDimensions = vector [-1 :: Int32]
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
-- lookupAttr :: NodeDef -> ( -> _
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens

View File

@ -36,7 +36,6 @@ import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when)
import Data.Bits (Bits, toIntegralSized)
import Data.Data (Data, dataTypeName, dataTypeOf)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)

View File

@ -13,8 +13,9 @@
-- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Output
( ControlNode(..)
@ -150,8 +151,8 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
-- code into a Build function
instance IsString Output where
fromString s = case break (==':') s of
(n, ':':ixStr)
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
-> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s
where assigned n = Rendered $ def & name .~ Text.pack n

View File

@ -46,24 +46,22 @@ import Data.ByteString (ByteString)
import Data.Default (Default, def)
import Data.Functor.Identity (runIdentity)
import Data.Monoid ((<>))
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.ProtoLens (showMessage)
import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8)
import Data.ProtoLens (def, showMessage)
import Lens.Family2 (Lens', (^.), (&), (.~))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Tensor
import qualified Data.ByteString.Builder as Builder
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified TensorFlow.Internal.FFI as FFI
import qualified TensorFlow.Internal.Raw as Raw
-- | An action for logging.
type Tracer = Builder.Builder -> IO ()