1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

Merge branch 'master' into depthconvgrad

This commit is contained in:
fkm3 2019-04-22 00:04:06 -04:00 committed by GitHub
commit 13c23d86c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 141 additions and 57 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
**/.stack-work
.stack/
tensorflow-mnist-input-data/data/*.gz
.DS_Store

View File

@ -11,8 +11,8 @@ let
pkgs = import nixpkgs {};
in
pkgs.haskell.lib.buildStackProject {
# Either use specified GHC or use GHC 8.4.4 (which we need for LTS 12.26)
ghc = if isNull ghc then pkgs.haskell.compiler.ghc844 else ghc;
# Either use specified GHC or use GHC 8.6.4 (which we need for LTS 13.13)
ghc = if isNull ghc then pkgs.haskell.compiler.ghc864 else ghc;
extraArgs = "--system-ghc";
name = "tf-env";
buildInputs = with pkgs; [ snappy zlib protobuf libtensorflow ];

View File

@ -1,4 +1,4 @@
resolver: lts-12.26
resolver: lts-13.13
packages:
- tensorflow
@ -15,6 +15,7 @@ packages:
extra-deps:
- snappy-framing-0.1.2
- snappy-0.2.0.2
# For Mac OS X, whose linker doesn't use this path by default
# unless you run `xcode-select --install`.

View File

@ -15,7 +15,7 @@ cabal-version: >=1.24
library
exposed-modules: TensorFlow.GenOps.Core
build-depends: bytestring
, proto-lens == 0.3.*
, proto-lens == 0.4.*
, tensorflow == 0.2.*
, base >= 4.7 && < 5
, lens-family
@ -26,7 +26,7 @@ custom-setup
setup-depends: Cabal
, bytestring
, directory
, proto-lens == 0.3.*
, proto-lens == 0.4.*
, tensorflow-opgen == 0.2.*
, tensorflow == 0.2.*
, base >= 4.7 && < 5

View File

@ -60,7 +60,7 @@ import Control.Monad.Trans.Resource (runResourceT)
import Data.ByteString (ByteString)
import Data.Conduit ((.|))
import Data.Conduit.TQueue (sourceTBMQueue)
import Data.Default (def)
import Data.ProtoLens.Default(def)
import Data.Int (Int64)
import Data.Word (Word8, Word16)
import Data.ProtoLens (encodeMessage)

View File

@ -24,7 +24,7 @@ library
, filepath
, hostname
, lens-family
, proto-lens == 0.3.*
, proto-lens == 0.4.*
, resourcet
, stm
, stm-chans

View File

@ -17,7 +17,7 @@ module Main where
import Control.Monad.Trans.Resource (runResourceT)
import Data.Conduit ((.|))
import Data.Default (def)
import Data.ProtoLens.Message (defMessage)
import Data.List ((\\))
import Data.ProtoLens (encodeMessage, decodeMessageOrDie)
import Lens.Family2 ((^.), (.~), (&))
@ -45,9 +45,9 @@ testEventWriter :: Test
testEventWriter = testCase "EventWriter" $
withSystemTempDirectory "event_writer_logs" $ \dir -> do
assertEqual "No file before" [] =<< listDirectory dir
let expected = [ (def :: Event) & step .~ 10
, def & step .~ 222
, def & step .~ 8
let expected = [ (defMessage :: Event) & step .~ 10
, defMessage & step .~ 222
, defMessage & step .~ 8
]
withEventWriter dir $ \eventWriter ->
mapM_ (logEvent eventWriter) expected
@ -66,7 +66,7 @@ testLogGraph = testCase "LogGraph" $
withSystemTempDirectory "event_writer_logs" $ \dir -> do
let graphBuild = noOp :: Build ControlNode
expectedGraph = asGraphDef graphBuild
expectedGraphEvent = (def :: Event) & graphDef .~ (encodeMessage expectedGraph)
expectedGraphEvent = (defMessage :: Event) & graphDef .~ (encodeMessage expectedGraph)
withEventWriter dir $ \eventWriter ->
logGraph eventWriter graphBuild

View File

@ -20,7 +20,7 @@ library
exposed-modules: TensorFlow.Examples.MNIST.Parse
, TensorFlow.Examples.MNIST.TrainedGraph
other-modules: Paths_tensorflow_mnist
build-depends: proto-lens == 0.3.*
build-depends: proto-lens == 0.4.*
, base >= 4.7 && < 5
, binary
, bytestring

View File

@ -24,12 +24,12 @@ import qualified Data.Text.IO as Text
import Lens.Family2 ((&), (.~), (^.))
import Prelude hiding (abs)
import Proto.Tensorflow.Core.Framework.Graph
( GraphDef(..) )
( GraphDef )
import Proto.Tensorflow.Core.Framework.Graph_Fields
( version
, node )
import Proto.Tensorflow.Core.Framework.NodeDef
( NodeDef(..) )
( NodeDef )
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
import System.IO as IO
import TensorFlow.Examples.MNIST.InputData

View File

@ -51,7 +51,8 @@ module TensorFlow.OpGen
import Data.Foldable (toList)
import Data.Maybe (fromMaybe)
import Data.ProtoLens (def, showMessage)
import Data.ProtoLens.Default(def)
import Data.ProtoLens (showMessage)
import Data.List (sortOn)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE

View File

@ -16,7 +16,7 @@ library
hs-source-dirs: src
exposed-modules: TensorFlow.OpGen.ParsedOp
, TensorFlow.OpGen
build-depends: proto-lens == 0.3.*
build-depends: proto-lens == 0.4.*
, tensorflow-proto == 0.2.*
, base >= 4.7 && < 5
, bytestring

View File

@ -31,11 +31,12 @@ import Control.Monad (forM, zipWithM)
import Control.Monad.State.Strict (State, evalState, gets, modify)
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Default (def)
import Data.ProtoLens.Default(def)
import Data.Int (Int32, Int64)
import Data.Foldable (foldlM)
import Data.List (foldl', sortBy)
import Data.Map.Strict (Map)
import qualified Data.IntSet as IntSet
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
import Data.Ord (comparing)
import Data.ProtoLens.TextFormat (showMessage)
@ -165,6 +166,11 @@ gradients y xs = build $ do
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
. flip Map.lookup
let (gr, nodeMap) = createGraph yName nodeDefLookup
xnodes = mapMaybe (\x -> nodeMap ^. (at . outputNodeName . renderedOutput $ x)) xs
-- make a set of the nodes reachable from the xnodes
-- The xnodes are not part of this set (unless reachable from another xnode)
reachableSet = computeReachableSet xnodes gr
-- Set gradient of y to one.
-- TODO: nicer
let initPending :: Map.Map FGL.Node (PendingGradients a)
@ -175,7 +181,7 @@ gradients y xs = build $ do
.~ [yOne]
)
-- Calculate the gradients of y w.r.t. each node in the graph.
gradientMap <- graphGrads gr initPending
gradientMap <- graphGrads gr reachableSet initPending
-- Lookup the gradients for each x.
forM xs $ \x ->
let Output i xName = renderedOutput x
@ -183,6 +189,13 @@ gradients y xs = build $ do
n <- nodeMap ^. at xName
gradientMap ^. at n . nonEmpty . outputIxAt i
-- | Compute a set of nodes reachable from the start nodes
--
-- the start nodes are excluded, unless reachable from another start node
computeReachableSet :: [FGL.Node] -> Graph -> IntSet.IntSet
computeReachableSet vs g =
IntSet.fromList $ concatMap (drop 1 . FGL.preorder) (FGL.dff vs g)
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
outputIxAt = intAt . unOutputIx
@ -245,16 +258,15 @@ nonEmpty = anon mempty null
-- | Calculate the gradients for every node in a graph.
graphGrads :: forall a. GradientCompatible a
=> Graph
-> IntSet.IntSet
-> Map FGL.Node (PendingGradients a)
-- ^ Initial gradients (usually just 1 for the node of interest).
-> Build (Map FGL.Node (Gradients a))
graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
graphGrads gr reachableSet initPending = view gradientsResult <$> foldlM go initState nodeOrder
where
initState = GradientsState initPending Map.empty
-- Reverse topological sort.
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
-- avoid calculating gradients that won't be used.
nodeOrder = FGL.topsort $ FGL.grev gr
nodeOrder = FGL.topsort . FGL.grev $ gr
go :: GradientsState a -> Int -> Build (GradientsState a)
go state node = do
-- Aggregate the accumulated gradients for this node.
@ -263,11 +275,17 @@ graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrd
if null outputGrads
then pure state
else do
let ctx = FGL.context gr node
inputGrads <- calculateInputGrads ctx outputGrads gr
-- Calculate the gradients for each of the node's inputs.
let nextState = state & gradientsResult %~ Map.insert node outputGrads
pure $ updatePendingGradients ctx inputGrads nextState
-- Only consider nodes that are reachable from the inputs to
-- avoid calculating gradients that won't be used.
if node `IntSet.member` reachableSet
then do
let ctx = FGL.context gr node
inputGrads <- calculateInputGrads ctx outputGrads gr
-- Calculate the gradients for each of the node's inputs.
pure $ updatePendingGradients ctx inputGrads nextState
else
pure nextState
-- | Reduce accumulated gradients for each output to one Tensor.
sumPendingGradient :: GradientCompatible a
@ -874,11 +892,8 @@ opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
-- through each read.
opGrad "ReadVariableOp" _ _ [dz] = [Just $ expr dz]
-- TODO(fmayle): These can go away if we properly prune the graph.
opGrad "Const" _ _ _ = [Nothing, Nothing]
opGrad "Placeholder" _ _ _ = []
opGrad "VarHandleOp" _ _ _ = []
opGrad "Variable" _ _ _ = []
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
where

View File

@ -158,7 +158,7 @@ import Data.Complex (Complex)
import Data.Int (Int32, Int64)
import Data.Word (Word16)
import Prelude hiding (abs, sum, concat)
import Data.ProtoLens (def)
import Data.ProtoLens.Default(def)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import Text.Printf (printf)

View File

@ -11,6 +11,7 @@
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoMonadFailDesugaring #-}
module TensorFlow.Variable
( Variable
, variable

View File

@ -21,7 +21,7 @@ library
, TensorFlow.NN
, TensorFlow.Queue
, TensorFlow.Variable
build-depends: proto-lens == 0.3.*
build-depends: proto-lens == 0.4.*
, base >= 4.7 && < 5
, bytestring
, fgl

View File

@ -15,6 +15,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NoMonadFailDesugaring #-}
-- | Tests for EmbeddingOps.
module Main where

View File

@ -16,6 +16,7 @@
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoMonadFailDesugaring #-}
import Data.Int (Int32, Int64)
import Data.List (sort)
@ -32,8 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
import qualified TensorFlow.GenOps.Core as TF (depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput')
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput')
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
import qualified TensorFlow.Output as TF
@ -124,6 +124,65 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
]
sort expected @=? sort ops
testGradientIncidental :: Test
testGradientIncidental = testCase "testGradientIncidental" $ do
let grads = do
x <- TF.render $ TF.scalar (3 :: Float)
b <- TF.render $ TF.scalar (4 :: Float)
w <- TF.render $ TF.diag $ TF.vector [ 1.0 :: Float ]
let incidental = b `TF.mul` w
let y = (x `TF.mul` b) `TF.add` incidental
TF.gradients y [x]
-- Assert that the gradients are right.
[dx] <- TF.runSession $ grads >>= TF.run
4 @=? TF.unScalar dx
-- Assert that the graph has the expected ops.
let graphDef = TF.asGraphDef grads
putStrLn $ showMessage graphDef
let ops = graphDef ^.. node . traverse . op
expected = [ "Add"
, "BroadcastGradientArgs"
, "BroadcastGradientArgs"
, "Const"
, "Const"
, "Const"
, "Const"
, "Diag"
, "Fill"
, "Mul"
, "Mul"
, "Mul"
, "Mul"
, "Reshape"
, "Reshape"
, "Reshape"
, "Reshape"
, "Shape"
, "Shape"
, "Shape"
, "Shape"
, "Shape"
, "Sum"
, "Sum"
, "Sum"
, "Sum"
]
sort expected @=? sort ops
testGradientPruning :: Test
testGradientPruning = testCase "testGradientPruning" $ do
let grads = do
x <- TF.render $ TF.scalar (3 :: Float)
b <- TF.render $ TF.scalar (4 :: Float)
bx <- TF.render $ b `TF.mul` x
let y = bx `TF.add` b
TF.gradients y [x, bx]
-- Assert that the gradients are right.
[dx, dxb] <- TF.runSession $ grads >>= TF.run
4 @=? TF.unScalar dx
1 @=? TF.unScalar dxb
-- Test that identical "stateful" ops work with createGraph.
testCreateGraphStateful :: Test
@ -588,6 +647,8 @@ main :: IO ()
main = defaultMain
[ testGradientSimple
, testGradientDisconnected
, testGradientIncidental
, testGradientPruning
, testCreateGraphStateful
, testCreateGraphNameScopes
, testDiamond

View File

@ -97,16 +97,16 @@ library
, Proto.Tensorflow.Core.Util.SavedTensorSlice_Fields
, Proto.Tensorflow.Core.Util.TestLog
, Proto.Tensorflow.Core.Util.TestLog_Fields
build-depends: proto-lens == 0.3.*
, proto-lens-protoc == 0.3.*
, proto-lens-protobuf-types == 0.3.*
build-depends: proto-lens == 0.4.*
, proto-lens-runtime == 0.4.*
, proto-lens-protobuf-types == 0.4.*
, base >= 4.7 && < 5
default-language: Haskell2010
include-dirs: .
custom-setup
setup-depends: Cabal
, proto-lens-protoc == 0.3.*
, proto-lens-setup == 0.4.*
, base >= 4.7 && < 5
source-repository head
type: git

View File

@ -61,12 +61,13 @@ module TensorFlow.Build
, withNodeDependencies
) where
import Data.ProtoLens.Message(defMessage)
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.Fix (MonadFix(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Fail (MonadFail(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
import Data.Default (def)
import Data.Functor.Identity (Identity(..))
import qualified Data.Map.Strict as Map
import Data.Monoid ((<>))
@ -191,7 +192,7 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
newtype BuildT m a = BuildT (StateT GraphState m a)
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
MonadFix)
MonadFix, MonadFail)
-- | An action for building nodes in a TensorFlow graph.
type Build = BuildT Identity
@ -236,7 +237,7 @@ addInitializer (ControlNode i) = build $ initializationNodes %= (i:)
-- | Produce a GraphDef proto representation of the nodes that are rendered in
-- the given 'Build' action.
asGraphDef :: Build a -> GraphDef
asGraphDef b = def & node .~ gs ^. nodeBuffer
asGraphDef b = defMessage & node .~ gs ^. nodeBuffer
where
gs = snd $ runIdentity $ runBuildT b
@ -285,7 +286,7 @@ getPendingNode o = do
let controlInputs
= map makeDep (o ^. opControlInputs ++ Set.toList controls)
return $ PendingNode scope (o ^. opName)
$ def & op .~ (unOpType (o ^. opType) :: Text)
$ defMessage & op .~ (unOpType (o ^. opType) :: Text)
& attr .~ _opAttrs o
& input .~ (inputs ++ controlInputs)
& device .~ dev

View File

@ -35,14 +35,14 @@ module TensorFlow.Output
, PendingNodeName(..)
) where
import Data.ProtoLens.Message(defMessage)
import qualified Data.Map.Strict as Map
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens')
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
import Data.Default (def)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue)
import TensorFlow.Types (Attribute, attrLens)
-- | A type of graph node which has no outputs. These nodes are
@ -108,7 +108,7 @@ opType = lens _opType (\o x -> o { _opType = x})
opAttr :: Attribute a => Text -> Lens' OpDef a
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
. lens (Map.findWithDefault def n) (flip (Map.insert n))
. lens (Map.findWithDefault defMessage n) (flip (Map.insert n))
. attrLens
opInputs :: Lens' OpDef [Output]

View File

@ -37,7 +37,9 @@ module TensorFlow.Session (
asyncProdNodes,
) where
import Data.ProtoLens.Message(defMessage)
import Control.Monad (forever, unless, void)
import Control.Monad.Fail (MonadFail(..))
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (MonadTrans, lift)
@ -78,7 +80,7 @@ data SessionState
newtype SessionT m a
= Session (ReaderT SessionState (BuildT m) a)
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
MonadMask)
MonadMask, MonadFail)
instance MonadTrans SessionT where
lift = Session . lift . lift
@ -100,7 +102,7 @@ data Options = Options
instance Default Options where
def = Options
{ _sessionTarget = ""
, _sessionConfig = def
, _sessionConfig = defMessage
, _sessionTracer = const (return ())
}
@ -142,7 +144,7 @@ extend = do
trace <- Session (asks tracer)
nodesToExtend <- build flushNodeBuffer
unless (null nodesToExtend) $ liftIO $ do
let graphDef = (def :: GraphDef) & node .~ nodesToExtend
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
FFI.extendGraph session graphDef
-- Now that all the nodes are created, run the initializers.

View File

@ -64,9 +64,9 @@ module TensorFlow.Types
, AllTensorTypes
) where
import Data.ProtoLens.Message(defMessage)
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
@ -88,8 +88,8 @@ import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import Proto.Tensorflow.Core.Framework.AttrValue
( AttrValue(..)
, AttrValue'ListValue(..)
( AttrValue
, AttrValue'ListValue
)
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
( b
@ -105,7 +105,7 @@ import Proto.Tensorflow.Core.Framework.AttrValue_Fields
import Proto.Tensorflow.Core.Framework.ResourceHandle
(ResourceHandleProto)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
(TensorProto(..))
(TensorProto)
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
( boolVal
, doubleVal
@ -119,7 +119,7 @@ import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
)
import Proto.Tensorflow.Core.Framework.TensorShape
(TensorShapeProto(..))
(TensorShapeProto)
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
( dim
, size
@ -400,7 +400,7 @@ protoShape = iso protoToShape shapeToProto
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
++ showMessageShort p
shapeToProto s' = def & protoMaybeShape .~ Just s'
shapeToProto s' = defMessage & protoMaybeShape .~ Just s'
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape = iso protoToShape shapeToProto
@ -412,9 +412,9 @@ protoMaybeShape = iso protoToShape shapeToProto
else Just (Shape (p ^.. dim . traverse . size))
shapeToProto :: Maybe Shape -> TensorShapeProto
shapeToProto Nothing =
def & unknownRank .~ True
defMessage & unknownRank .~ True
shapeToProto (Just (Shape ds)) =
def & dim .~ fmap (\d -> def & size .~ d) ds
defMessage & dim .~ fmap (\d -> defMessage & size .~ d) ds
class Attribute a where

View File

@ -36,7 +36,7 @@ library
, TensorFlow.Types
other-modules: TensorFlow.Internal.Raw
build-tools: c2hs
build-depends: proto-lens == 0.3.*
build-depends: proto-lens == 0.4.*
, tensorflow-proto == 0.2.*
, base >= 4.7 && < 5
, async