From 2c5c879037726001053a02cdfbca62642ac858e5 Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Sat, 18 Mar 2017 12:08:53 -0700 Subject: [PATCH] Introduce a MonadBuild class, and remove `buildAnd`. (#83) This change adds a class that both `Build` and `Session` are instances of: class MonadBuild m where build :: Build a -> m a All stateful ops (generated and manually written) now have a signature that returns an instance of `MonadBuild` (rather than just `Build`). For example: assign_ :: (MonadBuild m, TensorType t) => Tensor Ref t -> Tensor v t -> m (Tensor Ref t) This lets us remove a bunch of spurious calls to `build` in user code. It also lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run` (or `run =<< foo`, which is sometimes nicer when foo is a complicated expression). I went ahead and deleted `buildAnd` altogether since it seems to lead to confusion; in particular a few tests had `buildAnd run . pure` which is actually equivalent to just `run`. --- README.md | 8 ++-- tensorflow-mnist/tests/ParseTest.hs | 6 +-- tensorflow-nn/src/TensorFlow/NN.hs | 6 +-- tensorflow-nn/tests/NNTest.hs | 5 +-- tensorflow-opgen/src/TensorFlow/OpGen.hs | 20 ++++++--- tensorflow-ops/src/TensorFlow/EmbeddingOps.hs | 9 ++-- tensorflow-ops/src/TensorFlow/Gradient.hs | 13 +++--- tensorflow-ops/src/TensorFlow/Ops.hs | 43 ++++++++++--------- tensorflow-ops/tests/BuildTest.hs | 24 ++++------- tensorflow-ops/tests/DataFlowOpsTest.hs | 2 +- tensorflow-ops/tests/EmbeddingOpsTest.hs | 22 ++++------ tensorflow-ops/tests/GradientTest.hs | 21 ++++----- tensorflow-ops/tests/OpsTest.hs | 14 +++--- tensorflow-ops/tests/RegressionTest.hs | 8 ++-- tensorflow-ops/tests/TracingTest.hs | 2 +- tensorflow-queue/src/TensorFlow/Queue.hs | 29 ++++++------- tensorflow-queue/tensorflow-queue.cabal | 1 + tensorflow-queue/tests/QueueTest.hs | 11 +++-- tensorflow/src/TensorFlow/Build.hs | 42 ++++++++++-------- tensorflow/src/TensorFlow/ControlFlow.hs | 8 ++-- tensorflow/src/TensorFlow/Core.hs | 3 +- tensorflow/src/TensorFlow/Session.hs | 17 ++------ 22 files changed, 152 insertions(+), 162 deletions(-) diff --git a/README.md b/README.md index 983d400..62793f7 100644 --- a/README.md +++ b/README.md @@ -45,13 +45,13 @@ fit xData yData = TF.runSession $ do let x = TF.vector xData y = TF.vector yData -- Create scalar variables for slope and intercept. - w <- TF.build (TF.initializedVariable 0) - b <- TF.build (TF.initializedVariable 0) + w <- TF.initializedVariable 0 + b <- TF.initializedVariable 0 -- Define the loss function. let yHat = (x `TF.mul` w) `TF.add` b loss = TF.square (yHat `TF.sub` y) -- Optimize with gradient descent. - trainStep <- TF.build (gradientDescent 0.001 loss [w, b]) + trainStep <- gradientDescent 0.001 loss [w, b] replicateM_ 1000 (TF.run trainStep) -- Return the learned parameters. (TF.Scalar w', TF.Scalar b') <- TF.run (w, b) @@ -60,7 +60,7 @@ fit xData yData = TF.runSession $ do gradientDescent :: Float -> TF.Tensor TF.Value Float -> [TF.Tensor TF.Ref Float] - -> TF.Build TF.ControlNode + -> TF.Session TF.ControlNode gradientDescent alpha loss params = do let applyGrad param grad = TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad)) diff --git a/tensorflow-mnist/tests/ParseTest.hs b/tensorflow-mnist/tests/ParseTest.hs index 89b3d4a..95e8f11 100644 --- a/tensorflow-mnist/tests/ParseTest.hs +++ b/tensorflow-mnist/tests/ParseTest.hs @@ -49,7 +49,7 @@ import TensorFlow.Tensor ) import TensorFlow.Ops import TensorFlow.Session - (runSession, run, run_, runWithFeeds, build, buildAnd) + (runSession, run, run_, runWithFeeds, build) import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar) import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) @@ -108,7 +108,7 @@ testGraphDefExec :: Test testGraphDefExec = testCase "testGraphDefExec" $ do let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10 runSession $ do - build $ addGraphDef graphDef + addGraphDef graphDef x <- run $ tensorFromName ValueKind "Mul_2" liftIO $ (50 :: Float) @=? unScalar x @@ -147,7 +147,7 @@ testMNISTExec = testCase "testMNISTExec" $ do wtsCkptPath <- liftIO wtsCkpt biasCkptPath <- liftIO biasCkpt -- Run those restoring nodes on the graph in the current session. - buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a]) + run_ =<< (sequence :: Monad m => [m a] -> m [a]) [ restore wtsCkptPath wts , restoreFromName biasCkptPath "bias" bias ] diff --git a/tensorflow-nn/src/TensorFlow/NN.hs b/tensorflow-nn/src/TensorFlow/NN.hs index 5b6f60f..0ae1c05 100644 --- a/tensorflow-nn/src/TensorFlow/NN.hs +++ b/tensorflow-nn/src/TensorFlow/NN.hs @@ -23,7 +23,7 @@ module TensorFlow.NN import Prelude hiding ( log , exp ) -import TensorFlow.Build ( Build +import TensorFlow.Build ( MonadBuild , render , withNameScope ) @@ -71,10 +71,10 @@ import TensorFlow.Ops ( zerosLike -- -- `logits` and `targets` must have the same type and shape. sigmoidCrossEntropyWithLogits - :: (OneOf '[Float, Double] a, TensorType a, Num a) + :: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a) => Tensor Value a -- ^ __logits__ -> Tensor Value a -- ^ __targets__ - -> Build (Tensor Value a) + -> m (Tensor Value a) sigmoidCrossEntropyWithLogits logits targets = do logits' <- render logits targets' <- render targets diff --git a/tensorflow-nn/tests/NNTest.hs b/tensorflow-nn/tests/NNTest.hs index 23cd92c..d91dd70 100644 --- a/tensorflow-nn/tests/NNTest.hs +++ b/tensorflow-nn/tests/NNTest.hs @@ -22,7 +22,6 @@ import TensorFlow.Test (assertAllClose) import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) import qualified Data.Vector as V -import qualified TensorFlow.Build as TF import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Nodes as TF import qualified TensorFlow.NN as TF @@ -97,8 +96,8 @@ testGradientAtZero = testCase "testGradientAtZero" $ do assertAllClose (head r) (V.fromList [0.5, -0.5]) -run :: TF.Fetchable t a => TF.Build t -> IO a -run = TF.runSession . TF.buildAnd TF.run +run :: TF.Fetchable t a => TF.Session t -> IO a +run = TF.runSession . (>>= TF.run) main :: IO () main = googleTest [ testGradientAtZero diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index c8fa43b..bf2cfe8 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -220,9 +220,12 @@ renderHaskellAttrName :: Attr a -> Doc renderHaskellAttrName = renderHaskellName . attrName functionBody :: ParsedOp -> Doc -functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts)) +functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts)) indent indentation (sep tensorArgs) where + maybeLift + | parsedOpIsMonadic pOp = "build $" + | otherwise = "" buildFunction | null outputListsSizes = "buildOp" | otherwise = "buildListOp" <+> @@ -277,13 +280,18 @@ typeSig pOp = constraints ++ [outputs]) where constraints - | null (inferredTypeAttrs pOp) = empty - | otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>" + | null classConstraints = empty + | otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>" typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]] ++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp] - classConstraints = tuple $ map tensorArgConstraint - $ inferredTypeAttrs pOp + ++ if parsedOpIsMonadic pOp then ["m'"] else [] + -- Use m' as the type parameter to avoid clashing with an attribute name. + monadConstraint + | parsedOpIsMonadic pOp = ["MonadBuild m'"] + | otherwise = [] + classConstraints = monadConstraint ++ map tensorArgConstraint + (inferredTypeAttrs pOp) signatureFold = folddoc (\x y -> x "->" <+> y) attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a) renderAttrType (AttrSingle a) = renderAttrBaseType a @@ -304,7 +312,7 @@ typeSig pOp = constraints [a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as wrapOutput o - | parsedOpIsMonadic pOp = "Build" <+> parens o + | parsedOpIsMonadic pOp = "m'" <+> parens o | otherwise = o -- | Render an op input or output. diff --git a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs index 43da0e5..df1866f 100644 --- a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs +++ b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs @@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where import Control.Monad (zipWithM) import Data.Int (Int32, Int64) -import TensorFlow.Build (Build, colocateWith, render) +import TensorFlow.Build (MonadBuild, colocateWith, render) import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Types (OneOf, TensorType) @@ -44,8 +44,9 @@ import qualified TensorFlow.GenOps.Core as CoreOps -- -- The results of the lookup are concatenated into a dense -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. -embeddingLookup :: forall a b v . - ( TensorType a +embeddingLookup :: forall a b v m . + ( MonadBuild m + , TensorType a , OneOf '[Int64, Int32] b , Num b ) @@ -58,7 +59,7 @@ embeddingLookup :: forall a b v . -- containing the ids to be looked up in `params`. -- The ids are required to have fewer than 2^31 -- entries. - -> Build (Tensor Value a) + -> m (Tensor Value a) -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids) embeddingLookup params@(p0 : _) ids = do diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 4d2833d..f58d4ad 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -56,7 +56,9 @@ import qualified Data.Text as Text import qualified TensorFlow.GenOps.Core as CoreOps import TensorFlow.Build - ( Build + ( MonadBuild + , Build + , build , render , renderNodeName , renderedNodeDefs @@ -111,16 +113,17 @@ type GradientCompatible a = -- | Gradient of @y@ w.r.t. each element of @xs@. -gradients :: forall a v1 v2 . ( Num (Tensor v1 a) +gradients :: forall a v1 v2 m . (MonadBuild m + , Num (Tensor v1 a) -- TODO(gnezdo): remove indirect constraint. - -- It's a wart inherited from Num instance. + -- It's a wart inherited from Num instance. , v1 ~ Value , GradientCompatible a ) => Tensor v1 a -- ^ The output of the graph. -> [Tensor v2 a] -- ^ Tensors for which gradients are computed. - -> Build [Tensor Value a] -gradients y xs = do + -> m [Tensor Value a] +gradients y xs = build $ do -- The gradients are computed using "reverse accumulation", similarly to -- what is described here: -- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 7e33696..086430d 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -155,20 +155,20 @@ matTranspose :: forall a v . TensorType a => Tensor v a -> Tensor Value a matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32]) -placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a) +placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a) placeholder shape' = - buildOp $ opDef "Placeholder" + build $ buildOp $ opDef "Placeholder" & opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "shape" .~ shape' -- | Creates a variable initialized to the given value. -- Initialization happens next time session runs. -initializedVariable :: forall a . TensorType a - => Tensor Value a -> Build (Tensor Ref a) +initializedVariable :: forall a m . (MonadBuild m, TensorType a) + => Tensor Value a -> m (Tensor Ref a) initializedVariable initializer = do v <- CoreOps.variable [] -- The shape is not known initially. (i :: Tensor Ref a) <- - buildOp (opDef "Assign" + build $ buildOp (opDef "Assign" & opAttr "T" .~ tensorType (undefined :: a) & opAttr "use_locking" .~ True & opAttr "validate_shape" .~ False @@ -179,32 +179,32 @@ initializedVariable initializer = do -- | Creates a zero-initialized variable with the given shape. zeroInitializedVariable - :: (TensorType a, Num a) => - TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a) + :: (MonadBuild m, TensorType a, Num a) => + TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a) zeroInitializedVariable = initializedVariable . zeros -- TODO: Support heterogeneous list of tensors. -save :: forall a v . TensorType a +save :: forall a m v . (MonadBuild m, TensorType a) => ByteString -- ^ File path. -> [Tensor v a] -- ^ Tensors to save. - -> Build ControlNode + -> m ControlNode save path xs = do let toByteStringTensor = scalar . encodeUtf8 . unNodeName - names <- mapM (fmap toByteStringTensor . renderNodeName) xs + names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs let types = replicate (length xs) (tensorType (undefined :: a)) let saveOp = buildOp $ opDef "Save" & opAttr "T" .~ types - saveOp (scalar path) (CoreOps.pack names) xs + build $ saveOp (scalar path) (CoreOps.pack names) xs -- | Restore a tensor's value from a checkpoint file. -- -- This version allows restoring from a checkpoint file that uses a different -- tensor name than the variable. -restoreFromName :: forall a . TensorType a +restoreFromName :: forall a m . (MonadBuild m, TensorType a) => ByteString -- ^ File path. -> ByteString -- ^ Tensor name override. -> Tensor Ref a -- ^ Tensor to restore. - -> Build ControlNode + -> m ControlNode restoreFromName path name x = do let restoreOp = buildOp $ opDef "Restore" & opAttr "dt" .~ tensorType (undefined :: a) @@ -212,12 +212,12 @@ restoreFromName path name x = do (restoreOp (scalar path) (scalar name) :: Tensor Value a) -- | Restore a tensor's value from a checkpoint file. -restore :: forall a . TensorType a +restore :: forall a m . (MonadBuild m, TensorType a) => ByteString -- ^ File path. -> Tensor Ref a -- ^ Tensor to restore. - -> Build ControlNode + -> m ControlNode restore path x = do - name <- encodeUtf8 . unNodeName <$> renderNodeName x + name <- encodeUtf8 . unNodeName <$> build (renderNodeName x) restoreFromName path name x -- | Create a constant tensor. @@ -264,12 +264,13 @@ scalar :: forall a . TensorType a => a -> Tensor Value a scalar x = constant [] [x] -- Random tensor from the unit normal distribution with bounded values. -truncatedNormal :: forall a v . TensorType a +truncatedNormal :: forall a m v . (MonadBuild m, TensorType a) => Tensor v Int64 -- ^ Shape. - -> Build (Tensor Value a) -truncatedNormal = buildOp $ opDef "TruncatedNormal" - & opAttr "dtype" .~ tensorType (undefined :: a) - & opAttr "T" .~ tensorType (undefined :: Int64) + -> m (Tensor Value a) +truncatedNormal + = build . buildOp (opDef "TruncatedNormal" + & opAttr "dtype" .~ tensorType (undefined :: a) + & opAttr "T" .~ tensorType (undefined :: Int64)) zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0) diff --git a/tensorflow-ops/tests/BuildTest.hs b/tensorflow-ops/tests/BuildTest.hs index d8bf859..c75cf09 100644 --- a/tensorflow-ops/tests/BuildTest.hs +++ b/tensorflow-ops/tests/BuildTest.hs @@ -19,7 +19,6 @@ module Main where import Control.Monad.IO.Class (liftIO) -import Data.Functor.Identity (runIdentity) import Lens.Family2 ((^.)) import Data.List (sort) import Proto.Tensorflow.Core.Framework.Graph @@ -35,7 +34,6 @@ import TensorFlow.Build , asGraphDef , evalBuildT , flushNodeBuffer - , hoistBuildT , render , withDevice , colocateWith @@ -53,9 +51,7 @@ import TensorFlow.Ops import TensorFlow.Output (Device(..)) import TensorFlow.Tensor (Tensor, Value, Ref) import TensorFlow.Session - ( build - , buildAnd - , run + ( run , runSession , run_ ) @@ -82,7 +78,7 @@ testNamedDeRef = testCase "testNamedDeRef" $ do assign v 5 -- TODO: Implement TensorFlow get_variable and test it. runSession $ do - out <- buildAnd run graph + out <- graph >>= run liftIO $ 5 @=? (unScalar out :: Float) -- | Test that "run" will render and extend any pure ops that haven't already @@ -96,7 +92,7 @@ testPureRender = testCase "testPureRender" $ runSession $ do testInitializedVariable :: Test testInitializedVariable = testCase "testInitializedVariable" $ runSession $ do - (formula, reset) <- build $ do + (formula, reset) <- do v <- initializedVariable 42 r <- assign v 24 return (1 `add` v, r) @@ -109,7 +105,7 @@ testInitializedVariable = testInitializedVariableShape :: Test testInitializedVariableShape = testCase "testInitializedVariableShape" $ runSession $ do - vector <- build $ initializedVariable (constant [1] [42 :: Float]) + vector <- initializedVariable (constant [1] [42 :: Float]) result <- run vector liftIO $ [42] @=? (result :: V.Vector Float) @@ -132,23 +128,19 @@ testNamedAndScoped = testCase "testNamedAndScoped" $ do "RefIdentity" @=? (nodeDef ^. op) "foo1/bar1" @=? (nodeDef ^. name) --- | Lift a Build action into a context for HUnit to run. -liftBuild :: Build a -> BuildT IO a -liftBuild = hoistBuildT (return . runIdentity) - -- | Flush the node buffer and sort the nodes by name (for more stable tests). flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a] -flushed field = sort . map field <$> liftBuild flushNodeBuffer +flushed field = sort . map field <$> flushNodeBuffer -- | Test the interaction of rendering, CSE and scoping. testRenderDedup :: Test testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do - liftBuild renderNodes + renderNodes names <- flushed (^. name) liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names -- Render the nodes in a different scope, which should cause them -- to be distinct from the previous ones. - liftBuild $ withNameScope "foo" renderNodes + withNameScope "foo" renderNodes scopedNames <- flushed (^. name) liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames where @@ -165,7 +157,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do -- | Test the interaction of rendering, CSE and scoping. testDeviceColocation :: Test testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do - liftBuild renderNodes + renderNodes devices <- flushed (\x -> (x ^. name, x ^. device)) liftIO $ [ ("Add_2","dev0") , ("Const_1","dev0") diff --git a/tensorflow-ops/tests/DataFlowOpsTest.hs b/tensorflow-ops/tests/DataFlowOpsTest.hs index 789df51..38fdb5a 100644 --- a/tensorflow-ops/tests/DataFlowOpsTest.hs +++ b/tensorflow-ops/tests/DataFlowOpsTest.hs @@ -45,7 +45,7 @@ testDynamicPartitionStitchInverse (StitchExample numParts values partitions) = restitch = CoreOps.dynamicStitch restitchIndices splitParts in monadicIO $ run $ do fromIntegral numParts @=? length splitParts - valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch + valuesOut <- TF.runSession $ TF.run restitch V.fromList values @=? valuesOut data StitchExample a = StitchExample Int64 [a] [Int32] diff --git a/tensorflow-ops/tests/EmbeddingOpsTest.hs b/tensorflow-ops/tests/EmbeddingOpsTest.hs index 45e5647..137f7f6 100644 --- a/tensorflow-ops/tests/EmbeddingOpsTest.hs +++ b/tensorflow-ops/tests/EmbeddingOpsTest.hs @@ -39,11 +39,6 @@ import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Types as TF import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Build as TF -import qualified TensorFlow.Nodes as TF - - -buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a -buildAndRun = TF.runSession . TF.buildAnd TF.run -- | Tries to perform a simple embedding lookup, with two partitions. @@ -61,9 +56,9 @@ testEmbeddingLookupHasRightShapeWithPartition = let ids = TF.constant (TF.Shape [1, 2]) idValues let op = embeddingLookup embedding ids - (values, shape) <- buildAndRun $ do + (values, shape) <- TF.runSession $ do vs <- op - return (vs, TF.shape vs) + TF.run (vs, TF.shape vs) -- This is the shape that is returned in the equiv. Python. shape @=? V.fromList [1, 2, 3] @@ -87,9 +82,9 @@ testEmbeddingLookupHasRightShape = let ids = TF.constant (TF.Shape [1, 2]) idValues let op = embeddingLookup [embedding] ids - (values, shape) <- buildAndRun $ do + (values, shape) <- TF.runSession $ do vs <- op - return (vs, TF.shape vs) + TF.run (vs, TF.shape vs) -- This is the shape that is returned in the equiv. Python. shape @=? V.fromList [1, 2, 3] @@ -106,7 +101,6 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do let shape = TF.Shape [2] gs <- TF.runSession $ do - grads <- TF.build $ do let embShape = TF.Shape [2, 1] let embeddingInit = [1, 20 ::Float] let idValues = [1, 1 :: Int32] @@ -121,9 +115,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do loss = TF.mean twoNorm (TF.scalar (0 :: Int32)) grad <- fmap head (TF.gradients loss [embedding]) - return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad - - grads (TF.encodeTensorData shape xVals :: TF.TensorData Float) + TF.runWithFeeds + [TF.feed x $ TF.encodeTensorData shape xVals] + grad -- Gradients should be zero (or close) assertAllClose gs (V.fromList ([0, 0 :: Float])) @@ -148,7 +142,7 @@ testEmbeddingLookupUndoesSplit shapedValues = TF.constant shape values in monadicIO $ run $ do (shapeOut, got, want :: V.Vector a) <- - TF.runSession $ TF.buildAnd TF.run $ do + TF.runSession $ TF.run =<< do embeddings <- embeddingLookup modShardedValues indicesVector return (TF.cast (TF.shape embeddings), embeddings, directs) -- Checks the explicitly documented invariant of embeddingLookup. diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 2f4dd30..765ed22 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -13,6 +13,7 @@ -- limitations under the License. {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE NoMonomorphismRestriction #-} {-# LANGUAGE ScopedTypeVariables #-} import Data.Int (Int32) @@ -40,7 +41,7 @@ testGradientSimple = testCase "testGradientSimple" $ do y = x*x + b grads = TF.gradients y [x, b] -- Assert that the gradients are right. - [dx, db] <- TF.runSession $ TF.buildAnd TF.run grads + [dx, db] <- TF.runSession $ grads >>= TF.run 6 @=? TF.unScalar dx 1 @=? TF.unScalar db -- Assert that the graph has the expected ops. @@ -91,7 +92,7 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do b = TF.scalar (4 :: Float) grads = TF.gradients x [x, b] -- Assert that the gradients are right. - [dx, db] <- TF.runSession $ TF.buildAnd TF.run grads + [dx, db] <- TF.runSession $ grads >>= TF.run 1 @=? TF.unScalar dx 0 @=? TF.unScalar db -- Assert that the graph has the expected ops. @@ -113,11 +114,11 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do -- Test that identical "stateful" ops work with createGraph. testCreateGraphStateful :: Test testCreateGraphStateful = testCase "testCreateGraphStateful" $ do - [dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do + [dx, dy] <- TF.runSession $ do let shape = TF.constant (TF.Shape [1]) [1] x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape - TF.gradients (x + y*3) [x, y] + TF.gradients (x + y*3) [x, y] >>= TF.run -- If this test fails, it will likely be caused by an exception within -- `TF.gradients`. These asserts are extra. 1 @=? TF.unScalar dx @@ -127,11 +128,11 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do -- Test that name scopes work with createGraph. testCreateGraphNameScopes :: Test testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do - [dx] <- TF.runSession $ TF.buildAnd TF.run $ do + [dx] <- TF.runSession $ do let shape = TF.constant (TF.Shape [1]) [1] x :: TF.Tensor TF.Value Float <- TF.withNameScope "foo" (TF.truncatedNormal shape) - TF.gradients x [x] + TF.gradients x [x] >>= TF.run -- If this test fails, it will likely be caused by an exception within -- `TF.gradients`. This assert is extra. 1 @=? TF.unScalar dx @@ -140,20 +141,20 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do -- Test that createGraph can handle graphs with diamond shapes. testDiamond :: Test testDiamond = testCase "testDiamond" $ do - [dx] <- TF.runSession $ TF.buildAnd TF.run $ do + [dx] <- TF.runSession $ do let x = TF.vector [1] y = x*x z = y*y - TF.gradients z [x] + TF.gradients z [x] >>= TF.run (4 :: Float) @=? TF.unScalar dx testMaxGradient :: Test testMaxGradient = testCase "testMaxGradient" $ do - [dx] <- TF.runSession $ TF.buildAnd TF.run $ do + [dx] <- TF.runSession $ do let x = TF.vector [1, 2, 3, 0, 1 :: Float] y = TF.max x (0 :: TF.Tensor TF.Value Int32) - TF.gradients y [x] + TF.gradients y [x] >>= TF.run V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index 876b2d1..20f796e 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -41,7 +41,7 @@ testSize = testCase "testSize" $ do TF.Scalar (2 * 3 :: Int32) @=? x eval :: TF.Fetchable t a => t -> IO a -eval = TF.runSession . TF.buildAnd TF.run . return +eval = TF.runSession . TF.run -- | Confirms that the original example from Python code works. testReducedShape :: Test @@ -54,16 +54,16 @@ testSaveRestore :: Test testSaveRestore = testCase "testSaveRestore" $ withSystemTempDirectory "" $ \dirPath -> do let path = B8.pack $ dirPath ++ "/checkpoint" - var :: TF.Build (TF.Tensor TF.Ref Float) + var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float) var = TF.render =<< TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape []) TF.runSession $ do - v <- TF.build var - TF.buildAnd TF.run_ $ TF.assign v 134 - TF.buildAnd TF.run_ $ TF.save path [v] + v <- var + TF.assign v 134 >>= TF.run_ + TF.save path [v] >>= TF.run_ result <- TF.runSession $ do - v <- TF.build var - TF.buildAnd TF.run_ $ TF.restore path v + v <- var + TF.restore path v >>= TF.run_ TF.run v liftIO $ TF.Scalar 134 @=? result diff --git a/tensorflow-ops/tests/RegressionTest.hs b/tensorflow-ops/tests/RegressionTest.hs index 67e087b..ec83eed 100644 --- a/tensorflow-ops/tests/RegressionTest.hs +++ b/tensorflow-ops/tests/RegressionTest.hs @@ -25,13 +25,13 @@ fit xData yData = TF.runSession $ do let x = TF.vector xData y = TF.vector yData -- Create scalar variables for slope and intercept. - w <- TF.build (TF.initializedVariable 0) - b <- TF.build (TF.initializedVariable 0) + w <- TF.initializedVariable 0 + b <- TF.initializedVariable 0 -- Define the loss function. let yHat = (x `TF.mul` w) `TF.add` b loss = TF.square (yHat `TF.sub` y) -- Optimize with gradient descent. - trainStep <- TF.build (gradientDescent 0.001 loss [w, b]) + trainStep <- gradientDescent 0.001 loss [w, b] replicateM_ 1000 (TF.run trainStep) -- Return the learned parameters. (TF.Scalar w', TF.Scalar b') <- TF.run (w, b) @@ -40,7 +40,7 @@ fit xData yData = TF.runSession $ do gradientDescent :: Float -> TF.Tensor TF.Value Float -> [TF.Tensor TF.Ref Float] - -> TF.Build TF.ControlNode + -> TF.Session TF.ControlNode gradientDescent alpha loss params = do let applyGrad param grad = TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad)) diff --git a/tensorflow-ops/tests/TracingTest.hs b/tensorflow-ops/tests/TracingTest.hs index 9e7c1c5..fdf949d 100644 --- a/tensorflow-ops/tests/TracingTest.hs +++ b/tensorflow-ops/tests/TracingTest.hs @@ -35,7 +35,7 @@ testTracing = do loggedValue <- newEmptyMVar TF.runSessionWithOptions (def & TF.sessionTracer .~ putMVar loggedValue) - (TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float)))) + (TF.run_ (TF.scalar (0 :: Float))) tryReadMVar loggedValue >>= maybe (assertFailure "Logging never happened") expectedFormat where expectedFormat x = diff --git a/tensorflow-queue/src/TensorFlow/Queue.hs b/tensorflow-queue/src/TensorFlow/Queue.hs index f906f4a..2e0a5cf 100644 --- a/tensorflow-queue/src/TensorFlow/Queue.hs +++ b/tensorflow-queue/src/TensorFlow/Queue.hs @@ -24,10 +24,11 @@ import Data.ByteString (ByteString) import Data.Int (Int64) import Data.Proxy (Proxy(..)) import Lens.Family2 ((.~), (&)) -import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef) +import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef) import TensorFlow.BuildOp (buildOp) import TensorFlow.ControlFlow (group) -import TensorFlow.Tensor (Ref, Tensor, TensorList) +import qualified TensorFlow.GenOps.Core as CoreOps +import TensorFlow.Tensor (Ref, Value, Tensor, TensorList) import TensorFlow.Types (TensorTypes, fromTensorTypes) -- | A queue carrying tuples. @@ -36,36 +37,30 @@ data Queue (as :: [*]) = Queue { handle :: Handle } type Handle = Tensor Ref ByteString -- | Adds the given values to the queue. -enqueue :: forall as v . TensorTypes as +enqueue :: forall as v m . (MonadBuild m, TensorTypes as) => Queue as -> TensorList v as - -> Build ControlNode -enqueue q = - buildOp (opDef "QueueEnqueue" - & opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as)) - (handle q) + -> m ControlNode +enqueue = CoreOps.queueEnqueue . handle -- | Retrieves the values from the queue. -dequeue :: forall as . TensorTypes as +dequeue :: forall as m . (MonadBuild m, TensorTypes as) => Queue as - -> Build (TensorList Ref as) + -> m (TensorList Value as) -- ^ Dequeued tensors. They are coupled in a sense -- that values appear together, even if they are -- not consumed together. -dequeue q = - buildOp (opDef "QueueDequeue" - & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)) - (handle q) +dequeue = CoreOps.queueDequeue . handle -- | Creates a new queue with the given capacity and shared name. -makeQueue :: forall as . TensorTypes as +makeQueue :: forall as m . (MonadBuild m, TensorTypes as) => Int64 -- ^ The upper bound on the number of elements in -- this queue. Negative numbers mean no limit. -> ByteString -- ^ If non-empty, this queue will be shared -- under the given name across multiple sessions. - -> Build (Queue as) + -> m (Queue as) makeQueue capacity sharedName = do - q <- buildOp (opDef "FIFOQueue" + q <- build $ buildOp (opDef "FIFOQueue" & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as) & opAttr "shared_name" .~ sharedName & opAttr "capacity" .~ capacity diff --git a/tensorflow-queue/tensorflow-queue.cabal b/tensorflow-queue/tensorflow-queue.cabal index dcf2c8a..f14f243 100644 --- a/tensorflow-queue/tensorflow-queue.cabal +++ b/tensorflow-queue/tensorflow-queue.cabal @@ -39,6 +39,7 @@ Test-Suite QueueTest , lens-family , google-shim , tensorflow + , tensorflow-core-ops , tensorflow-ops , tensorflow-queue , test-framework diff --git a/tensorflow-queue/tests/QueueTest.hs b/tensorflow-queue/tests/QueueTest.hs index 5aa0e54..cb4c6ab 100644 --- a/tensorflow-queue/tests/QueueTest.hs +++ b/tensorflow-queue/tests/QueueTest.hs @@ -27,7 +27,6 @@ import TensorFlow.Queue import TensorFlow.Session ( asyncProdNodes , build - , buildAnd , run , runSession , run_ @@ -41,12 +40,12 @@ import qualified Data.ByteString as BS testBasic :: Test testBasic = testCase "testBasic" $ runSession $ do q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 "" - buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil - x <- buildAnd run (dequeue q) + run_ =<< enqueue q (42 :/ scalar "Hi" :/ Nil) + x <- run =<< dequeue q liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x - buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil - y <- buildAnd run (dequeue q) + run_ =<< enqueue q (56 :/ scalar "Bar" :/ Nil) + y <- run =<< dequeue q -- Note: we use explicit "Scalar" here to specify the type that was -- fetched. Equivalently we could write -- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString] @@ -74,7 +73,7 @@ testPump = testCase "testPump" $ runSession $ do testAsync :: Test testAsync = testCase "testAsync" $ runSession $ do - (deq, pump) <- build $ do + (deq, pump) <- do q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "" (,) <$> dequeue q <*> enqueue q (10 :/ scalar "Async" :/ Nil) diff --git a/tensorflow/src/TensorFlow/Build.hs b/tensorflow/src/TensorFlow/Build.hs index 8724fb1..9cdd43b 100644 --- a/tensorflow/src/TensorFlow/Build.hs +++ b/tensorflow/src/TensorFlow/Build.hs @@ -37,6 +37,7 @@ module TensorFlow.Build , renderedNodeDefs , BuildT , Build + , MonadBuild(..) , addInitializer , hoistBuildT , evalBuildT @@ -212,9 +213,16 @@ runBuildT (BuildT f) = runStateT f initGraphState evalBuildT :: Monad m => BuildT m a -> m a evalBuildT (BuildT f) = evalStateT f initGraphState +-- | Lift a 'Build' action into a monad, including any explicit op renderings. +class Monad m => MonadBuild m where + build :: Build a -> m a + +instance Monad m => MonadBuild (BuildT m) where + build = hoistBuildT $ return . runIdentity + -- | Get all the NodeDefs that have accumulated so far, and clear that buffer. -flushNodeBuffer :: Monad m => BuildT m [NodeDef] -flushNodeBuffer = do +flushNodeBuffer :: MonadBuild m => m [NodeDef] +flushNodeBuffer = build $ do ns <- use nodeBuffer nodeBuffer .= [] return ns @@ -229,8 +237,8 @@ flushInitializers = do -- | Registers the given node to be executed before the next -- 'TensorFlow.Session.run'. -addInitializer :: ControlNode -> Build () -addInitializer (ControlNode o) = do +addInitializer :: MonadBuild m => ControlNode -> m () +addInitializer (ControlNode o) = build $ do i <- getOrAddOp o initializationNodes %= (i:) @@ -242,8 +250,8 @@ asGraphDef b = def & node .~ gs ^. nodeBuffer gs = snd $ runIdentity $ runBuildT b -- TODO: check against existing nodes for conflicts? -addGraphDef :: GraphDef -> Build () -addGraphDef g = nodeBuffer <>= g ^. node +addGraphDef :: MonadBuild m => GraphDef -> m () +addGraphDef g = build $ nodeBuffer <>= g ^. node -- | Render the given op if it hasn't been rendered already, and return its -- name. @@ -318,34 +326,34 @@ renderOutput (Output (OutputIx i) o) = do -- | Modify some part of the state, run an action, and restore the state -- after that action is done. -withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b +withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b withStateLens accessor f act = do - old <- use accessor - accessor %= f + old <- build $ use accessor + build $ accessor %= f result <- act - accessor .= old + build $ accessor .= old return result -- | Set a device for all nodes rendered in the given 'Build' action -- (unless further overridden by another use of withDevice). -withDevice :: Maybe Device -> Build a -> Build a +withDevice :: MonadBuild m => Maybe Device -> m a -> m a withDevice d = withStateLens defaultDevice (const d) -- | Places all nodes rendered in the given 'Build' action on the same -- device as the given Tensor (see also 'withDevice'). Make sure that -- the action has side effects of rendering the desired tensors. A pure -- return would not have the desired effect. -colocateWith :: forall a v b . Tensor v b -> Build a -> Build a +colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a colocateWith t x = do - d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp) + d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp) withDevice (Just d) x -- | Prepend a scope to all nodes rendered in the given 'Build' action. -withNameScope :: Text -> Build a -> Build a +withNameScope :: MonadBuild m => Text -> m a -> m a withNameScope s = withStateLens currentScope (Scope s :) -- | Add control inputs to all nodes rendered in the given 'Build' action. -withNodeDependencies :: Set NodeName -> Build a -> Build a +withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes) -- | Render a 'Tensor', fixing its name, scope, device and control inputs from @@ -355,8 +363,8 @@ withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes) -- This operation is idempotent; @render >=> render === render@. However, -- rendering a (previously un-rendered) 'Tensor' in two different contexts -- may result in two different 'Tensor's. -render :: Tensor v a -> Build (Tensor v a) -render = tensorOutput $ outputOp $ fmap Rendered . resolveOp +render :: MonadBuild m => Tensor v a -> m (Tensor v a) +render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp) -- | Render a 'Tensor' and get its node's name. renderNodeName :: Tensor v a -> Build NodeName diff --git a/tensorflow/src/TensorFlow/ControlFlow.hs b/tensorflow/src/TensorFlow/ControlFlow.hs index 9b3f112..2a57b22 100644 --- a/tensorflow/src/TensorFlow/ControlFlow.hs +++ b/tensorflow/src/TensorFlow/ControlFlow.hs @@ -40,9 +40,9 @@ import TensorFlow.Types -- | Modify a 'Build' action, such that all new ops rendered in it will depend -- on the nodes in the first argument. -withControlDependencies :: Nodes t => t -> Build a -> Build a +withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a withControlDependencies deps act = do - nodes <- getNodes deps + nodes <- build $ getNodes deps withNodeDependencies nodes act -- TODO(judahjacobson): Reimplement withDependencies. @@ -51,9 +51,9 @@ withControlDependencies deps act = do -- -- When this op finishes, all ops in the input @n@ have finished. This op has -- no output. -group :: Nodes t => t -> Build ControlNode +group :: (MonadBuild m, Nodes t) => t -> m ControlNode group deps = do - nodes <- Set.toList <$> getNodes deps + nodes <- build $ Set.toList <$> getNodes deps -- TODO: slicker way return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes diff --git a/tensorflow/src/TensorFlow/Core.hs b/tensorflow/src/TensorFlow/Core.hs index 8af036e..67781aa 100644 --- a/tensorflow/src/TensorFlow/Core.hs +++ b/tensorflow/src/TensorFlow/Core.hs @@ -31,8 +31,7 @@ module TensorFlow.Core , runSession , runSessionWithOptions -- ** Building graphs - , build - , buildAnd + , MonadBuild(..) -- ** Running graphs , Fetchable , Nodes diff --git a/tensorflow/src/TensorFlow/Session.hs b/tensorflow/src/TensorFlow/Session.hs index a9a0182..bc24e3d 100644 --- a/tensorflow/src/TensorFlow/Session.hs +++ b/tensorflow/src/TensorFlow/Session.hs @@ -26,8 +26,7 @@ module TensorFlow.Session ( sessionTracer, runSession, runSessionWithOptions, - build, - buildAnd, + MonadBuild(..), extend, addGraphDef, run, @@ -44,7 +43,6 @@ import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks) import Data.ByteString (ByteString) import Data.Default (Default, def) -import Data.Functor.Identity (runIdentity) import Data.Monoid ((<>)) import Data.ProtoLens (showMessage) import Data.Set (Set) @@ -124,10 +122,8 @@ runSessionWithOptions options (Session m) = FFI.setSessionTarget (options ^. sessionTarget) opt FFI.setSessionConfig (options ^. sessionConfig) opt --- | Lift a 'Build' action into a 'Session', including any explicit op --- renderings. -build :: Build a -> Session a -build = Session . lift . hoistBuildT (return . runIdentity) +instance MonadBuild Session where + build = Session . lift . build -- | Add all pending rendered nodes to the TensorFlow graph and runs -- any pending initializers. @@ -147,13 +143,6 @@ extend = do unless (null initializers) $ void $ liftIO $ FFI.run session [] [] (toNodeNames initializers) --- | Helper combinator for doing something with the result of a 'Build' action. --- Example usage: --- --- > buildAnd run :: Fetchable t a => Build t -> Session a -buildAnd :: (a -> Session b) -> Build a -> Session b -buildAnd f m = build m >>= f - -- | Run a subgraph 't', rendering any dependent nodes that aren't already -- rendered, and fetch the corresponding values for 'a'. run :: Fetchable t a => t -> Session a