mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-03 16:29:46 +01:00
upgrade to ghc 8.6.4 (#237)
This commit is contained in:
parent
c0f87dc0bc
commit
7316062c10
23 changed files with 52 additions and 43 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,3 +1,4 @@
|
||||||
**/.stack-work
|
**/.stack-work
|
||||||
.stack/
|
.stack/
|
||||||
tensorflow-mnist-input-data/data/*.gz
|
tensorflow-mnist-input-data/data/*.gz
|
||||||
|
.DS_Store
|
||||||
|
|
|
@ -11,8 +11,8 @@ let
|
||||||
pkgs = import nixpkgs {};
|
pkgs = import nixpkgs {};
|
||||||
in
|
in
|
||||||
pkgs.haskell.lib.buildStackProject {
|
pkgs.haskell.lib.buildStackProject {
|
||||||
# Either use specified GHC or use GHC 8.4.4 (which we need for LTS 12.26)
|
# Either use specified GHC or use GHC 8.6.4 (which we need for LTS 13.13)
|
||||||
ghc = if isNull ghc then pkgs.haskell.compiler.ghc844 else ghc;
|
ghc = if isNull ghc then pkgs.haskell.compiler.ghc864 else ghc;
|
||||||
extraArgs = "--system-ghc";
|
extraArgs = "--system-ghc";
|
||||||
name = "tf-env";
|
name = "tf-env";
|
||||||
buildInputs = with pkgs; [ snappy zlib protobuf libtensorflow ];
|
buildInputs = with pkgs; [ snappy zlib protobuf libtensorflow ];
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
resolver: lts-12.26
|
resolver: lts-13.13
|
||||||
|
|
||||||
packages:
|
packages:
|
||||||
- tensorflow
|
- tensorflow
|
||||||
|
@ -15,6 +15,7 @@ packages:
|
||||||
|
|
||||||
extra-deps:
|
extra-deps:
|
||||||
- snappy-framing-0.1.2
|
- snappy-framing-0.1.2
|
||||||
|
- snappy-0.2.0.2
|
||||||
|
|
||||||
# For Mac OS X, whose linker doesn't use this path by default
|
# For Mac OS X, whose linker doesn't use this path by default
|
||||||
# unless you run `xcode-select --install`.
|
# unless you run `xcode-select --install`.
|
||||||
|
|
|
@ -15,7 +15,7 @@ cabal-version: >=1.24
|
||||||
library
|
library
|
||||||
exposed-modules: TensorFlow.GenOps.Core
|
exposed-modules: TensorFlow.GenOps.Core
|
||||||
build-depends: bytestring
|
build-depends: bytestring
|
||||||
, proto-lens == 0.3.*
|
, proto-lens == 0.4.*
|
||||||
, tensorflow == 0.2.*
|
, tensorflow == 0.2.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, lens-family
|
, lens-family
|
||||||
|
@ -26,7 +26,7 @@ custom-setup
|
||||||
setup-depends: Cabal
|
setup-depends: Cabal
|
||||||
, bytestring
|
, bytestring
|
||||||
, directory
|
, directory
|
||||||
, proto-lens == 0.3.*
|
, proto-lens == 0.4.*
|
||||||
, tensorflow-opgen == 0.2.*
|
, tensorflow-opgen == 0.2.*
|
||||||
, tensorflow == 0.2.*
|
, tensorflow == 0.2.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
|
|
|
@ -60,7 +60,7 @@ import Control.Monad.Trans.Resource (runResourceT)
|
||||||
import Data.ByteString (ByteString)
|
import Data.ByteString (ByteString)
|
||||||
import Data.Conduit ((.|))
|
import Data.Conduit ((.|))
|
||||||
import Data.Conduit.TQueue (sourceTBMQueue)
|
import Data.Conduit.TQueue (sourceTBMQueue)
|
||||||
import Data.Default (def)
|
import Data.ProtoLens.Default(def)
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
import Data.Word (Word8, Word16)
|
import Data.Word (Word8, Word16)
|
||||||
import Data.ProtoLens (encodeMessage)
|
import Data.ProtoLens (encodeMessage)
|
||||||
|
|
|
@ -24,7 +24,7 @@ library
|
||||||
, filepath
|
, filepath
|
||||||
, hostname
|
, hostname
|
||||||
, lens-family
|
, lens-family
|
||||||
, proto-lens == 0.3.*
|
, proto-lens == 0.4.*
|
||||||
, resourcet
|
, resourcet
|
||||||
, stm
|
, stm
|
||||||
, stm-chans
|
, stm-chans
|
||||||
|
|
|
@ -17,7 +17,7 @@ module Main where
|
||||||
|
|
||||||
import Control.Monad.Trans.Resource (runResourceT)
|
import Control.Monad.Trans.Resource (runResourceT)
|
||||||
import Data.Conduit ((.|))
|
import Data.Conduit ((.|))
|
||||||
import Data.Default (def)
|
import Data.ProtoLens.Message (defMessage)
|
||||||
import Data.List ((\\))
|
import Data.List ((\\))
|
||||||
import Data.ProtoLens (encodeMessage, decodeMessageOrDie)
|
import Data.ProtoLens (encodeMessage, decodeMessageOrDie)
|
||||||
import Lens.Family2 ((^.), (.~), (&))
|
import Lens.Family2 ((^.), (.~), (&))
|
||||||
|
@ -45,9 +45,9 @@ testEventWriter :: Test
|
||||||
testEventWriter = testCase "EventWriter" $
|
testEventWriter = testCase "EventWriter" $
|
||||||
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||||
assertEqual "No file before" [] =<< listDirectory dir
|
assertEqual "No file before" [] =<< listDirectory dir
|
||||||
let expected = [ (def :: Event) & step .~ 10
|
let expected = [ (defMessage :: Event) & step .~ 10
|
||||||
, def & step .~ 222
|
, defMessage & step .~ 222
|
||||||
, def & step .~ 8
|
, defMessage & step .~ 8
|
||||||
]
|
]
|
||||||
withEventWriter dir $ \eventWriter ->
|
withEventWriter dir $ \eventWriter ->
|
||||||
mapM_ (logEvent eventWriter) expected
|
mapM_ (logEvent eventWriter) expected
|
||||||
|
@ -66,7 +66,7 @@ testLogGraph = testCase "LogGraph" $
|
||||||
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||||
let graphBuild = noOp :: Build ControlNode
|
let graphBuild = noOp :: Build ControlNode
|
||||||
expectedGraph = asGraphDef graphBuild
|
expectedGraph = asGraphDef graphBuild
|
||||||
expectedGraphEvent = (def :: Event) & graphDef .~ (encodeMessage expectedGraph)
|
expectedGraphEvent = (defMessage :: Event) & graphDef .~ (encodeMessage expectedGraph)
|
||||||
|
|
||||||
withEventWriter dir $ \eventWriter ->
|
withEventWriter dir $ \eventWriter ->
|
||||||
logGraph eventWriter graphBuild
|
logGraph eventWriter graphBuild
|
||||||
|
|
|
@ -20,7 +20,7 @@ library
|
||||||
exposed-modules: TensorFlow.Examples.MNIST.Parse
|
exposed-modules: TensorFlow.Examples.MNIST.Parse
|
||||||
, TensorFlow.Examples.MNIST.TrainedGraph
|
, TensorFlow.Examples.MNIST.TrainedGraph
|
||||||
other-modules: Paths_tensorflow_mnist
|
other-modules: Paths_tensorflow_mnist
|
||||||
build-depends: proto-lens == 0.3.*
|
build-depends: proto-lens == 0.4.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, binary
|
, binary
|
||||||
, bytestring
|
, bytestring
|
||||||
|
|
|
@ -24,12 +24,12 @@ import qualified Data.Text.IO as Text
|
||||||
import Lens.Family2 ((&), (.~), (^.))
|
import Lens.Family2 ((&), (.~), (^.))
|
||||||
import Prelude hiding (abs)
|
import Prelude hiding (abs)
|
||||||
import Proto.Tensorflow.Core.Framework.Graph
|
import Proto.Tensorflow.Core.Framework.Graph
|
||||||
( GraphDef(..) )
|
( GraphDef )
|
||||||
import Proto.Tensorflow.Core.Framework.Graph_Fields
|
import Proto.Tensorflow.Core.Framework.Graph_Fields
|
||||||
( version
|
( version
|
||||||
, node )
|
, node )
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
( NodeDef(..) )
|
( NodeDef )
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
|
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
|
||||||
import System.IO as IO
|
import System.IO as IO
|
||||||
import TensorFlow.Examples.MNIST.InputData
|
import TensorFlow.Examples.MNIST.InputData
|
||||||
|
|
|
@ -51,7 +51,8 @@ module TensorFlow.OpGen
|
||||||
|
|
||||||
import Data.Foldable (toList)
|
import Data.Foldable (toList)
|
||||||
import Data.Maybe (fromMaybe)
|
import Data.Maybe (fromMaybe)
|
||||||
import Data.ProtoLens (def, showMessage)
|
import Data.ProtoLens.Default(def)
|
||||||
|
import Data.ProtoLens (showMessage)
|
||||||
import Data.List (sortOn)
|
import Data.List (sortOn)
|
||||||
import Data.List.NonEmpty (NonEmpty)
|
import Data.List.NonEmpty (NonEmpty)
|
||||||
import qualified Data.List.NonEmpty as NE
|
import qualified Data.List.NonEmpty as NE
|
||||||
|
|
|
@ -16,7 +16,7 @@ library
|
||||||
hs-source-dirs: src
|
hs-source-dirs: src
|
||||||
exposed-modules: TensorFlow.OpGen.ParsedOp
|
exposed-modules: TensorFlow.OpGen.ParsedOp
|
||||||
, TensorFlow.OpGen
|
, TensorFlow.OpGen
|
||||||
build-depends: proto-lens == 0.3.*
|
build-depends: proto-lens == 0.4.*
|
||||||
, tensorflow-proto == 0.2.*
|
, tensorflow-proto == 0.2.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, bytestring
|
, bytestring
|
||||||
|
|
|
@ -31,7 +31,7 @@ import Control.Monad (forM, zipWithM)
|
||||||
import Control.Monad.State.Strict (State, evalState, gets, modify)
|
import Control.Monad.State.Strict (State, evalState, gets, modify)
|
||||||
import Data.ByteString (ByteString)
|
import Data.ByteString (ByteString)
|
||||||
import Data.Complex (Complex)
|
import Data.Complex (Complex)
|
||||||
import Data.Default (def)
|
import Data.ProtoLens.Default(def)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.Foldable (foldlM)
|
import Data.Foldable (foldlM)
|
||||||
import Data.List (foldl', sortBy)
|
import Data.List (foldl', sortBy)
|
||||||
|
|
|
@ -158,7 +158,7 @@ import Data.Complex (Complex)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.Word (Word16)
|
import Data.Word (Word16)
|
||||||
import Prelude hiding (abs, sum, concat)
|
import Prelude hiding (abs, sum, concat)
|
||||||
import Data.ProtoLens (def)
|
import Data.ProtoLens.Default(def)
|
||||||
import Data.Text.Encoding (encodeUtf8)
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
import Lens.Family2 ((.~), (&))
|
import Lens.Family2 ((.~), (&))
|
||||||
import Text.Printf (printf)
|
import Text.Printf (printf)
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
{-# LANGUAGE RecursiveDo #-}
|
{-# LANGUAGE RecursiveDo #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE NoMonadFailDesugaring #-}
|
||||||
module TensorFlow.Variable
|
module TensorFlow.Variable
|
||||||
( Variable
|
( Variable
|
||||||
, variable
|
, variable
|
||||||
|
|
|
@ -21,7 +21,7 @@ library
|
||||||
, TensorFlow.NN
|
, TensorFlow.NN
|
||||||
, TensorFlow.Queue
|
, TensorFlow.Queue
|
||||||
, TensorFlow.Variable
|
, TensorFlow.Variable
|
||||||
build-depends: proto-lens == 0.3.*
|
build-depends: proto-lens == 0.4.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, bytestring
|
, bytestring
|
||||||
, fgl
|
, fgl
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE NoMonadFailDesugaring #-}
|
||||||
|
|
||||||
-- | Tests for EmbeddingOps.
|
-- | Tests for EmbeddingOps.
|
||||||
module Main where
|
module Main where
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE NoMonadFailDesugaring #-}
|
||||||
|
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.List (sort)
|
import Data.List (sort)
|
||||||
|
|
|
@ -97,16 +97,16 @@ library
|
||||||
, Proto.Tensorflow.Core.Util.SavedTensorSlice_Fields
|
, Proto.Tensorflow.Core.Util.SavedTensorSlice_Fields
|
||||||
, Proto.Tensorflow.Core.Util.TestLog
|
, Proto.Tensorflow.Core.Util.TestLog
|
||||||
, Proto.Tensorflow.Core.Util.TestLog_Fields
|
, Proto.Tensorflow.Core.Util.TestLog_Fields
|
||||||
build-depends: proto-lens == 0.3.*
|
build-depends: proto-lens == 0.4.*
|
||||||
, proto-lens-protoc == 0.3.*
|
, proto-lens-runtime == 0.4.*
|
||||||
, proto-lens-protobuf-types == 0.3.*
|
, proto-lens-protobuf-types == 0.4.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
include-dirs: .
|
include-dirs: .
|
||||||
|
|
||||||
custom-setup
|
custom-setup
|
||||||
setup-depends: Cabal
|
setup-depends: Cabal
|
||||||
, proto-lens-protoc == 0.3.*
|
, proto-lens-setup == 0.4.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
source-repository head
|
source-repository head
|
||||||
type: git
|
type: git
|
||||||
|
|
|
@ -61,12 +61,13 @@ module TensorFlow.Build
|
||||||
, withNodeDependencies
|
, withNodeDependencies
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Data.ProtoLens.Message(defMessage)
|
||||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||||
import Control.Monad.Fix (MonadFix(..))
|
import Control.Monad.Fix (MonadFix(..))
|
||||||
import Control.Monad.IO.Class (MonadIO(..))
|
import Control.Monad.IO.Class (MonadIO(..))
|
||||||
|
import Control.Monad.Fail (MonadFail(..))
|
||||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||||
import Data.Default (def)
|
|
||||||
import Data.Functor.Identity (Identity(..))
|
import Data.Functor.Identity (Identity(..))
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
|
@ -191,7 +192,7 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
|
||||||
newtype BuildT m a = BuildT (StateT GraphState m a)
|
newtype BuildT m a = BuildT (StateT GraphState m a)
|
||||||
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
||||||
MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
|
MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
|
||||||
MonadFix)
|
MonadFix, MonadFail)
|
||||||
|
|
||||||
-- | An action for building nodes in a TensorFlow graph.
|
-- | An action for building nodes in a TensorFlow graph.
|
||||||
type Build = BuildT Identity
|
type Build = BuildT Identity
|
||||||
|
@ -236,7 +237,7 @@ addInitializer (ControlNode i) = build $ initializationNodes %= (i:)
|
||||||
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
||||||
-- the given 'Build' action.
|
-- the given 'Build' action.
|
||||||
asGraphDef :: Build a -> GraphDef
|
asGraphDef :: Build a -> GraphDef
|
||||||
asGraphDef b = def & node .~ gs ^. nodeBuffer
|
asGraphDef b = defMessage & node .~ gs ^. nodeBuffer
|
||||||
where
|
where
|
||||||
gs = snd $ runIdentity $ runBuildT b
|
gs = snd $ runIdentity $ runBuildT b
|
||||||
|
|
||||||
|
@ -285,7 +286,7 @@ getPendingNode o = do
|
||||||
let controlInputs
|
let controlInputs
|
||||||
= map makeDep (o ^. opControlInputs ++ Set.toList controls)
|
= map makeDep (o ^. opControlInputs ++ Set.toList controls)
|
||||||
return $ PendingNode scope (o ^. opName)
|
return $ PendingNode scope (o ^. opName)
|
||||||
$ def & op .~ (unOpType (o ^. opType) :: Text)
|
$ defMessage & op .~ (unOpType (o ^. opType) :: Text)
|
||||||
& attr .~ _opAttrs o
|
& attr .~ _opAttrs o
|
||||||
& input .~ (inputs ++ controlInputs)
|
& input .~ (inputs ++ controlInputs)
|
||||||
& device .~ dev
|
& device .~ dev
|
||||||
|
|
|
@ -35,14 +35,14 @@ module TensorFlow.Output
|
||||||
, PendingNodeName(..)
|
, PendingNodeName(..)
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Data.ProtoLens.Message(defMessage)
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
import Data.String (IsString(..))
|
import Data.String (IsString(..))
|
||||||
import Data.Text (Text)
|
import Data.Text (Text)
|
||||||
import qualified Data.Text as Text
|
import qualified Data.Text as Text
|
||||||
import Lens.Family2 (Lens')
|
import Lens.Family2 (Lens')
|
||||||
import Lens.Family2.Unchecked (lens)
|
import Lens.Family2.Unchecked (lens)
|
||||||
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
|
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue)
|
||||||
import Data.Default (def)
|
|
||||||
import TensorFlow.Types (Attribute, attrLens)
|
import TensorFlow.Types (Attribute, attrLens)
|
||||||
|
|
||||||
-- | A type of graph node which has no outputs. These nodes are
|
-- | A type of graph node which has no outputs. These nodes are
|
||||||
|
@ -108,7 +108,7 @@ opType = lens _opType (\o x -> o { _opType = x})
|
||||||
|
|
||||||
opAttr :: Attribute a => Text -> Lens' OpDef a
|
opAttr :: Attribute a => Text -> Lens' OpDef a
|
||||||
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
|
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
|
||||||
. lens (Map.findWithDefault def n) (flip (Map.insert n))
|
. lens (Map.findWithDefault defMessage n) (flip (Map.insert n))
|
||||||
. attrLens
|
. attrLens
|
||||||
|
|
||||||
opInputs :: Lens' OpDef [Output]
|
opInputs :: Lens' OpDef [Output]
|
||||||
|
|
|
@ -37,7 +37,9 @@ module TensorFlow.Session (
|
||||||
asyncProdNodes,
|
asyncProdNodes,
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Data.ProtoLens.Message(defMessage)
|
||||||
import Control.Monad (forever, unless, void)
|
import Control.Monad (forever, unless, void)
|
||||||
|
import Control.Monad.Fail (MonadFail(..))
|
||||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
import Control.Monad.Trans.Class (MonadTrans, lift)
|
import Control.Monad.Trans.Class (MonadTrans, lift)
|
||||||
|
@ -78,7 +80,7 @@ data SessionState
|
||||||
newtype SessionT m a
|
newtype SessionT m a
|
||||||
= Session (ReaderT SessionState (BuildT m) a)
|
= Session (ReaderT SessionState (BuildT m) a)
|
||||||
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
||||||
MonadMask)
|
MonadMask, MonadFail)
|
||||||
|
|
||||||
instance MonadTrans SessionT where
|
instance MonadTrans SessionT where
|
||||||
lift = Session . lift . lift
|
lift = Session . lift . lift
|
||||||
|
@ -100,7 +102,7 @@ data Options = Options
|
||||||
instance Default Options where
|
instance Default Options where
|
||||||
def = Options
|
def = Options
|
||||||
{ _sessionTarget = ""
|
{ _sessionTarget = ""
|
||||||
, _sessionConfig = def
|
, _sessionConfig = defMessage
|
||||||
, _sessionTracer = const (return ())
|
, _sessionTracer = const (return ())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,7 +144,7 @@ extend = do
|
||||||
trace <- Session (asks tracer)
|
trace <- Session (asks tracer)
|
||||||
nodesToExtend <- build flushNodeBuffer
|
nodesToExtend <- build flushNodeBuffer
|
||||||
unless (null nodesToExtend) $ liftIO $ do
|
unless (null nodesToExtend) $ liftIO $ do
|
||||||
let graphDef = (def :: GraphDef) & node .~ nodesToExtend
|
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
|
||||||
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
||||||
FFI.extendGraph session graphDef
|
FFI.extendGraph session graphDef
|
||||||
-- Now that all the nodes are created, run the initializers.
|
-- Now that all the nodes are created, run the initializers.
|
||||||
|
|
|
@ -64,9 +64,9 @@ module TensorFlow.Types
|
||||||
, AllTensorTypes
|
, AllTensorTypes
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Data.ProtoLens.Message(defMessage)
|
||||||
import Data.Functor.Identity (Identity(..))
|
import Data.Functor.Identity (Identity(..))
|
||||||
import Data.Complex (Complex)
|
import Data.Complex (Complex)
|
||||||
import Data.Default (def)
|
|
||||||
import Data.Int (Int8, Int16, Int32, Int64)
|
import Data.Int (Int8, Int16, Int32, Int64)
|
||||||
import Data.Maybe (fromMaybe)
|
import Data.Maybe (fromMaybe)
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
|
@ -88,8 +88,8 @@ import qualified Data.ByteString.Lazy as L
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import qualified Data.Vector.Storable as S
|
import qualified Data.Vector.Storable as S
|
||||||
import Proto.Tensorflow.Core.Framework.AttrValue
|
import Proto.Tensorflow.Core.Framework.AttrValue
|
||||||
( AttrValue(..)
|
( AttrValue
|
||||||
, AttrValue'ListValue(..)
|
, AttrValue'ListValue
|
||||||
)
|
)
|
||||||
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
|
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
|
||||||
( b
|
( b
|
||||||
|
@ -105,7 +105,7 @@ import Proto.Tensorflow.Core.Framework.AttrValue_Fields
|
||||||
import Proto.Tensorflow.Core.Framework.ResourceHandle
|
import Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||||
(ResourceHandleProto)
|
(ResourceHandleProto)
|
||||||
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||||
(TensorProto(..))
|
(TensorProto)
|
||||||
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
|
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
|
||||||
( boolVal
|
( boolVal
|
||||||
, doubleVal
|
, doubleVal
|
||||||
|
@ -119,7 +119,7 @@ import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
(TensorShapeProto(..))
|
(TensorShapeProto)
|
||||||
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
|
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
|
||||||
( dim
|
( dim
|
||||||
, size
|
, size
|
||||||
|
@ -400,7 +400,7 @@ protoShape = iso protoToShape shapeToProto
|
||||||
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
|
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
|
||||||
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
|
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
|
||||||
++ showMessageShort p
|
++ showMessageShort p
|
||||||
shapeToProto s' = def & protoMaybeShape .~ Just s'
|
shapeToProto s' = defMessage & protoMaybeShape .~ Just s'
|
||||||
|
|
||||||
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
|
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
|
||||||
protoMaybeShape = iso protoToShape shapeToProto
|
protoMaybeShape = iso protoToShape shapeToProto
|
||||||
|
@ -412,9 +412,9 @@ protoMaybeShape = iso protoToShape shapeToProto
|
||||||
else Just (Shape (p ^.. dim . traverse . size))
|
else Just (Shape (p ^.. dim . traverse . size))
|
||||||
shapeToProto :: Maybe Shape -> TensorShapeProto
|
shapeToProto :: Maybe Shape -> TensorShapeProto
|
||||||
shapeToProto Nothing =
|
shapeToProto Nothing =
|
||||||
def & unknownRank .~ True
|
defMessage & unknownRank .~ True
|
||||||
shapeToProto (Just (Shape ds)) =
|
shapeToProto (Just (Shape ds)) =
|
||||||
def & dim .~ fmap (\d -> def & size .~ d) ds
|
defMessage & dim .~ fmap (\d -> defMessage & size .~ d) ds
|
||||||
|
|
||||||
|
|
||||||
class Attribute a where
|
class Attribute a where
|
||||||
|
|
|
@ -36,7 +36,7 @@ library
|
||||||
, TensorFlow.Types
|
, TensorFlow.Types
|
||||||
other-modules: TensorFlow.Internal.Raw
|
other-modules: TensorFlow.Internal.Raw
|
||||||
build-tools: c2hs
|
build-tools: c2hs
|
||||||
build-depends: proto-lens == 0.3.*
|
build-depends: proto-lens == 0.4.*
|
||||||
, tensorflow-proto == 0.2.*
|
, tensorflow-proto == 0.2.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, async
|
, async
|
||||||
|
|
Loading…
Reference in a new issue