mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
Merge branch 'master' into meangrad
This commit is contained in:
commit
d87a94fc50
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,3 +1,4 @@
|
|||
**/.stack-work
|
||||
.stack/
|
||||
tensorflow-mnist-input-data/data/*.gz
|
||||
.DS_Store
|
||||
|
|
|
@ -11,8 +11,8 @@ let
|
|||
pkgs = import nixpkgs {};
|
||||
in
|
||||
pkgs.haskell.lib.buildStackProject {
|
||||
# Either use specified GHC or use GHC 8.4.4 (which we need for LTS 12.26)
|
||||
ghc = if isNull ghc then pkgs.haskell.compiler.ghc844 else ghc;
|
||||
# Either use specified GHC or use GHC 8.6.4 (which we need for LTS 13.13)
|
||||
ghc = if isNull ghc then pkgs.haskell.compiler.ghc864 else ghc;
|
||||
extraArgs = "--system-ghc";
|
||||
name = "tf-env";
|
||||
buildInputs = with pkgs; [ snappy zlib protobuf libtensorflow ];
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
resolver: lts-12.26
|
||||
resolver: lts-13.13
|
||||
|
||||
packages:
|
||||
- tensorflow
|
||||
|
@ -15,6 +15,7 @@ packages:
|
|||
|
||||
extra-deps:
|
||||
- snappy-framing-0.1.2
|
||||
- snappy-0.2.0.2
|
||||
|
||||
# For Mac OS X, whose linker doesn't use this path by default
|
||||
# unless you run `xcode-select --install`.
|
||||
|
|
|
@ -15,7 +15,7 @@ cabal-version: >=1.24
|
|||
library
|
||||
exposed-modules: TensorFlow.GenOps.Core
|
||||
build-depends: bytestring
|
||||
, proto-lens == 0.3.*
|
||||
, proto-lens == 0.4.*
|
||||
, tensorflow == 0.2.*
|
||||
, base >= 4.7 && < 5
|
||||
, lens-family
|
||||
|
@ -26,7 +26,7 @@ custom-setup
|
|||
setup-depends: Cabal
|
||||
, bytestring
|
||||
, directory
|
||||
, proto-lens == 0.3.*
|
||||
, proto-lens == 0.4.*
|
||||
, tensorflow-opgen == 0.2.*
|
||||
, tensorflow == 0.2.*
|
||||
, base >= 4.7 && < 5
|
||||
|
|
|
@ -60,7 +60,7 @@ import Control.Monad.Trans.Resource (runResourceT)
|
|||
import Data.ByteString (ByteString)
|
||||
import Data.Conduit ((.|))
|
||||
import Data.Conduit.TQueue (sourceTBMQueue)
|
||||
import Data.Default (def)
|
||||
import Data.ProtoLens.Default(def)
|
||||
import Data.Int (Int64)
|
||||
import Data.Word (Word8, Word16)
|
||||
import Data.ProtoLens (encodeMessage)
|
||||
|
|
|
@ -24,7 +24,7 @@ library
|
|||
, filepath
|
||||
, hostname
|
||||
, lens-family
|
||||
, proto-lens == 0.3.*
|
||||
, proto-lens == 0.4.*
|
||||
, resourcet
|
||||
, stm
|
||||
, stm-chans
|
||||
|
|
|
@ -17,7 +17,7 @@ module Main where
|
|||
|
||||
import Control.Monad.Trans.Resource (runResourceT)
|
||||
import Data.Conduit ((.|))
|
||||
import Data.Default (def)
|
||||
import Data.ProtoLens.Message (defMessage)
|
||||
import Data.List ((\\))
|
||||
import Data.ProtoLens (encodeMessage, decodeMessageOrDie)
|
||||
import Lens.Family2 ((^.), (.~), (&))
|
||||
|
@ -45,9 +45,9 @@ testEventWriter :: Test
|
|||
testEventWriter = testCase "EventWriter" $
|
||||
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||
assertEqual "No file before" [] =<< listDirectory dir
|
||||
let expected = [ (def :: Event) & step .~ 10
|
||||
, def & step .~ 222
|
||||
, def & step .~ 8
|
||||
let expected = [ (defMessage :: Event) & step .~ 10
|
||||
, defMessage & step .~ 222
|
||||
, defMessage & step .~ 8
|
||||
]
|
||||
withEventWriter dir $ \eventWriter ->
|
||||
mapM_ (logEvent eventWriter) expected
|
||||
|
@ -66,7 +66,7 @@ testLogGraph = testCase "LogGraph" $
|
|||
withSystemTempDirectory "event_writer_logs" $ \dir -> do
|
||||
let graphBuild = noOp :: Build ControlNode
|
||||
expectedGraph = asGraphDef graphBuild
|
||||
expectedGraphEvent = (def :: Event) & graphDef .~ (encodeMessage expectedGraph)
|
||||
expectedGraphEvent = (defMessage :: Event) & graphDef .~ (encodeMessage expectedGraph)
|
||||
|
||||
withEventWriter dir $ \eventWriter ->
|
||||
logGraph eventWriter graphBuild
|
||||
|
|
|
@ -20,7 +20,7 @@ library
|
|||
exposed-modules: TensorFlow.Examples.MNIST.Parse
|
||||
, TensorFlow.Examples.MNIST.TrainedGraph
|
||||
other-modules: Paths_tensorflow_mnist
|
||||
build-depends: proto-lens == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, base >= 4.7 && < 5
|
||||
, binary
|
||||
, bytestring
|
||||
|
|
|
@ -24,12 +24,12 @@ import qualified Data.Text.IO as Text
|
|||
import Lens.Family2 ((&), (.~), (^.))
|
||||
import Prelude hiding (abs)
|
||||
import Proto.Tensorflow.Core.Framework.Graph
|
||||
( GraphDef(..) )
|
||||
( GraphDef )
|
||||
import Proto.Tensorflow.Core.Framework.Graph_Fields
|
||||
( version
|
||||
, node )
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
( NodeDef(..) )
|
||||
( NodeDef )
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
|
||||
import System.IO as IO
|
||||
import TensorFlow.Examples.MNIST.InputData
|
||||
|
|
|
@ -51,7 +51,8 @@ module TensorFlow.OpGen
|
|||
|
||||
import Data.Foldable (toList)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.ProtoLens (def, showMessage)
|
||||
import Data.ProtoLens.Default(def)
|
||||
import Data.ProtoLens (showMessage)
|
||||
import Data.List (sortOn)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as NE
|
||||
|
|
|
@ -16,7 +16,7 @@ library
|
|||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.OpGen.ParsedOp
|
||||
, TensorFlow.OpGen
|
||||
build-depends: proto-lens == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, tensorflow-proto == 0.2.*
|
||||
, base >= 4.7 && < 5
|
||||
, bytestring
|
||||
|
|
|
@ -31,11 +31,12 @@ import Control.Monad (forM, zipWithM)
|
|||
import Control.Monad.State.Strict (State, evalState, gets, modify)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.ProtoLens.Default(def)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.Foldable (foldlM)
|
||||
import Data.List (foldl', sortBy)
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.IntSet as IntSet
|
||||
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
||||
import Data.Ord (comparing)
|
||||
import Data.ProtoLens.TextFormat (showMessage)
|
||||
|
@ -165,6 +166,11 @@ gradients y xs = build $ do
|
|||
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
||||
. flip Map.lookup
|
||||
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
||||
xnodes = mapMaybe (\x -> nodeMap ^. (at . outputNodeName . renderedOutput $ x)) xs
|
||||
-- make a set of the nodes reachable from the xnodes
|
||||
-- The xnodes are not part of this set (unless reachable from another xnode)
|
||||
reachableSet = computeReachableSet xnodes gr
|
||||
|
||||
-- Set gradient of y to one.
|
||||
-- TODO: nicer
|
||||
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
||||
|
@ -175,7 +181,7 @@ gradients y xs = build $ do
|
|||
.~ [yOne]
|
||||
)
|
||||
-- Calculate the gradients of y w.r.t. each node in the graph.
|
||||
gradientMap <- graphGrads gr initPending
|
||||
gradientMap <- graphGrads gr reachableSet initPending
|
||||
-- Lookup the gradients for each x.
|
||||
forM xs $ \x ->
|
||||
let Output i xName = renderedOutput x
|
||||
|
@ -183,6 +189,13 @@ gradients y xs = build $ do
|
|||
n <- nodeMap ^. at xName
|
||||
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||
|
||||
-- | Compute a set of nodes reachable from the start nodes
|
||||
--
|
||||
-- the start nodes are excluded, unless reachable from another start node
|
||||
computeReachableSet :: [FGL.Node] -> Graph -> IntSet.IntSet
|
||||
computeReachableSet vs g =
|
||||
IntSet.fromList $ concatMap (drop 1 . FGL.preorder) (FGL.dff vs g)
|
||||
|
||||
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
||||
outputIxAt = intAt . unOutputIx
|
||||
|
||||
|
@ -245,16 +258,15 @@ nonEmpty = anon mempty null
|
|||
-- | Calculate the gradients for every node in a graph.
|
||||
graphGrads :: forall a. GradientCompatible a
|
||||
=> Graph
|
||||
-> IntSet.IntSet
|
||||
-> Map FGL.Node (PendingGradients a)
|
||||
-- ^ Initial gradients (usually just 1 for the node of interest).
|
||||
-> Build (Map FGL.Node (Gradients a))
|
||||
graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
|
||||
graphGrads gr reachableSet initPending = view gradientsResult <$> foldlM go initState nodeOrder
|
||||
where
|
||||
initState = GradientsState initPending Map.empty
|
||||
-- Reverse topological sort.
|
||||
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
|
||||
-- avoid calculating gradients that won't be used.
|
||||
nodeOrder = FGL.topsort $ FGL.grev gr
|
||||
nodeOrder = FGL.topsort . FGL.grev $ gr
|
||||
go :: GradientsState a -> Int -> Build (GradientsState a)
|
||||
go state node = do
|
||||
-- Aggregate the accumulated gradients for this node.
|
||||
|
@ -263,11 +275,17 @@ graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrd
|
|||
if null outputGrads
|
||||
then pure state
|
||||
else do
|
||||
let ctx = FGL.context gr node
|
||||
inputGrads <- calculateInputGrads ctx outputGrads gr
|
||||
-- Calculate the gradients for each of the node's inputs.
|
||||
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
||||
pure $ updatePendingGradients ctx inputGrads nextState
|
||||
-- Only consider nodes that are reachable from the inputs to
|
||||
-- avoid calculating gradients that won't be used.
|
||||
if node `IntSet.member` reachableSet
|
||||
then do
|
||||
let ctx = FGL.context gr node
|
||||
inputGrads <- calculateInputGrads ctx outputGrads gr
|
||||
-- Calculate the gradients for each of the node's inputs.
|
||||
pure $ updatePendingGradients ctx inputGrads nextState
|
||||
else
|
||||
pure nextState
|
||||
|
||||
-- | Reduce accumulated gradients for each output to one Tensor.
|
||||
sumPendingGradient :: GradientCompatible a
|
||||
|
@ -839,12 +857,9 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
|
|||
-- through each read.
|
||||
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
|
||||
|
||||
-- TODO(fmayle): These can go away if we properly prune the graph.
|
||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "Placeholder" _ _ _ = []
|
||||
opGrad "StopGradient" _ _ _ = [Nothing]
|
||||
opGrad "VarHandleOp" _ _ _ = []
|
||||
opGrad "Variable" _ _ _ = []
|
||||
|
||||
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
|
||||
where
|
||||
|
|
|
@ -158,7 +158,7 @@ import Data.Complex (Complex)
|
|||
import Data.Int (Int32, Int64)
|
||||
import Data.Word (Word16)
|
||||
import Prelude hiding (abs, sum, concat)
|
||||
import Data.ProtoLens (def)
|
||||
import Data.ProtoLens.Default(def)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import Text.Printf (printf)
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
{-# LANGUAGE RecursiveDo #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE NoMonadFailDesugaring #-}
|
||||
module TensorFlow.Variable
|
||||
( Variable
|
||||
, variable
|
||||
|
|
|
@ -21,7 +21,7 @@ library
|
|||
, TensorFlow.NN
|
||||
, TensorFlow.Queue
|
||||
, TensorFlow.Variable
|
||||
build-depends: proto-lens == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, base >= 4.7 && < 5
|
||||
, bytestring
|
||||
, fgl
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE NoMonadFailDesugaring #-}
|
||||
|
||||
-- | Tests for EmbeddingOps.
|
||||
module Main where
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE NoMonadFailDesugaring #-}
|
||||
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (sort)
|
||||
|
@ -32,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
|
||||
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag)
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
|
||||
import qualified TensorFlow.Output as TF
|
||||
|
@ -123,6 +124,65 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
|||
]
|
||||
sort expected @=? sort ops
|
||||
|
||||
testGradientIncidental :: Test
|
||||
testGradientIncidental = testCase "testGradientIncidental" $ do
|
||||
let grads = do
|
||||
x <- TF.render $ TF.scalar (3 :: Float)
|
||||
b <- TF.render $ TF.scalar (4 :: Float)
|
||||
w <- TF.render $ TF.diag $ TF.vector [ 1.0 :: Float ]
|
||||
let incidental = b `TF.mul` w
|
||||
let y = (x `TF.mul` b) `TF.add` incidental
|
||||
TF.gradients y [x]
|
||||
|
||||
-- Assert that the gradients are right.
|
||||
[dx] <- TF.runSession $ grads >>= TF.run
|
||||
4 @=? TF.unScalar dx
|
||||
-- Assert that the graph has the expected ops.
|
||||
let graphDef = TF.asGraphDef grads
|
||||
putStrLn $ showMessage graphDef
|
||||
let ops = graphDef ^.. node . traverse . op
|
||||
expected = [ "Add"
|
||||
, "BroadcastGradientArgs"
|
||||
, "BroadcastGradientArgs"
|
||||
, "Const"
|
||||
, "Const"
|
||||
, "Const"
|
||||
, "Const"
|
||||
, "Diag"
|
||||
, "Fill"
|
||||
, "Mul"
|
||||
, "Mul"
|
||||
, "Mul"
|
||||
, "Mul"
|
||||
, "Reshape"
|
||||
, "Reshape"
|
||||
, "Reshape"
|
||||
, "Reshape"
|
||||
, "Shape"
|
||||
, "Shape"
|
||||
, "Shape"
|
||||
, "Shape"
|
||||
, "Shape"
|
||||
, "Sum"
|
||||
, "Sum"
|
||||
, "Sum"
|
||||
, "Sum"
|
||||
]
|
||||
sort expected @=? sort ops
|
||||
|
||||
testGradientPruning :: Test
|
||||
testGradientPruning = testCase "testGradientPruning" $ do
|
||||
let grads = do
|
||||
x <- TF.render $ TF.scalar (3 :: Float)
|
||||
b <- TF.render $ TF.scalar (4 :: Float)
|
||||
bx <- TF.render $ b `TF.mul` x
|
||||
let y = bx `TF.add` b
|
||||
TF.gradients y [x, bx]
|
||||
|
||||
-- Assert that the gradients are right.
|
||||
[dx, dxb] <- TF.runSession $ grads >>= TF.run
|
||||
4 @=? TF.unScalar dx
|
||||
1 @=? TF.unScalar dxb
|
||||
|
||||
-- Test that identical "stateful" ops work with createGraph.
|
||||
testCreateGraphStateful :: Test
|
||||
|
@ -562,6 +622,8 @@ main :: IO ()
|
|||
main = defaultMain
|
||||
[ testGradientSimple
|
||||
, testGradientDisconnected
|
||||
, testGradientIncidental
|
||||
, testGradientPruning
|
||||
, testCreateGraphStateful
|
||||
, testCreateGraphNameScopes
|
||||
, testDiamond
|
||||
|
|
|
@ -97,16 +97,16 @@ library
|
|||
, Proto.Tensorflow.Core.Util.SavedTensorSlice_Fields
|
||||
, Proto.Tensorflow.Core.Util.TestLog
|
||||
, Proto.Tensorflow.Core.Util.TestLog_Fields
|
||||
build-depends: proto-lens == 0.3.*
|
||||
, proto-lens-protoc == 0.3.*
|
||||
, proto-lens-protobuf-types == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, proto-lens-runtime == 0.4.*
|
||||
, proto-lens-protobuf-types == 0.4.*
|
||||
, base >= 4.7 && < 5
|
||||
default-language: Haskell2010
|
||||
include-dirs: .
|
||||
|
||||
custom-setup
|
||||
setup-depends: Cabal
|
||||
, proto-lens-protoc == 0.3.*
|
||||
, proto-lens-setup == 0.4.*
|
||||
, base >= 4.7 && < 5
|
||||
source-repository head
|
||||
type: git
|
||||
|
|
|
@ -61,12 +61,13 @@ module TensorFlow.Build
|
|||
, withNodeDependencies
|
||||
) where
|
||||
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||
import Control.Monad.Fix (MonadFix(..))
|
||||
import Control.Monad.IO.Class (MonadIO(..))
|
||||
import Control.Monad.Fail (MonadFail(..))
|
||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||
import Data.Default (def)
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.Monoid ((<>))
|
||||
|
@ -191,7 +192,7 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
|
|||
newtype BuildT m a = BuildT (StateT GraphState m a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
||||
MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
|
||||
MonadFix)
|
||||
MonadFix, MonadFail)
|
||||
|
||||
-- | An action for building nodes in a TensorFlow graph.
|
||||
type Build = BuildT Identity
|
||||
|
@ -236,7 +237,7 @@ addInitializer (ControlNode i) = build $ initializationNodes %= (i:)
|
|||
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
||||
-- the given 'Build' action.
|
||||
asGraphDef :: Build a -> GraphDef
|
||||
asGraphDef b = def & node .~ gs ^. nodeBuffer
|
||||
asGraphDef b = defMessage & node .~ gs ^. nodeBuffer
|
||||
where
|
||||
gs = snd $ runIdentity $ runBuildT b
|
||||
|
||||
|
@ -285,7 +286,7 @@ getPendingNode o = do
|
|||
let controlInputs
|
||||
= map makeDep (o ^. opControlInputs ++ Set.toList controls)
|
||||
return $ PendingNode scope (o ^. opName)
|
||||
$ def & op .~ (unOpType (o ^. opType) :: Text)
|
||||
$ defMessage & op .~ (unOpType (o ^. opType) :: Text)
|
||||
& attr .~ _opAttrs o
|
||||
& input .~ (inputs ++ controlInputs)
|
||||
& device .~ dev
|
||||
|
|
|
@ -35,14 +35,14 @@ module TensorFlow.Output
|
|||
, PendingNodeName(..)
|
||||
) where
|
||||
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.String (IsString(..))
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens')
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
|
||||
import Data.Default (def)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue)
|
||||
import TensorFlow.Types (Attribute, attrLens)
|
||||
|
||||
-- | A type of graph node which has no outputs. These nodes are
|
||||
|
@ -108,7 +108,7 @@ opType = lens _opType (\o x -> o { _opType = x})
|
|||
|
||||
opAttr :: Attribute a => Text -> Lens' OpDef a
|
||||
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
|
||||
. lens (Map.findWithDefault def n) (flip (Map.insert n))
|
||||
. lens (Map.findWithDefault defMessage n) (flip (Map.insert n))
|
||||
. attrLens
|
||||
|
||||
opInputs :: Lens' OpDef [Output]
|
||||
|
|
|
@ -37,7 +37,9 @@ module TensorFlow.Session (
|
|||
asyncProdNodes,
|
||||
) where
|
||||
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
import Control.Monad (forever, unless, void)
|
||||
import Control.Monad.Fail (MonadFail(..))
|
||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Control.Monad.Trans.Class (MonadTrans, lift)
|
||||
|
@ -78,7 +80,7 @@ data SessionState
|
|||
newtype SessionT m a
|
||||
= Session (ReaderT SessionState (BuildT m) a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
||||
MonadMask)
|
||||
MonadMask, MonadFail)
|
||||
|
||||
instance MonadTrans SessionT where
|
||||
lift = Session . lift . lift
|
||||
|
@ -100,7 +102,7 @@ data Options = Options
|
|||
instance Default Options where
|
||||
def = Options
|
||||
{ _sessionTarget = ""
|
||||
, _sessionConfig = def
|
||||
, _sessionConfig = defMessage
|
||||
, _sessionTracer = const (return ())
|
||||
}
|
||||
|
||||
|
@ -142,7 +144,7 @@ extend = do
|
|||
trace <- Session (asks tracer)
|
||||
nodesToExtend <- build flushNodeBuffer
|
||||
unless (null nodesToExtend) $ liftIO $ do
|
||||
let graphDef = (def :: GraphDef) & node .~ nodesToExtend
|
||||
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
|
||||
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
||||
FFI.extendGraph session graphDef
|
||||
-- Now that all the nodes are created, run the initializers.
|
||||
|
|
|
@ -64,9 +64,9 @@ module TensorFlow.Types
|
|||
, AllTensorTypes
|
||||
) where
|
||||
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int8, Int16, Int32, Int64)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Monoid ((<>))
|
||||
|
@ -88,8 +88,8 @@ import qualified Data.ByteString.Lazy as L
|
|||
import qualified Data.Vector as V
|
||||
import qualified Data.Vector.Storable as S
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue
|
||||
( AttrValue(..)
|
||||
, AttrValue'ListValue(..)
|
||||
( AttrValue
|
||||
, AttrValue'ListValue
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
|
||||
( b
|
||||
|
@ -105,7 +105,7 @@ import Proto.Tensorflow.Core.Framework.AttrValue_Fields
|
|||
import Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||
(ResourceHandleProto)
|
||||
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||
(TensorProto(..))
|
||||
(TensorProto)
|
||||
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
|
||||
( boolVal
|
||||
, doubleVal
|
||||
|
@ -119,7 +119,7 @@ import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
|
|||
)
|
||||
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||
(TensorShapeProto(..))
|
||||
(TensorShapeProto)
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
|
||||
( dim
|
||||
, size
|
||||
|
@ -400,7 +400,7 @@ protoShape = iso protoToShape shapeToProto
|
|||
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
|
||||
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
|
||||
++ showMessageShort p
|
||||
shapeToProto s' = def & protoMaybeShape .~ Just s'
|
||||
shapeToProto s' = defMessage & protoMaybeShape .~ Just s'
|
||||
|
||||
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
|
||||
protoMaybeShape = iso protoToShape shapeToProto
|
||||
|
@ -412,9 +412,9 @@ protoMaybeShape = iso protoToShape shapeToProto
|
|||
else Just (Shape (p ^.. dim . traverse . size))
|
||||
shapeToProto :: Maybe Shape -> TensorShapeProto
|
||||
shapeToProto Nothing =
|
||||
def & unknownRank .~ True
|
||||
defMessage & unknownRank .~ True
|
||||
shapeToProto (Just (Shape ds)) =
|
||||
def & dim .~ fmap (\d -> def & size .~ d) ds
|
||||
defMessage & dim .~ fmap (\d -> defMessage & size .~ d) ds
|
||||
|
||||
|
||||
class Attribute a where
|
||||
|
|
|
@ -36,7 +36,7 @@ library
|
|||
, TensorFlow.Types
|
||||
other-modules: TensorFlow.Internal.Raw
|
||||
build-tools: c2hs
|
||||
build-depends: proto-lens == 0.3.*
|
||||
build-depends: proto-lens == 0.4.*
|
||||
, tensorflow-proto == 0.2.*
|
||||
, base >= 4.7 && < 5
|
||||
, async
|
||||
|
|
Loading…
Reference in New Issue
Block a user