mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
Make code --pedantic.
Enforce pedantic build mode in CI. Our imports drifted really far from where they should be.
This commit is contained in:
parent
fc3d398ca9
commit
252885f414
|
@ -8,4 +8,4 @@ IMAGE_NAME=tensorflow/haskell/ci_build:v0
|
|||
|
||||
git submodule update
|
||||
docker build -t $IMAGE_NAME -f ci_build/Dockerfile .
|
||||
docker run $IMAGE_NAME stack test
|
||||
docker run $IMAGE_NAME stack build --pedantic --test
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
|
||||
import Control.Monad (zipWithM, when, forM, forM_)
|
||||
import Control.Monad (zipWithM, when, forM_)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
|
@ -34,7 +34,8 @@ import qualified TensorFlow.Types as TF
|
|||
import TensorFlow.Examples.MNIST.InputData
|
||||
import TensorFlow.Examples.MNIST.Parse
|
||||
|
||||
numPixels = 28^2 :: Int64
|
||||
numPixels, numLabels :: Int64
|
||||
numPixels = 28*28 :: Int64
|
||||
numLabels = 10 :: Int64
|
||||
|
||||
-- | Create tensor with random values where the stddev depends on the width.
|
||||
|
@ -44,6 +45,7 @@ randomParam width (TF.Shape shape) =
|
|||
where
|
||||
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||
|
||||
reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
|
||||
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
|
||||
|
||||
-- Types must match due to model structure.
|
||||
|
@ -108,6 +110,7 @@ createModel = do
|
|||
] errorRateTensor
|
||||
}
|
||||
|
||||
main :: IO ()
|
||||
main = TF.runSession $ do
|
||||
-- Read training and test data.
|
||||
trainingImages <- liftIO (readMNISTSamples =<< trainingImageData)
|
||||
|
|
|
@ -22,7 +22,7 @@ module TensorFlow.NN
|
|||
import Prelude hiding ( log
|
||||
, exp
|
||||
)
|
||||
import TensorFlow.Build ( Build(..)
|
||||
import TensorFlow.Build ( Build
|
||||
, render
|
||||
, withNameScope
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual
|
|||
, exp
|
||||
)
|
||||
import TensorFlow.Tensor ( Tensor(..)
|
||||
, Value(..)
|
||||
, Value
|
||||
)
|
||||
import TensorFlow.Types ( TensorType(..)
|
||||
, OneOf
|
||||
|
|
|
@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
|
|||
import Control.Monad (zipWithM)
|
||||
import Data.Int (Int32, Int64)
|
||||
import TensorFlow.Build (Build, colocateWith, render)
|
||||
import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Tensor (Tensor, Value)
|
||||
import TensorFlow.Types (OneOf, TensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
|
|
@ -94,7 +94,7 @@ import TensorFlow.Tensor
|
|||
, tensorOutput
|
||||
, tensorAttr
|
||||
)
|
||||
import TensorFlow.Types (OneOf, TensorType, attrLens)
|
||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
(NodeDef, attr, input, op, name)
|
||||
|
||||
|
@ -688,10 +688,15 @@ numOutputs o =
|
|||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||
|
||||
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
||||
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32
|
||||
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
||||
|
||||
allDimensions :: Tensor Value Int32
|
||||
allDimensions = vector [-1 :: Int32]
|
||||
|
||||
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32
|
||||
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
||||
|
||||
-- lookupAttr :: NodeDef -> ( -> _
|
||||
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
|
||||
lookupAttr nodeDef attrName = nodeDef ^. attr . at attrName . non def . attrLens
|
||||
|
|
|
@ -36,7 +36,6 @@ import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
|||
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
||||
import Control.Monad (when)
|
||||
import Data.Bits (Bits, toIntegralSized)
|
||||
import Data.Data (Data, dataTypeName, dataTypeOf)
|
||||
import Data.Int (Int64)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Typeable (Typeable)
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module TensorFlow.Output
|
||||
( ControlNode(..)
|
||||
|
@ -150,8 +151,8 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
|
|||
-- code into a Build function
|
||||
instance IsString Output where
|
||||
fromString s = case break (==':') s of
|
||||
(n, ':':ixStr)
|
||||
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
|
||||
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
|
||||
-> Output (fromInteger ix) $ assigned n
|
||||
_ -> Output 0 $ assigned s
|
||||
where assigned n = Rendered $ def & name .~ Text.pack n
|
||||
|
||||
|
|
|
@ -46,24 +46,22 @@ import Data.ByteString (ByteString)
|
|||
import Data.Default (Default, def)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Data.Monoid ((<>))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import Data.ProtoLens (showMessage)
|
||||
import Data.Set (Set)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Data.ProtoLens (def, showMessage)
|
||||
import Lens.Family2 (Lens', (^.), (&), (.~))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.Graph (node)
|
||||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Nodes
|
||||
import TensorFlow.Output (NodeName, unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
|
||||
import qualified Data.ByteString.Builder as Builder
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
|
||||
-- | An action for logging.
|
||||
type Tracer = Builder.Builder -> IO ()
|
||||
|
|
Loading…
Reference in New Issue
Block a user