mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
upgrade to ghc 8.6.4 (#237)
This commit is contained in:
parent
c0f87dc0bc
commit
7316062c10
23 changed files with 52 additions and 43 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,3 +1,4 @@
|
|||
**/.stack-work
|
||||
.stack/
|
||||
tensorflow-mnist-input-data/data/*.gz
|
||||
.DS_Store
|
||||
|
|
|
@ -11,8 +11,8 @@ let
|
|||
pkgs = import nixpkgs {};
|
||||
in
|
||||
pkgs.haskell.lib.buildStackProject {
|
||||
# Either use specified GHC or use GHC 8.4.4 (which we need for LTS 12.26)
|
||||
ghc = if isNull ghc then pkgs.haskell.compiler.ghc844 else ghc;
|
||||
# Either use specified GHC or use GHC 8.6.4 (which we need for LTS 13.13)
|
||||
ghc = if isNull ghc then pkgs.haskell.compiler.ghc864 else ghc;
|
||||
extraArgs = "--system-ghc";
|
||||
name = "tf-env";
|
||||
buildInputs = with pkgs; [ snappy zlib protobuf libtensorflow ];
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
resolver: lts-12.26
|
||||
resolver: lts-13.13
|
||||
|
||||
packages:
|
||||
- tensorflow
|
||||
|
@ -15,6 +15,7 @@ packages:
|
|||
|
||||
extra-deps:
|
||||
- snappy-framing-0.1.2
|
||||
- snappy-0.2.0.2
|
||||
|
||||
# For Mac OS X, whose linker doesn't use this path by default
|
||||
# unless you run `xcode-select --install`.
|
||||
|
|
|
@ -15,7 +15,7 @@ cabal-version: >=1.24
|
|||
library
|
||||
exposed-modules: TensorFlow.GenOps.Core
|
||||
build-depends: bytestring
|
||||
, proto-lens == 0.3.*
|
||||
, proto-lens == 0.4.*
|
||||
, tensorflow == 0.2.*
|
||||
, base >= 4.7 && < 5
|
||||
, lens-family
|
||||
|
@ -26,7 +26,7 @@ custom-setup
|
|||
setup-depends: Cabal
|
||||
, bytestring
|
||||
, directory
|
||||
, proto-lens == 0.3.*
|
||||
, proto-lens == 0.4.*
|
||||
, tensorflow-opgen == 0.2.*
|
||||
, tensorflow == 0.2.*
|
||||
, base >= 4.7 && < 5
|
||||
|
|
|
@ -60,7 +60,7 @@ import Control.Monad.Trans.Resource (runResourceT)
|
|||
import Data.ByteString (ByteString)
|
||||
import Data.Conduit ((.|))
|
||||
import Data.Conduit.TQueue (sourceTBMQueue)
|
||||
import Data.Default (def)
|
||||
import Data.ProtoLens.Default(def)
|
||||
import Data.Int (Int64)
|
||||
import Data.Word (Word8, Word16)
|
||||
import Data.ProtoLens (encodeMessage)
|
||||
|
|
|
@ -24,7 +24,7 @@ library
|
|||
, filepath
|
||||
, hostname
|
||||
, lens-family
|
||||
, proto-lens == 0.3.*
|
||||
, proto-lens == 0.4.*
|
||||
, resourcet
|
||||
, stm
|
||||
, stm-chans
|
||||
|
|
|
@ -17,7 +17,7 @@ module Main where
|
|||
|
||||
import Control.Monad.Trans.Resource (runResourceT)
|
||||
import Data.Conduit ((.|))
|
||||
import Data.Default (def)
|
||||
import Data.ProtoLens.Message (defMessage)
|
||||
import Data.List ((\\))
|
||||
import Data.ProtoLens (encodeMessage, decodeMessageOrDie)
|
||||
import Lens.Family2 ((^.), (.~), (&))
|
||||
|
@ -45,9 +45,9 @@ testEventWriter :: Test
|
|||
testEventWriter = testCase "EventWriter" $
|
||||
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||
assertEqual "No file before" [] =<< listDirectory dir
|
||||
let expected = [ (def :: Event) & step .~ 10
|
||||
, def & step .~ 222
|
||||
, def & step .~ 8
|
||||
let expected = [ (defMessage :: Event) & step .~ 10
|
||||
, defMessage & step .~ 222
|
||||
, defMessage & step .~ 8
|
||||
]
|
||||
withEventWriter dir $ \eventWriter ->
|
||||
mapM_ (logEvent eventWriter) expected
|
||||
|
@ -66,7 +66,7 @@ testLogGraph = testCase "LogGraph" $
|
|||
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||
let graphBuild = noOp :: Build ControlNode
|
||||
expectedGraph = asGraphDef graphBuild
|
||||
expectedGraphEvent = (def :: Event) & graphDef .~ (encodeMessage expectedGraph)
|
||||
expectedGraphEvent = (defMessage :: Event) & graphDef .~ (encodeMessage expectedGraph)
|
||||
|
||||
withEventWriter dir $ \eventWriter ->
|
||||
logGraph eventWriter graphBuild
|
||||
|
|
|
@ -20,7 +20,7 @@ library
|
|||
exposed-modules: TensorFlow.Examples.MNIST.Parse
|
||||
, TensorFlow.Examples.MNIST.TrainedGraph
|
||||
other-modules: Paths_tensorflow_mnist
|
||||
build-depends: proto-lens == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, base >= 4.7 && < 5
|
||||
, binary
|
||||
, bytestring
|
||||
|
|
|
@ -24,12 +24,12 @@ import qualified Data.Text.IO as Text
|
|||
import Lens.Family2 ((&), (.~), (^.))
|
||||
import Prelude hiding (abs)
|
||||
import Proto.Tensorflow.Core.Framework.Graph
|
||||
( GraphDef(..) )
|
||||
( GraphDef )
|
||||
import Proto.Tensorflow.Core.Framework.Graph_Fields
|
||||
( version
|
||||
, node )
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
( NodeDef(..) )
|
||||
( NodeDef )
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
|
||||
import System.IO as IO
|
||||
import TensorFlow.Examples.MNIST.InputData
|
||||
|
|
|
@ -51,7 +51,8 @@ module TensorFlow.OpGen
|
|||
|
||||
import Data.Foldable (toList)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.ProtoLens (def, showMessage)
|
||||
import Data.ProtoLens.Default(def)
|
||||
import Data.ProtoLens (showMessage)
|
||||
import Data.List (sortOn)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as NE
|
||||
|
|
|
@ -16,7 +16,7 @@ library
|
|||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.OpGen.ParsedOp
|
||||
, TensorFlow.OpGen
|
||||
build-depends: proto-lens == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, tensorflow-proto == 0.2.*
|
||||
, base >= 4.7 && < 5
|
||||
, bytestring
|
||||
|
|
|
@ -31,7 +31,7 @@ import Control.Monad (forM, zipWithM)
|
|||
import Control.Monad.State.Strict (State, evalState, gets, modify)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.ProtoLens.Default(def)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.Foldable (foldlM)
|
||||
import Data.List (foldl', sortBy)
|
||||
|
|
|
@ -158,7 +158,7 @@ import Data.Complex (Complex)
|
|||
import Data.Int (Int32, Int64)
|
||||
import Data.Word (Word16)
|
||||
import Prelude hiding (abs, sum, concat)
|
||||
import Data.ProtoLens (def)
|
||||
import Data.ProtoLens.Default(def)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import Text.Printf (printf)
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
{-# LANGUAGE RecursiveDo #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE NoMonadFailDesugaring #-}
|
||||
module TensorFlow.Variable
|
||||
( Variable
|
||||
, variable
|
||||
|
|
|
@ -21,7 +21,7 @@ library
|
|||
, TensorFlow.NN
|
||||
, TensorFlow.Queue
|
||||
, TensorFlow.Variable
|
||||
build-depends: proto-lens == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, base >= 4.7 && < 5
|
||||
, bytestring
|
||||
, fgl
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE NoMonadFailDesugaring #-}
|
||||
|
||||
-- | Tests for EmbeddingOps.
|
||||
module Main where
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE NoMonadFailDesugaring #-}
|
||||
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (sort)
|
||||
|
|
|
@ -97,16 +97,16 @@ library
|
|||
, Proto.Tensorflow.Core.Util.SavedTensorSlice_Fields
|
||||
, Proto.Tensorflow.Core.Util.TestLog
|
||||
, Proto.Tensorflow.Core.Util.TestLog_Fields
|
||||
build-depends: proto-lens == 0.3.*
|
||||
, proto-lens-protoc == 0.3.*
|
||||
, proto-lens-protobuf-types == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, proto-lens-runtime == 0.4.*
|
||||
, proto-lens-protobuf-types == 0.4.*
|
||||
, base >= 4.7 && < 5
|
||||
default-language: Haskell2010
|
||||
include-dirs: .
|
||||
|
||||
custom-setup
|
||||
setup-depends: Cabal
|
||||
, proto-lens-protoc == 0.3.*
|
||||
, proto-lens-setup == 0.4.*
|
||||
, base >= 4.7 && < 5
|
||||
source-repository head
|
||||
type: git
|
||||
|
|
|
@ -61,12 +61,13 @@ module TensorFlow.Build
|
|||
, withNodeDependencies
|
||||
) where
|
||||
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||
import Control.Monad.Fix (MonadFix(..))
|
||||
import Control.Monad.IO.Class (MonadIO(..))
|
||||
import Control.Monad.Fail (MonadFail(..))
|
||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||
import Data.Default (def)
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.Monoid ((<>))
|
||||
|
@ -191,7 +192,7 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
|
|||
newtype BuildT m a = BuildT (StateT GraphState m a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
||||
MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
|
||||
MonadFix)
|
||||
MonadFix, MonadFail)
|
||||
|
||||
-- | An action for building nodes in a TensorFlow graph.
|
||||
type Build = BuildT Identity
|
||||
|
@ -236,7 +237,7 @@ addInitializer (ControlNode i) = build $ initializationNodes %= (i:)
|
|||
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
||||
-- the given 'Build' action.
|
||||
asGraphDef :: Build a -> GraphDef
|
||||
asGraphDef b = def & node .~ gs ^. nodeBuffer
|
||||
asGraphDef b = defMessage & node .~ gs ^. nodeBuffer
|
||||
where
|
||||
gs = snd $ runIdentity $ runBuildT b
|
||||
|
||||
|
@ -285,7 +286,7 @@ getPendingNode o = do
|
|||
let controlInputs
|
||||
= map makeDep (o ^. opControlInputs ++ Set.toList controls)
|
||||
return $ PendingNode scope (o ^. opName)
|
||||
$ def & op .~ (unOpType (o ^. opType) :: Text)
|
||||
$ defMessage & op .~ (unOpType (o ^. opType) :: Text)
|
||||
& attr .~ _opAttrs o
|
||||
& input .~ (inputs ++ controlInputs)
|
||||
& device .~ dev
|
||||
|
|
|
@ -35,14 +35,14 @@ module TensorFlow.Output
|
|||
, PendingNodeName(..)
|
||||
) where
|
||||
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.String (IsString(..))
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens')
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
|
||||
import Data.Default (def)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue)
|
||||
import TensorFlow.Types (Attribute, attrLens)
|
||||
|
||||
-- | A type of graph node which has no outputs. These nodes are
|
||||
|
@ -108,7 +108,7 @@ opType = lens _opType (\o x -> o { _opType = x})
|
|||
|
||||
opAttr :: Attribute a => Text -> Lens' OpDef a
|
||||
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
|
||||
. lens (Map.findWithDefault def n) (flip (Map.insert n))
|
||||
. lens (Map.findWithDefault defMessage n) (flip (Map.insert n))
|
||||
. attrLens
|
||||
|
||||
opInputs :: Lens' OpDef [Output]
|
||||
|
|
|
@ -37,7 +37,9 @@ module TensorFlow.Session (
|
|||
asyncProdNodes,
|
||||
) where
|
||||
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
import Control.Monad (forever, unless, void)
|
||||
import Control.Monad.Fail (MonadFail(..))
|
||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Control.Monad.Trans.Class (MonadTrans, lift)
|
||||
|
@ -78,7 +80,7 @@ data SessionState
|
|||
newtype SessionT m a
|
||||
= Session (ReaderT SessionState (BuildT m) a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
||||
MonadMask)
|
||||
MonadMask, MonadFail)
|
||||
|
||||
instance MonadTrans SessionT where
|
||||
lift = Session . lift . lift
|
||||
|
@ -100,7 +102,7 @@ data Options = Options
|
|||
instance Default Options where
|
||||
def = Options
|
||||
{ _sessionTarget = ""
|
||||
, _sessionConfig = def
|
||||
, _sessionConfig = defMessage
|
||||
, _sessionTracer = const (return ())
|
||||
}
|
||||
|
||||
|
@ -142,7 +144,7 @@ extend = do
|
|||
trace <- Session (asks tracer)
|
||||
nodesToExtend <- build flushNodeBuffer
|
||||
unless (null nodesToExtend) $ liftIO $ do
|
||||
let graphDef = (def :: GraphDef) & node .~ nodesToExtend
|
||||
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
|
||||
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
||||
FFI.extendGraph session graphDef
|
||||
-- Now that all the nodes are created, run the initializers.
|
||||
|
|
|
@ -64,9 +64,9 @@ module TensorFlow.Types
|
|||
, AllTensorTypes
|
||||
) where
|
||||
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int8, Int16, Int32, Int64)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Monoid ((<>))
|
||||
|
@ -88,8 +88,8 @@ import qualified Data.ByteString.Lazy as L
|
|||
import qualified Data.Vector as V
|
||||
import qualified Data.Vector.Storable as S
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue
|
||||
( AttrValue(..)
|
||||
, AttrValue'ListValue(..)
|
||||
( AttrValue
|
||||
, AttrValue'ListValue
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
|
||||
( b
|
||||
|
@ -105,7 +105,7 @@ import Proto.Tensorflow.Core.Framework.AttrValue_Fields
|
|||
import Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||
(ResourceHandleProto)
|
||||
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||
(TensorProto(..))
|
||||
(TensorProto)
|
||||
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
|
||||
( boolVal
|
||||
, doubleVal
|
||||
|
@ -119,7 +119,7 @@ import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
|
|||
)
|
||||
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||
(TensorShapeProto(..))
|
||||
(TensorShapeProto)
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
|
||||
( dim
|
||||
, size
|
||||
|
@ -400,7 +400,7 @@ protoShape = iso protoToShape shapeToProto
|
|||
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
|
||||
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
|
||||
++ showMessageShort p
|
||||
shapeToProto s' = def & protoMaybeShape .~ Just s'
|
||||
shapeToProto s' = defMessage & protoMaybeShape .~ Just s'
|
||||
|
||||
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
|
||||
protoMaybeShape = iso protoToShape shapeToProto
|
||||
|
@ -412,9 +412,9 @@ protoMaybeShape = iso protoToShape shapeToProto
|
|||
else Just (Shape (p ^.. dim . traverse . size))
|
||||
shapeToProto :: Maybe Shape -> TensorShapeProto
|
||||
shapeToProto Nothing =
|
||||
def & unknownRank .~ True
|
||||
defMessage & unknownRank .~ True
|
||||
shapeToProto (Just (Shape ds)) =
|
||||
def & dim .~ fmap (\d -> def & size .~ d) ds
|
||||
defMessage & dim .~ fmap (\d -> defMessage & size .~ d) ds
|
||||
|
||||
|
||||
class Attribute a where
|
||||
|
|
|
@ -36,7 +36,7 @@ library
|
|||
, TensorFlow.Types
|
||||
other-modules: TensorFlow.Internal.Raw
|
||||
build-tools: c2hs
|
||||
build-depends: proto-lens == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, tensorflow-proto == 0.2.*
|
||||
, base >= 4.7 && < 5
|
||||
, async
|
||||
|
|
Loading…
Reference in a new issue