From d62c61469558cc089ef461bfcf412b9579eaaacc Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Thu, 6 Apr 2017 15:10:33 -0700 Subject: [PATCH] Distinguish between "rendered" and "unrendered" Tensors. (#88) Distinguish between "rendered" and "unrendered" Tensors. There are now three types of `Tensor`: - `Tensor Value a`: rendered value - `Tensor Ref a`: rendered reference - `Tensor Build a` : unrendered value The extra bookkeeping makes it easier to track (and enforce) which tensors are rendered or not. For examples where this has been confusing in the past, see With this change, pure ops look similar to before, returning `Tensor Build` instead of `Tensor Value`. "Stateful" (monadic) ops are unchanged. For example: add :: OneOf [..] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t assign :: (MonadBuild m, TensorType t) => Tensor Ref t -> Tensor v'2 t -> m (Tensor Ref t) The `gradients` function now requires that the variables over which it's differentiating are pre-rendered: gradients :: (..., Rendered v2) => Tensor v1 a -> [Tensor v2 a] -> m [Tensor Value a] (`Rendered v2` means that `v2` is either a `Ref` or a `Value`.) Additionally, the implementation of `gradients` now takes care to render every intermediate value when performing the reverse accumulation. I suspect this fixes an exponential blowup for complicated expressions. --- README.md | 2 +- tensorflow-logging/src/TensorFlow/Logging.hs | 14 +- tensorflow-mnist/app/Main.hs | 8 +- tensorflow-mnist/tests/ParseTest.hs | 22 +- tensorflow-nn/src/TensorFlow/NN.hs | 16 +- tensorflow-nn/tests/NNTest.hs | 35 +- tensorflow-opgen/src/TensorFlow/OpGen.hs | 51 +-- .../src/TensorFlow/OpGen/ParsedOp.hs | 30 +- tensorflow-ops/src/TensorFlow/EmbeddingOps.hs | 11 +- tensorflow-ops/src/TensorFlow/Gradient.hs | 183 +++++----- tensorflow-ops/src/TensorFlow/Ops.hs | 76 ++-- tensorflow-ops/tensorflow-ops.cabal | 1 + tensorflow-ops/tests/ArrayOpsTest.hs | 6 +- tensorflow-ops/tests/BuildTest.hs | 14 +- tensorflow-ops/tests/DataFlowOpsTest.hs | 6 +- tensorflow-ops/tests/EmbeddingOpsTest.hs | 52 ++- tensorflow-ops/tests/GradientTest.hs | 26 +- tensorflow-ops/tests/OpsTest.hs | 9 +- tensorflow-ops/tests/RegressionTest.hs | 2 +- tensorflow-ops/tests/TypesTest.hs | 9 +- tensorflow-queue/src/TensorFlow/Queue.hs | 2 +- tensorflow/src/TensorFlow/Build.hs | 105 ++---- tensorflow/src/TensorFlow/BuildOp.hs | 329 +++++++++--------- tensorflow/src/TensorFlow/ControlFlow.hs | 12 +- tensorflow/src/TensorFlow/Core.hs | 2 +- tensorflow/src/TensorFlow/Nodes.hs | 14 +- tensorflow/src/TensorFlow/Output.hs | 41 +-- tensorflow/src/TensorFlow/Session.hs | 6 +- tensorflow/src/TensorFlow/Tensor.hs | 160 +++++++-- 29 files changed, 636 insertions(+), 608 deletions(-) diff --git a/README.md b/README.md index 62793f7..8ebb889 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ fit xData yData = TF.runSession $ do return (w', b') gradientDescent :: Float - -> TF.Tensor TF.Value Float + -> TF.Tensor TF.Build Float -> [TF.Tensor TF.Ref Float] -> TF.Session TF.ControlNode gradientDescent alpha loss params = do diff --git a/tensorflow-logging/src/TensorFlow/Logging.hs b/tensorflow-logging/src/TensorFlow/Logging.hs index da1c29f..f0d5439 100644 --- a/tensorflow-logging/src/TensorFlow/Logging.hs +++ b/tensorflow-logging/src/TensorFlow/Logging.hs @@ -67,10 +67,10 @@ import Proto.Tensorflow.Core.Framework.Summary (Summary) import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime) import System.Directory (createDirectoryIfMissing) import System.FilePath (()) -import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries) +import TensorFlow.Build (MonadBuild) import TensorFlow.Ops (scalar) import TensorFlow.Records.Conduit (sinkTFRecords) -import TensorFlow.Tensor (Tensor) +import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries) import TensorFlow.Types (TensorType, type(/=)) import Text.Printf (printf) import qualified Data.ByteString.Lazy as L @@ -141,19 +141,19 @@ doubleWallTime = asDouble <$> getCurrentTime -- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally -- limited to a single value for simplicity. histogramSummary :: - (TensorType t, t /= ByteString, t /= Bool) + (MonadBuild m, TensorType t, t /= ByteString, t /= Bool) -- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) - => ByteString -> Tensor v t -> Build () + => ByteString -> Tensor v t -> m () histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag) -- | Adds a 'CoreOps.scalarSummary' node. scalarSummary :: - (TensorType t, t /= ByteString, t /= Bool) + (TensorType t, t /= ByteString, t /= Bool, MonadBuild m) -- (TensorType t, -- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) - => ByteString -> Tensor v t -> Build () + => ByteString -> Tensor v t -> m () scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag) -- | Merge all summaries accumulated in the 'Build' into one summary. -mergeAllSummaries :: Build SummaryTensor +mergeAllSummaries :: MonadBuild m => m SummaryTensor mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary diff --git a/tensorflow-mnist/app/Main.hs b/tensorflow-mnist/app/Main.hs index 5475469..ccb032b 100644 --- a/tensorflow-mnist/app/Main.hs +++ b/tensorflow-mnist/app/Main.hs @@ -34,13 +34,13 @@ numPixels = 28*28 :: Int64 numLabels = 10 :: Int64 -- | Create tensor with random values where the stddev depends on the width. -randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float) +randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Build Float) randomParam width (TF.Shape shape) = - (* stddev) <$> TF.truncatedNormal (TF.vector shape) + (`TF.mul` stddev) <$> TF.truncatedNormal (TF.vector shape) where stddev = TF.scalar (1 / sqrt (fromIntegral width)) -reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float +reduceMean :: TF.Tensor TF.Build Float -> TF.Tensor TF.Build Float reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32)) -- Types must match due to model structure. @@ -87,7 +87,7 @@ createModel = do grads <- TF.gradients loss params let lr = TF.scalar 0.00001 - applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad) + applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad) trainStep <- TF.group =<< zipWithM applyGrad params grads let correctPredictions = TF.equal predict labels diff --git a/tensorflow-mnist/tests/ParseTest.hs b/tensorflow-mnist/tests/ParseTest.hs index 95e8f11..9a0f4db 100644 --- a/tensorflow-mnist/tests/ParseTest.hs +++ b/tensorflow-mnist/tests/ParseTest.hs @@ -37,15 +37,15 @@ import TensorFlow.Examples.MNIST.TrainedGraph import TensorFlow.Build ( asGraphDef , addGraphDef - , render + , Build ) import TensorFlow.Tensor ( Tensor(..) , Ref - , Value , feed - , TensorKind(..) + , render , tensorFromName + , tensorValueFromName ) import TensorFlow.Ops import TensorFlow.Session @@ -80,7 +80,7 @@ testReadMNIST = testCase "testReadMNIST" $ do labelData <- readMNISTLabels =<< testLabelData 10000 @=? length labelData -testNodeName :: Text -> Tensor v a -> Assertion +testNodeName :: Text -> Tensor Build a -> Assertion testNodeName n g = n @=? opName where opName = head (gDef^.node)^.op @@ -89,7 +89,7 @@ testNodeName n g = n @=? opName testGraphDefGen :: Test testGraphDefGen = testCase "testGraphDefGen" $ do -- Test the inferred operation type. - let f0 :: Tensor Value Float + let f0 :: Tensor Build Float f0 = 0 testNodeName "Const" f0 testNodeName "Add" $ 1 + f0 @@ -109,7 +109,7 @@ testGraphDefExec = testCase "testGraphDefExec" $ do let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10 runSession $ do addGraphDef graphDef - x <- run $ tensorFromName ValueKind "Mul_2" + x <- run $ tensorValueFromName "Mul_2" liftIO $ (50 :: Float) @=? unScalar x -- | Load MNIST from a GraphDef and the weights from a checkpoint and run on @@ -142,8 +142,8 @@ testMNISTExec = testCase "testMNISTExec" $ do build $ addGraphDef $ mnist & version .~ 0 -- Define nodes that restore saved weights and biases. let bias, wts :: Tensor Ref Float - bias = tensorFromName RefKind "Variable" - wts = tensorFromName RefKind "weights" + bias = tensorFromName "Variable" + wts = tensorFromName "weights" wtsCkptPath <- liftIO wtsCkpt biasCkptPath <- liftIO biasCkpt -- Run those restoring nodes on the graph in the current session. @@ -155,12 +155,12 @@ testMNISTExec = testCase "testMNISTExec" $ do let ty = encodeTensorData [10] oneHotLabels where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates updates = [(fromIntegral label, 1)] - let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample - , feed (tensorFromName ValueKind "y-input") ty + let feeds = [ feed (tensorValueFromName "x-input") tensorSample + , feed (tensorValueFromName "y-input") ty ] -- Run the graph with the input feeds and read the ArgMax'd result from -- the test (not training) side of the evaluation. - x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax" + x <- runWithFeeds feeds $ tensorValueFromName "test/ArgMax" -- Print the trained model's predicted outcome. liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n" ++ "Prediction: " ++ show (unScalar x :: Int64) diff --git a/tensorflow-nn/src/TensorFlow/NN.hs b/tensorflow-nn/src/TensorFlow/NN.hs index 0ae1c05..9baaeae 100644 --- a/tensorflow-nn/src/TensorFlow/NN.hs +++ b/tensorflow-nn/src/TensorFlow/NN.hs @@ -24,7 +24,6 @@ import Prelude hiding ( log , exp ) import TensorFlow.Build ( MonadBuild - , render , withNameScope ) import TensorFlow.GenOps.Core ( greaterEqual @@ -33,6 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual , exp ) import TensorFlow.Tensor ( Tensor(..) + , render , Value ) import TensorFlow.Types ( TensorType(..) @@ -40,6 +40,8 @@ import TensorFlow.Types ( TensorType(..) ) import TensorFlow.Ops ( zerosLike , add + , mul + , neg ) -- | Computes sigmoid cross entropy given `logits`. @@ -76,13 +78,11 @@ sigmoidCrossEntropyWithLogits -> Tensor Value a -- ^ __targets__ -> m (Tensor Value a) sigmoidCrossEntropyWithLogits logits targets = do - logits' <- render logits - targets' <- render targets - let zeros = zerosLike logits' - cond = logits' `greaterEqual` zeros - relu_logits = select cond logits' zeros - neg_abs_logits = select cond (-logits') logits' + let zeros = zerosLike logits + cond = logits `greaterEqual` zeros + relu_logits = select cond logits zeros + neg_abs_logits = select cond (neg logits) logits withNameScope "logistic_loss" $ do - left <- render $ relu_logits - logits' * targets' + left <- render $ relu_logits - logits `mul` targets right <- render $ log (1 + exp neg_abs_logits) withNameScope "sigmoid_add" $ render $ left `add` right diff --git a/tensorflow-nn/tests/NNTest.hs b/tensorflow-nn/tests/NNTest.hs index d91dd70..4051212 100644 --- a/tensorflow-nn/tests/NNTest.hs +++ b/tensorflow-nn/tests/NNTest.hs @@ -23,10 +23,9 @@ import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) import qualified Data.Vector as V import qualified TensorFlow.Gradient as TF -import qualified TensorFlow.Nodes as TF import qualified TensorFlow.NN as TF import qualified TensorFlow.Ops as TF -import qualified TensorFlow.Session as TF +import qualified TensorFlow.Core as TF -- | These tests are ported from: -- @@ -60,12 +59,11 @@ defInputs = Inputs { testLogisticOutput :: Test testLogisticOutput = testCase "testLogisticOutput" $ do let inputs = defInputs - vLogits = TF.vector $ logits inputs - vTargets = TF.vector $ targets inputs - tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets - ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) - - r <- run tfLoss + r <- run $ do + vLogits <- TF.render $ TF.vector $ logits inputs + vTargets <- TF.render $ TF.vector $ targets inputs + TF.sigmoidCrossEntropyWithLogits vLogits vTargets + let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) assertAllClose r ourLoss @@ -74,23 +72,22 @@ testLogisticOutputMultipleDim = testCase "testLogisticOutputMultipleDim" $ do let inputs = defInputs shape = [2, 2, 2] - vLogits = TF.constant shape (logits inputs) - vTargets = TF.constant shape (targets inputs) - tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets - ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) - - r <- run tfLoss + r <- run $ do + vLogits <- TF.render $ TF.constant shape (logits inputs) + vTargets <- TF.render $ TF.constant shape (targets inputs) + TF.sigmoidCrossEntropyWithLogits vLogits vTargets + let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) assertAllClose r ourLoss testGradientAtZero :: Test testGradientAtZero = testCase "testGradientAtZero" $ do - let inputs = defInputs { logits = [0, 0], targets = [0, 1] } - vLogits = TF.vector $ logits inputs - vTargets = TF.vector $ targets inputs - tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets - r <- run $ do + let inputs = defInputs { logits = [0, 0], targets = [0, 1] } + vTargets <- TF.render $ TF.vector $ targets inputs + vLogits <- TF.render $ TF.vector $ logits inputs + let tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets + l <- tfLoss TF.gradients l [vLogits] diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index ac1e02d..c512168 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -231,21 +231,24 @@ renderHaskellAttrName :: Attr a -> Doc renderHaskellAttrName = renderHaskellName . attrName functionBody :: ParsedOp -> Doc -functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts)) - indent indentation (sep tensorArgs) +functionBody pOp + | parsedOpIsMonadic pOp + = "build $ do" + indent indentation (bindOpInputsVar + "buildOp" <+> outputListsSizes <+> opDef) + | otherwise + = "pureOp" <+> outputListsSizes <+> "$ do" + indent indentation (bindOpInputsVar "return" <+> opDef) where - maybeLift - | parsedOpIsMonadic pOp = "build $" - | otherwise = "" - buildFunction - | null outputListsSizes = "buildOp" - | otherwise = "buildListOp" <+> - brackets (commasep $ - map renderHaskellName outputListsSizes) - outputListsSizes = [ a - | ParsedArg { parsedArgCase = ListArg { argLength = a } } - <- parsedOutputs pOp] - buildOpParts = + outputListsSizes = brackets $ commasep + [ renderHaskellName a + | ParsedArg { parsedArgCase = ListArg { argLength = a } } + <- parsedOutputs pOp + ] + opInputsVar = "op'inputs" + bindOpInputsVar = opInputsVar <+> "<- fmap Prelude.concat $ Prelude.sequence" + <+> brackets (commasep $ map (\a -> "buildInputs" <+> a) tensorArgs) + opDef = parens $ hang 0 $ stack $ "opDef" <+> renderQuotedTFName (parsedOpName pOp) : -- Renders type parameter arguments. [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a @@ -259,10 +262,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n | a <- inferredListSizeAttrs pOp, let n = attrName a ] ++ - ["& op'options"] - - - tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp + ["& op'options & opInputs .~" <+> opInputsVar] + tensorArgs = renderTensorArg <$> parsedInputs pOp + renderTensorArg = renderHaskellName . parsedArgName inferredTypeExpr a | typeParamIsList $ attrInfo a = "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a @@ -296,7 +298,7 @@ typeSig pre pOp = constraints | null classConstraints = empty | otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>" typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, - Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]] + Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]] ++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp] ++ if parsedOpIsMonadic pOp then ["m'"] else [] -- Use m' as the type parameter to avoid clashing with an attribute name. @@ -327,7 +329,7 @@ typeSig pre pOp = constraints wrapOutput o | parsedOpIsMonadic pOp = "m'" <+> parens o | otherwise = o - + -- | Render an op input or output. -- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle" tensorArg :: ParsedArg -> Doc @@ -336,12 +338,13 @@ tensorArg p = case parsedArgCase p of SimpleArg { argType = t, argCaseKind = k } -> tensorType t k ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k MixedListArg {argTypeAttr = t, argCaseKind = k} - -> "TensorList" <+> kind k <+> renderHaskellName t + -> "TensorList" <+> parens (kind k) <+> renderHaskellName t where kind k = case k of ArgTensorRef -> "Ref" ArgTensorValue -> "Value" - ArgTensorEither v' -> strictText v' + ArgTensorBuild -> "Build" + ArgSomeTensor v -> strictText v tensorType t k = let a = case t of ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt @@ -350,7 +353,7 @@ tensorArg p = case parsedArgCase p of attrComment :: Attr a -> Doc attrComment a = argComment' (attrName a) (attrDescription a) - + argComment :: ParsedArg -> Doc argComment a = argComment' (parsedArgName a) (parsedArgDescription a) @@ -364,7 +367,7 @@ bold n = "__" <> n <> "__" -- | Comment for the outputs of an op. -- For example: -- -- ^ (__output1__, __output2__) --- -- +-- -- -- -- * __output1__: description1 -- -- -- -- * __output2__: description2 diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs index c7a5b69..2fbe4e4 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs @@ -141,7 +141,8 @@ data ArgType data ArgKind = ArgTensorRef -- Tensor Ref a | ArgTensorValue -- Tensor Value a - | ArgTensorEither Text -- Tensor v a; the Text is the variable `v` + | ArgTensorBuild -- Tensor Build a + | ArgSomeTensor Text -- Tensor v a; the Text is the variable 'v'. deriving (Eq) isRefCase :: ParsedArgCase -> Bool @@ -219,15 +220,17 @@ parseOp o = ParsedOp { parsedOpName = makeName $ o ^. name , parsedOpSummary = o ^. summary , parsedOpDescription = o ^. description - , parsedOpIsMonadic = o ^. isStateful - || any (isRefCase . parsedArgCase) parsedInputs , .. } where - parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v)) - (o ^. inputArg) tensorKindParams - tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]] - parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg) + parsedOpIsMonadic = o ^. isStateful + || any (isRefCase . parsedArgCase) parsedInputs + || null (o ^. outputArg) + parsedInputs = zipWith (\t a -> parseArg a (inputTensorKind t a)) + tensorKindParams (o ^. inputArg) + tensorKindParams = ["v'" <> Text.pack (show x) | x <- [1::Integer ..]] + parsedOutputs = map (\a -> parseArg a (outputTensorKind parsedOpIsMonadic a)) + (o ^. outputArg) -- Integer attributes that can be inferred from the size of at least one -- input list. inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs) @@ -246,15 +249,16 @@ parseOp o = ParsedOp $ o ^. attr -- TODO(judahjacobson): Some arguments should be refs. -inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind -inputTensorKind a v +inputTensorKind :: Text -> OpDef'ArgDef -> ArgKind +inputTensorKind v a | a ^. isRef = ArgTensorRef - | otherwise = ArgTensorEither v + | otherwise = ArgSomeTensor v -outputTensorKind :: OpDef'ArgDef -> ArgKind -outputTensorKind a +outputTensorKind :: Bool -> OpDef'ArgDef -> ArgKind +outputTensorKind isMonadic a | a ^. isRef = ArgTensorRef - | otherwise = ArgTensorValue + | isMonadic = ArgTensorValue + | otherwise = ArgTensorBuild getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType getExplicitInputAttr o implicitAttrs a diff --git a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs index df1866f..7b9e7fc 100644 --- a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs +++ b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs @@ -24,9 +24,9 @@ module TensorFlow.EmbeddingOps where import Control.Monad (zipWithM) import Data.Int (Int32, Int64) -import TensorFlow.Build (MonadBuild, colocateWith, render) +import TensorFlow.Build (MonadBuild) import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor -import TensorFlow.Tensor (Tensor, Value) +import TensorFlow.Tensor (Tensor, Value, Rendered, colocateWith, render) import TensorFlow.Types (OneOf, TensorType) import qualified TensorFlow.GenOps.Core as CoreOps @@ -44,17 +44,18 @@ import qualified TensorFlow.GenOps.Core as CoreOps -- -- The results of the lookup are concatenated into a dense -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. -embeddingLookup :: forall a b v m . +embeddingLookup :: forall a b v1 v2 m . ( MonadBuild m + , Rendered v1 , TensorType a , OneOf '[Int64, Int32] b , Num b ) - => [Tensor v a] + => [Tensor v1 a] -- ^ A list of tensors which can be concatenated along -- dimension 0. Each `Tensor` must be appropriately -- sized for `mod` partition strategy. - -> Tensor Value b + -> Tensor v2 b -- ^ A `Tensor` with type `int32` or `int64` -- containing the ids to be looked up in `params`. -- The ids are required to have fewer than 2^31 diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 099a196..52467c8 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -31,6 +31,7 @@ import Data.ByteString (ByteString) import Data.Complex (Complex) import Data.Default (def) import Data.Int (Int32, Int64) +import Data.Foldable (foldlM) import Data.List (foldl', sortBy) import Data.Map.Strict (Map) import Data.Maybe (fromMaybe, maybeToList, mapMaybe) @@ -39,7 +40,7 @@ import Data.ProtoLens.TextFormat (showMessage) import Data.Set (Set) import Data.Text (Text) import Data.Tuple (swap) -import Lens.Family2 (Lens', (&), (^.), (.~), (%~)) +import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~)) import Lens.Family2.State.Strict (uses) import Lens.Family2.Stock (at, intAt) import Lens.Family2.Unchecked (lens, iso) @@ -59,11 +60,10 @@ import TensorFlow.Build ( MonadBuild , Build , build - , render - , renderNodeName , renderedNodeDefs , opDef , opAttr + , opInputs ) import TensorFlow.BuildOp import TensorFlow.Ops @@ -86,16 +86,19 @@ import TensorFlow.Ops ) import TensorFlow.Output ( NodeName(..) - , Op (Rendered) , Output(..) , OutputIx(..) , outputIndex ) import TensorFlow.Tensor ( Tensor(..) - , TensorKind (ValueKind) , Value - , tensorOutput + , render + , expr + , Rendered + , tensorNodeName + , renderedOutput + , renderValue ) import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens) import Proto.Tensorflow.Core.Framework.NodeDef @@ -114,10 +117,7 @@ type GradientCompatible a = -- | Gradient of @y@ w.r.t. each element of @xs@. gradients :: forall a v1 v2 m . (MonadBuild m - , Num (Tensor v1 a) - -- TODO(gnezdo): remove indirect constraint. - -- It's a wart inherited from Num instance. - , v1 ~ Value + , Rendered v2 , GradientCompatible a ) => Tensor v1 a -- ^ The output of the graph. @@ -150,27 +150,31 @@ gradients y xs = build $ do -- -- 4. Lookup the recorded gradient for each x in xs. - yName <- renderNodeName y + y' <- renderValue y + let yName = tensorNodeName y' + yOne <- render $ fill (shape y') (scalar 1) -- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName? nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $ (\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x)) . flip Map.lookup let (gr, nodeMap) = createGraph yName nodeDefLookup -- Set gradient of y to one. + -- TODO: nicer let initPending :: Map.Map FGL.Node (PendingGradients a) - initPending = Map.empty & at (nodeMap Map.! yName) + = Map.empty & (at (nodeMap Map.! yName) . nonEmpty - . outputIxAt (y ^. tensorOutput . outputIndex) + . outputIxAt (outputIndex $ renderedOutput y') . nonEmpty - .~ [fill (shape y) (scalar 1)] + .~ [yOne] + ) -- Calculate the gradients of y w.r.t. each node in the graph. gradientMap <- graphGrads gr initPending -- Lookup the gradients for each x. - forM xs $ \x -> do - xName <- renderNodeName x - render $ fromMaybe (zerosLike x) $ do + forM xs $ \x -> + let xName = tensorNodeName x + in maybe (render $ zerosLike x) return $ do n <- nodeMap ^. at xName - let i = x ^. tensorOutput . outputIndex + let i = outputIndex $ renderedOutput x gradientMap ^. at n . nonEmpty . outputIxAt i outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v) @@ -182,6 +186,7 @@ outputIxAt = intAt . unOutputIx type PendingGradients a = IntMap.IntMap [Tensor Value a] -- | Gradients of a node's outputs. The key is an OutputIx sans newtype. +-- TODO: precache the rendering? type Gradients a = IntMap.IntMap (Tensor Value a) -- | Graph of TensorFlow operations. @@ -229,42 +234,43 @@ non a = anon a (a==) nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v) nonEmpty = anon mempty null +-- TODO: strictness (e.g., foldlM') + -- | Calculate the gradients for every node in a graph. graphGrads :: forall a. GradientCompatible a => Graph -> Map FGL.Node (PendingGradients a) -- ^ Initial gradients (usually just 1 for the node of interest). -> Build (Map FGL.Node (Gradients a)) -graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult) +graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder where initState = GradientsState initPending Map.empty -- Reverse topological sort. -- TODO(fmayle): Filter out nodes that are not successors of any x in xs to -- avoid calculating gradients that won't be used. nodeOrder = FGL.topsort $ FGL.grev gr - go state node = + go :: GradientsState a -> Int -> Build (GradientsState a) + go state node = do -- Aggregate the accumulated gradients for this node. - let outputGrads = + outputGrads <- sumPendingGradient (state ^. gradientsPending . at node . nonEmpty) - in if null outputGrads - then state - else + if null outputGrads + then pure state + else do + let ctx = FGL.context gr node + inputGrads <- calculateInputGrads ctx outputGrads gr -- Calculate the gradients for each of the node's inputs. let nextState = state & gradientsResult %~ Map.insert node outputGrads - ctx = FGL.context gr node - in updatePendingGradients - ctx - (calculateInputGrads ctx outputGrads gr) - nextState + pure $ updatePendingGradients ctx inputGrads nextState -- | Reduce accumulated gradients for each output to one Tensor. sumPendingGradient :: GradientCompatible a - => PendingGradients a -> Gradients a -sumPendingGradient = IntMap.mapMaybe f + => PendingGradients a -> Build (Gradients a) +sumPendingGradient = sequence . IntMap.mapMaybe f where f [] = Nothing - f [x] = Just x - f xs = Just (addN xs) + f [x] = Just (pure x) + f xs = Just (render $ addN xs) -- | Calculate the gradients of a node's input tensors. @@ -274,18 +280,18 @@ calculateInputGrads :: forall a. GradientCompatible a => FGL.Context NodeDef EdgeLabel -> Gradients a -- ^ Output gradients of the node. -> Graph - -> [Maybe (Tensor Value a)] -calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = - opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads + -> Build [Maybe (Tensor Value a)] +calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = do + fullOutGrads <- fullOutputGrads (numOutputs nodeDef) (nodeDefName nodeDef) + outputGrads + traverse (traverse render) $ opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads where - fullOutGrads = - fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads -- Create a tensor from an edge (technically an Output, but it seems less -- confusing to refer to it as a tensor here). edgeToTensor :: (EdgeLabel, FGL.Node) -> Output edgeToTensor ((i, _), n) = case FGL.lab gr n of - Just edgeNodeDef -> Output i (Rendered edgeNodeDef) + Just edgeNodeDef -> Output i (NodeName $ edgeNodeDef ^. name) Nothing -> error $ "calculateInputGrads: missing input node for " ++ Text.unpack (nodeDef ^. name) -- Input tensors, sorted by input index. @@ -294,11 +300,11 @@ calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = -- | Convert a Map of gradients to a list, with zeros for missing outputs. fullOutputGrads :: (TensorType a, Num a) => OutputIx -- ^ Number of outputs. - -> Op + -> NodeName -> Gradients a - -> [Tensor Value a] + -> Build [Tensor Value a] fullOutputGrads n o gs = - map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1] + mapM (\i -> maybe (render $ zero i) return (gs ^. outputIxAt i)) [0..n-1] where -- A tensor of zeros with the same shape as the i'th output. zero i = zerosLike $ toT (Output i o) @@ -397,19 +403,19 @@ type GradientFunc a = NodeDef -- ^ Input tensors. -> [Tensor Value a] -- ^ Gradient of y w.r.t. each output tensor. - -> [Maybe (Tensor Value a)] + -> [Maybe (Tensor Build a)] -- ^ Gradient of y w.r.t. each input tensor. -- TODO(fmayle): Assert the type is correct. -- | Create a Tensor from an Output. -toT :: Output -> Tensor Value a -toT = Tensor ValueKind +toT :: Output -> Tensor Build a +toT = Tensor . pure -- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for -- simple slicing operations. -flatSlice :: forall v1 t . (TensorType t) +flatSlice :: forall v1 t . TensorType t => Tensor v1 t -- ^ __input__ -> Int32 -- ^ __begin__: specifies the offset into the first dimension of -- 'input' to slice from. @@ -417,9 +423,12 @@ flatSlice :: forall v1 t . (TensorType t) -- of 'input' to slice. If size is -1, all remaining elements in the dimension -- are included in the slice (i.e. this is equivalent to setting -- size = input.dim_size(0) - begin). - -> Tensor Value t -- ^ __output__ + -> Tensor Build t -- ^ __output__ flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size]) +nodeDefName :: NodeDef -> NodeName +nodeDefName = NodeName . view name + -- | The gradient function for an op type. -- @@ -427,8 +436,8 @@ flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size]) -- third_party/tensorflow/python/ops/*_grad.py opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a -opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x] -opGrad "Neg" _ [_] [dz] = [Just $ -dz] +opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x] +opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz] opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x] opGrad "Square" _ [toT -> x] [dz] = @@ -436,7 +445,7 @@ opGrad "Square" _ [toT -> x] [dz] = -- TODO(fmayle): The python code makes dz a control dependency of the 2*x -- (for performance reasons?). Will need to put these functions in the Build -- monad to replicate that. - [Just $ dz * (2 * x)] + [Just $ dz `CoreOps.mul` (2 * x)] opGrad "Gather" _ [toT -> x, toT -> indices] [dz] = -- TODO(fmayle): The python version uses a better performance implementation @@ -448,20 +457,20 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] = ] where -- TODO(gnezdo): Use colocateWith but it requires Build monad. - denseShape = shape (x :: Tensor Value a) + denseShape = shape (x :: Tensor Build a) numRows = scalarize $ flatSlice denseShape 0 1 valuesShape = CoreOps.concat 0 [ allDimensions , flatSlice denseShape 1 (-1) ] values = reshape dz valuesShape -- TODO(fmayle): This could be either Int32 or Int64. - indices' = reshape indices allDimensions :: Tensor Value Int32 + indices' = reshape indices allDimensions :: Tensor Build Int32 opGrad "Max" _ [toT -> x, toT -> indices] [dz] = [Just $ indicators `CoreOps.div` numSelected * dz', Nothing] where - sx = shape (x :: Tensor Value a) - outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) + sx = shape (x :: Tensor Build a) + outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32) y = CoreOps.max x indices y' = reshape y outputShapeKeptDims dz' = reshape dz outputShapeKeptDims @@ -475,8 +484,8 @@ opGrad "Sum" _ [toT -> x, toT -> indices] [dz] = [ Just $ CoreOps.tile grad tileScaling, Nothing ] where -- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad. - sx = shape (x :: Tensor Value a) - outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) + sx = shape (x :: Tensor Build a) + outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32) tileScaling = safeShapeDiv sx outputShapeKeptDims grad = reshape dz outputShapeKeptDims @@ -484,8 +493,8 @@ opGrad "Mean" u v@[toT -> x, _] w = [Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing] where [Just dz, Nothing] = opGrad "Sum" u v w - inputShape = shape (x :: Tensor Value a) - outputShape = shape (dz :: Tensor Value a) + inputShape = shape (x :: Tensor Build a) + outputShape = shape (dz :: Tensor Build a) -- TODO(fmayle): Add fast path when shape is known. inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape @@ -495,8 +504,8 @@ opGrad "Add" _ [toT -> x, toT -> y] [dz] = [ Just $ reshape (sum dz rx) sx , Just $ reshape (sum dz ry) sy ] where - sx = shape (x :: Tensor Value a) - sy = shape (y :: Tensor Value a) + sx = shape (x :: Tensor Build a) + sy = shape (y :: Tensor Build a) (rx, ry) = broadcastGradientArgs sx sy opGrad "Sub" u v w = @@ -510,22 +519,24 @@ opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] = opGrad "Mul" _ [toT -> x, toT -> y] [dz] = -- TODO(fmayle): Handle complex numbers. - [ Just $ reshape (sum (dz * y) rx) sx - , Just $ reshape (sum (x * dz) ry) sy ] + [ Just $ reshape (sum (dz `CoreOps.mul` y) rx) sx + , Just $ reshape (sum (x `CoreOps.mul` dz) ry) sy ] where - sx = shape (x :: Tensor Value a) - sy = shape (y :: Tensor Value a) + sx = shape (x :: Tensor Build a) + sy = shape (y :: Tensor Build a) (rx, ry) = broadcastGradientArgs sx sy opGrad "Div" _ [toT -> x, toT -> y] [dz] = -- TODO(fmayle): Handle complex numbers. -- TODO(gnezdo): Provide Fractional instance and use '/' instead of div. [ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx - , Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy + , Just $ reshape (sum (dz `CoreOps.mul` (negate x `CoreOps.div` (y * y))) + ry) + sy ] where - sx = shape (x :: Tensor Value a) - sy = shape (y :: Tensor Value a) + sx = shape (x :: Tensor Build a) + sy = shape (y :: Tensor Build a) (rx, ry) = broadcastGradientArgs sx sy opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] = @@ -549,7 +560,7 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] = opGrad "Transpose" _ [_, toT -> p] [dz] = [ Just $ CoreOps.transpose dz - (CoreOps.invertPermutation p :: Tensor Value Int32) + (CoreOps.invertPermutation p :: Tensor Build Int32) , Nothing ] @@ -582,28 +593,28 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] = x output dz ] where - output :: Tensor Value a - output = toT $ Output 0 (Rendered nodeDef) + output :: Tensor Build a + output = toT $ Output 0 (nodeDefName nodeDef) ksize = lookupAttr nodeDef "ksize" :: [Int64] strides = lookupAttr nodeDef "strides" :: [Int64] padding = lookupAttr nodeDef "padding" :: ByteString dataFormat = lookupAttr nodeDef "data_format" :: ByteString opGrad "Reshape" _ [toT -> x, _] [dz] = - [Just $ reshape dz $ shape (x :: Tensor Value a), Nothing] + [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing] opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing] opGrad "TruncatedNormal" _ _ _ = [Nothing] -opGrad "RefIdentity" _ _ [dz] = [Just dz] +opGrad "RefIdentity" _ _ [dz] = [Just $ expr dz] opGrad "Cast" nodeDef _ [dz] = [Just reverseCast] where -- TODO(gnezdo): too permissive, python only allows float types as src_type. reverseCast = - buildOp (opDef "Cast" + pureOp [] $ pure (opDef "Cast" & opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString) - & opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString)) - dz + & opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString) + & opInputs .~ [renderedOutput dz]) opGrad "DynamicStitch" nodeDef inputs [dz] = replicate halfLen Nothing ++ valuesGrads @@ -614,7 +625,7 @@ opGrad "DynamicStitch" nodeDef inputs [dz] = in if 2 * half == len then half else error ("Uneven input size " ++ show (len, showMessage nodeDef)) - valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32) + valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Build Int32) | idx <- take halfLen inputs ] @@ -622,14 +633,14 @@ opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz = [ Just reconstructed, Nothing ] where reconstructed = CoreOps.reshape stitched - (CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32) + (CoreOps.shape (xs :: Tensor Build a) :: Tensor Build Int32) stitched = CoreOps.dynamicStitch partitionedIndices dz partitionedIndices = CoreOps.dynamicPartition np originalIndices indices np = lookupAttr nodeDef "num_partitions" :: Int64 originalIndices = CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape prefixShape = shapeInt32 indices - shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32 + shapeInt32 t = CoreOps.shape t :: Tensor Build Int32 opGrad "Select" _ [toT -> c, toT -> x, _] [dz] = [ Nothing @@ -639,18 +650,18 @@ opGrad "Select" _ [toT -> c, toT -> x, _] [dz] = where zeros = CoreOps.zerosLike x -- TODO(gnezdo): Unlike Python, no control dependency on dz. -opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ] +opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.inv x ] -- TODO(gnezdo): Reuse the output instead of doing another exp, -- though, it is probably CSE'd away anyway. -opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ] +opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.exp x ] opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] = [ Just $ CoreOps.unsortedSegmentSum - (CoreOps.gather dz (t :: Tensor Value Int32)) - (y :: Tensor Value Int32) inputRows + (CoreOps.gather dz (t :: Tensor Build Int32)) + (y :: Tensor Build Int32) inputRows , Nothing , Nothing ] - where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1 + where inputRows = flatSlice (shape (x :: Tensor Build a)) 0 1 opGrad "LabelClasses" _ _ _ = [Nothing, Nothing] opGrad "LabelWeights" _ _ _ = [Nothing] @@ -710,13 +721,13 @@ numOutputs o = _ -> error $ "numOuputs not implemented for " ++ show (o ^. op) -- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0` -safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32 +safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32 safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1) -allDimensions :: Tensor Value Int32 +allDimensions :: Tensor Build Int32 allDimensions = vector [-1 :: Int32] -rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32 +rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Build Int32 rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1 lookupAttr :: Attribute a1 => NodeDef -> Text -> a1 diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index e123a41..3f27a80 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -166,7 +166,6 @@ import qualified Proto.Tensorflow.Core.Framework.TensorShape import TensorFlow.Build import TensorFlow.BuildOp import TensorFlow.ControlFlow (group) -import TensorFlow.Output (unNodeName) import TensorFlow.Tensor import TensorFlow.Types @@ -183,7 +182,7 @@ import qualified Prelude (abs) -- "1". instance ( TensorType a , Num a - , v ~ Value + , v ~ Build , OneOf '[ Double, Float, Int32, Int64 , Complex Float, Complex Double] a) => Num (Tensor v a) where (+) = CoreOps.add @@ -194,10 +193,10 @@ instance ( TensorType a signum = CoreOps.sign negate = CoreOps.neg -matTranspose :: TensorType a => Tensor v a -> Tensor Value a +matTranspose :: TensorType a => Tensor e a -> Tensor Build a matTranspose = matTranspose' id -matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a +matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Build a matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32]) placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a) @@ -208,7 +207,7 @@ placeholder' :: forall m a . (MonadBuild m, TensorType a) placeholder' params pShape -- Note: we don't use CoreOps.placeholder' since that op isn't stateful, -- and thus would be CSE'd. - = build $ buildOp $ opDef "Placeholder" + = build $ buildOp [] $ opDef "Placeholder" & opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "shape" .~ pShape & params @@ -216,11 +215,11 @@ placeholder' params pShape -- | Creates a variable initialized to the given value. -- Initialization happens next time session runs. initializedVariable :: (MonadBuild m, TensorType a) - => Tensor Value a -> m (Tensor Ref a) + => Tensor v a -> m (Tensor Ref a) initializedVariable = initializedVariable' id initializedVariable' :: (MonadBuild m, TensorType a) - => OpParams -> Tensor Value a -> m (Tensor Ref a) + => OpParams -> Tensor v a -> m (Tensor Ref a) initializedVariable' params initializer = do v <- CoreOps.variable' params [] -- The shape is not known initially. i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v @@ -240,17 +239,20 @@ zeroInitializedVariable' zeroInitializedVariable' params = initializedVariable' params . zeros -- TODO: Support heterogeneous list of tensors. -save :: forall a m v . (MonadBuild m, TensorType a) +save :: forall a m v . (Rendered v, MonadBuild m, TensorType a) => ByteString -- ^ File path. -> [Tensor v a] -- ^ Tensors to save. -> m ControlNode -save path xs = do - let toByteStringTensor = scalar . encodeUtf8 . unNodeName - names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs +save path xs = build $ do + let toByteStringTensor = scalar . encodeUtf8 . encodeOutput . renderedOutput + let names = fmap toByteStringTensor xs let types = replicate (length xs) (tensorType (undefined :: a)) - let saveOp = buildOp $ opDef "Save" - & opAttr "T" .~ types - build $ saveOp (scalar path) (CoreOps.pack names) xs + names' <- buildInputs $ CoreOps.pack names + xs' <- buildInputs xs + path' <- buildInputs $ scalar path + buildOp [] $ opDef "Save" + & opAttr "T" .~ types + & opInputs .~ (path' ++ names' ++ xs') -- | Restore a tensor's value from a checkpoint file. -- @@ -261,20 +263,22 @@ restoreFromName :: forall a m . (MonadBuild m, TensorType a) -> ByteString -- ^ Tensor name override. -> Tensor Ref a -- ^ Tensor to restore. -> m ControlNode -restoreFromName path name x = do - let restoreOp = buildOp $ opDef "Restore" - & opAttr "dt" .~ tensorType (undefined :: a) - group =<< CoreOps.assign x - (restoreOp (scalar path) (scalar name) :: Tensor Value a) +restoreFromName path name x = build $ do + path' <- buildInputs $ scalar path + name' <- buildInputs $ scalar name + restoreOp <- buildOp [] $ opDef "Restore" + & opAttr "dt" .~ tensorType (undefined :: a) + & opInputs .~ (path' ++ name') + group =<< CoreOps.assign x (restoreOp :: Tensor Value a) -- | Restore a tensor's value from a checkpoint file. restore :: forall a m . (MonadBuild m, TensorType a) => ByteString -- ^ File path. -> Tensor Ref a -- ^ Tensor to restore. -> m ControlNode -restore path x = do - name <- encodeUtf8 . unNodeName <$> build (renderNodeName x) - restoreFromName path name x +restore path x = restoreFromName path name x + where + name = encodeUtf8 $ encodeOutput $ renderedOutput x -- | Create a constant tensor. -- @@ -283,10 +287,10 @@ restore path x = do -- element 0: index (0, ..., 0) -- element 1: index (0, ..., 1) -- ... -constant :: TensorType a => Shape -> [a] -> Tensor Value a +constant :: TensorType a => Shape -> [a] -> Tensor Build a constant = constant' id -constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a +constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Build a constant' params (Shape cShape) values | invalidLength = error invalidLengthMsg | otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode)) @@ -305,24 +309,24 @@ constant' params (Shape cShape) values -- | Reshape a N-D tensor down to a scalar. -- -- See `TensorFlow.GenOps.Core.reshape`. -scalarize :: (TensorType a) => Tensor v a -> Tensor Value a +scalarize :: TensorType a => Tensor v a -> Tensor Build a scalarize t = CoreOps.reshape t (vector scalarShape) where scalarShape = [] :: [Int32] -- | Create a constant vector. -vector :: TensorType a => [a] -> Tensor Value a +vector :: TensorType a => [a] -> Tensor Build a vector = vector' id -vector' :: TensorType a => OpParams -> [a] -> Tensor Value a +vector' :: TensorType a => OpParams -> [a] -> Tensor Build a vector' params xs = constant' params [fromIntegral $ length xs] xs -- | Create a constant scalar. -scalar :: TensorType a => a -> Tensor Value a +scalar :: TensorType a => a -> Tensor Build a scalar = scalar' id -scalar' :: TensorType a => OpParams -> a -> Tensor Value a +scalar' :: TensorType a => OpParams -> a -> Tensor Build a scalar' params x = constant' params [] [x] -- | Random tensor from the unit normal distribution with bounded values. @@ -338,28 +342,28 @@ truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a) -> m (Tensor Value a) truncatedNormal' = CoreOps.truncatedNormal' -zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a +zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0) -shape :: TensorType t => Tensor v1 t -> Tensor Value Int32 +shape :: TensorType t => Tensor v t -> Tensor Build Int32 shape = CoreOps.shape -shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32 +shape' :: TensorType t => OpParams -> Tensor v t -> Tensor Build Int32 shape' = CoreOps.shape' -expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t +expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t expandDims = CoreOps.expandDims -expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t +expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t expandDims' = CoreOps.expandDims' -- | Helper function for reduction ops (translation of math_ops.reduced_shape). reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) => - Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32 + Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32 reducedShape inputShape axes = let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7] axes32 = toInt32 axes -- [1, 2] - toInt32 x = CoreOps.cast x :: Tensor Value Int32 + toInt32 x = CoreOps.cast x :: Tensor Build Int32 inputRank = CoreOps.size inputShape32 -- 4 axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank axesShape = shape axesMod -- [2] diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index d5a37bd..d11b093 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -79,6 +79,7 @@ Test-Suite EmbeddingOpsTest , test-framework , test-framework-hunit , test-framework-quickcheck2 + , transformers , vector Test-Suite ArrayOpsTest diff --git a/tensorflow-ops/tests/ArrayOpsTest.hs b/tensorflow-ops/tests/ArrayOpsTest.hs index 1b32b03..e266ace 100644 --- a/tensorflow-ops/tests/ArrayOpsTest.hs +++ b/tensorflow-ops/tests/ArrayOpsTest.hs @@ -24,9 +24,7 @@ import Test.HUnit ((@=?)) import qualified Data.Vector as V import qualified TensorFlow.Ops as TF -import qualified TensorFlow.Session as TF -import qualified TensorFlow.Tensor as TF -import qualified TensorFlow.Types as TF +import qualified TensorFlow.Core as TF import qualified TensorFlow.GenOps.Core as CoreOps -- | Test split and concat are inverses. @@ -44,7 +42,7 @@ testSplit = testCase "testSplit" $ TF.runSession $ do testShapeN :: Test testShapeN = testCase "testShapeN" $ TF.runSession $ do let shapes = map TF.Shape [[1],[2,3]] - let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float] + let tensors = map TF.zeros shapes :: [TF.Tensor TF.Build Float] result <- TF.run $ CoreOps.shapeN tensors liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64]) diff --git a/tensorflow-ops/tests/BuildTest.hs b/tensorflow-ops/tests/BuildTest.hs index 9410dab..1df0719 100644 --- a/tensorflow-ops/tests/BuildTest.hs +++ b/tensorflow-ops/tests/BuildTest.hs @@ -34,9 +34,7 @@ import TensorFlow.Build , asGraphDef , evalBuildT , flushNodeBuffer - , render , withDevice - , colocateWith , withNameScope , opName ) @@ -50,7 +48,13 @@ import TensorFlow.Ops , variable' ) import TensorFlow.Output (Device(..)) -import TensorFlow.Tensor (Tensor, Value, Ref) +import TensorFlow.Tensor + ( colocateWith + , render + , Tensor + , Value + , Ref + ) import TensorFlow.Session ( run , runSession @@ -65,8 +69,7 @@ import qualified Data.Vector as V -- | Test 'opName' behavior. testOpName :: Test testOpName = testCase "testOpName" $ do - let graph = variable' (opName .~ "foo") [] - >>= render :: Build (Tensor Ref Float) + let graph = variable' (opName .~ "foo") [] :: Build (Tensor Ref Float) nodeDef :: NodeDef nodeDef = head $ asGraphDef graph ^. node "Variable" @=? (nodeDef ^. op) @@ -114,7 +117,6 @@ testNamedAndScoped :: Test testNamedAndScoped = testCase "testNamedAndScoped" $ do let graph :: Build (Tensor Ref Float) graph = withNameScope "foo1" (variable' (opName .~ "bar1") []) - >>= render nodeDef :: NodeDef nodeDef = head $ asGraphDef graph ^. node "Variable" @=? (nodeDef ^. op) diff --git a/tensorflow-ops/tests/DataFlowOpsTest.hs b/tensorflow-ops/tests/DataFlowOpsTest.hs index 38fdb5a..184cbd2 100644 --- a/tensorflow-ops/tests/DataFlowOpsTest.hs +++ b/tensorflow-ops/tests/DataFlowOpsTest.hs @@ -26,16 +26,14 @@ import Test.QuickCheck.Monadic (monadicIO, run) import qualified Data.Vector as V import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.Ops as TF -import qualified TensorFlow.Session as TF -import qualified TensorFlow.Tensor as TF -import qualified TensorFlow.Types as TF +import qualified TensorFlow.Core as TF -- DynamicSplit is undone with DynamicStitch to get the original input -- back. testDynamicPartitionStitchInverse :: forall a. (TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property testDynamicPartitionStitchInverse (StitchExample numParts values partitions) = - let splitParts :: [TF.Tensor TF.Value a] = + let splitParts :: [TF.Tensor TF.Build a] = CoreOps.dynamicPartition numParts (TF.vector values) partTensor partTensor = TF.vector partitions restitchIndices = CoreOps.dynamicPartition numParts diff --git a/tensorflow-ops/tests/EmbeddingOpsTest.hs b/tensorflow-ops/tests/EmbeddingOpsTest.hs index 137f7f6..ecaeaa4 100644 --- a/tensorflow-ops/tests/EmbeddingOpsTest.hs +++ b/tensorflow-ops/tests/EmbeddingOpsTest.hs @@ -19,6 +19,7 @@ -- | Tests for EmbeddingOps. module Main where +import Control.Monad.IO.Class (liftIO) import Data.Int (Int32, Int64) import Data.List (genericLength) import Google.Test (googleTest) @@ -48,16 +49,15 @@ testEmbeddingLookupHasRightShapeWithPartition = let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items. let embedding1 = [1, 1, 1 :: Int32] let embedding2 = [0, 0, 0 :: Int32] - let embedding = [ TF.constant embShape embedding1 - , TF.constant embShape embedding2 - ] let idValues = [0, 1 :: Int32] - let ids = TF.constant (TF.Shape [1, 2]) idValues - let op = embeddingLookup embedding ids (values, shape) <- TF.runSession $ do - vs <- op + embedding <- mapM TF.render [ TF.constant embShape embedding1 + , TF.constant embShape embedding2 + ] + let ids = TF.constant (TF.Shape [1, 2]) idValues + vs <- embeddingLookup embedding ids TF.run (vs, TF.shape vs) -- This is the shape that is returned in the equiv. Python. @@ -77,13 +77,12 @@ testEmbeddingLookupHasRightShape = , 0, 0, 0 :: Int32 ] - let embedding = TF.constant embShape embeddingInit let idValues = [0, 1 :: Int32] - let ids = TF.constant (TF.Shape [1, 2]) idValues - let op = embeddingLookup [embedding] ids (values, shape) <- TF.runSession $ do - vs <- op + embedding <- TF.render $ TF.constant embShape embeddingInit + let ids = TF.constant (TF.Shape [1, 2]) idValues + vs <- embeddingLookup [embedding] ids TF.run (vs, TF.shape vs) -- This is the shape that is returned in the equiv. Python. @@ -92,7 +91,6 @@ testEmbeddingLookupHasRightShape = -- "[0, 1]" should pull out the resulting vector. values @=? V.fromList [1, 1, 1, 0, 0, 0] - -- | Check that we can calculate gradients w.r.t embeddings. testEmbeddingLookupGradients :: Test testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do @@ -108,10 +106,10 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do x <- TF.placeholder (TF.Shape [2]) embedding <- TF.initializedVariable - =<< TF.render (TF.constant embShape embeddingInit) + (TF.constant embShape embeddingInit) op <- embeddingLookup [embedding] ids - let twoNorm = CoreOps.square $ TF.abs (op - x) + let twoNorm = CoreOps.square $ TF.abs (op `TF.sub` x) loss = TF.mean twoNorm (TF.scalar (0 :: Int32)) grad <- fmap head (TF.gradients loss [embedding]) @@ -131,23 +129,21 @@ testEmbeddingLookupUndoesSplit (LookupExample numParts shape@(TF.Shape (firstDim : restDims)) values - indices) = - let modShardedValues :: [TF.Tensor TF.Value a] = - CoreOps.dynamicPartition numParts shapedValues cyclicCounter - cyclicCounter :: TF.Tensor TF.Value Int32 = + indices) = monadicIO $ run $ TF.runSession $ do + let shapedValues = TF.constant shape values + indicesVector <- TF.render $ TF.vector indices + let directs = CoreOps.gather shapedValues indicesVector + let cyclicCounter :: TF.Tensor TF.Build Int32 = TF.vector [0..fromIntegral firstDim-1] `CoreOps.mod` fromIntegral numParts - indicesVector = TF.vector indices - directs = CoreOps.gather shapedValues indicesVector - shapedValues = TF.constant shape values - in monadicIO $ run $ do - (shapeOut, got, want :: V.Vector a) <- - TF.runSession $ TF.run =<< do - embeddings <- embeddingLookup modShardedValues indicesVector - return (TF.cast (TF.shape embeddings), embeddings, directs) - -- Checks the explicitly documented invariant of embeddingLookup. - shapeOut @=? V.fromList (genericLength indices : restDims) - got @=? want + modShardedValues :: [TF.Tensor TF.Value a] <- + mapM TF.render $ CoreOps.dynamicPartition numParts shapedValues cyclicCounter + embeddings <- embeddingLookup modShardedValues indicesVector + (shapeOut, got, want :: V.Vector a) <- + TF.run (TF.cast (TF.shape embeddings), embeddings, directs) + -- Checks the explicitly documented invariant of embeddingLookup. + liftIO $ shapeOut @=? V.fromList (genericLength indices : restDims) + liftIO $ got @=? want testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)" -- | Consistent set of parameters for EmbeddingLookupUndoesSplit. diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 765ed22..b51eab9 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -36,10 +36,11 @@ import Proto.Tensorflow.Core.Framework.NodeDef (op) testGradientSimple :: Test testGradientSimple = testCase "testGradientSimple" $ do - let x = TF.scalar (3 :: Float) - b = TF.scalar (4 :: Float) - y = x*x + b - grads = TF.gradients y [x, b] + let grads = do + x <- TF.render $ TF.scalar (3 :: Float) + b <- TF.render $ TF.scalar (4 :: Float) + let y = x `TF.mul` x `TF.add` b + TF.gradients y [x, b] -- Assert that the gradients are right. [dx, db] <- TF.runSession $ grads >>= TF.run 6 @=? TF.unScalar dx @@ -88,9 +89,10 @@ testGradientSimple = testCase "testGradientSimple" $ do testGradientDisconnected :: Test testGradientDisconnected = testCase "testGradientDisconnected" $ do - let x = TF.scalar (3 :: Float) - b = TF.scalar (4 :: Float) - grads = TF.gradients x [x, b] + let grads = do + x <- TF.render $ TF.scalar (3 :: Float) + b <- TF.render $ TF.scalar (4 :: Float) + TF.gradients x [x, b] -- Assert that the gradients are right. [dx, db] <- TF.runSession $ grads >>= TF.run 1 @=? TF.unScalar dx @@ -118,7 +120,7 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do let shape = TF.constant (TF.Shape [1]) [1] x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape - TF.gradients (x + y*3) [x, y] >>= TF.run + TF.gradients (TF.expr x + TF.expr y * 3) [x, y] >>= TF.run -- If this test fails, it will likely be caused by an exception within -- `TF.gradients`. These asserts are extra. 1 @=? TF.unScalar dx @@ -142,8 +144,8 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do testDiamond :: Test testDiamond = testCase "testDiamond" $ do [dx] <- TF.runSession $ do - let x = TF.vector [1] - y = x*x + x <- TF.render $ TF.vector [1] + let y = x `TF.mul` x z = y*y TF.gradients z [x] >>= TF.run (4 :: Float) @=? TF.unScalar dx @@ -152,8 +154,8 @@ testDiamond = testCase "testDiamond" $ do testMaxGradient :: Test testMaxGradient = testCase "testMaxGradient" $ do [dx] <- TF.runSession $ do - let x = TF.vector [1, 2, 3, 0, 1 :: Float] - y = TF.max x (0 :: TF.Tensor TF.Value Int32) + x <- TF.render $ TF.vector [1, 2, 3, 0, 1 :: Float] + let y = TF.max x (0 :: TF.Tensor TF.Build Int32) TF.gradients y [x] >>= TF.run V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index a117015..de35828 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -56,8 +56,7 @@ testSaveRestore = testCase "testSaveRestore" $ withSystemTempDirectory "" $ \dirPath -> do let path = B8.pack $ dirPath ++ "/checkpoint" var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float) - var = TF.render =<< - TF.zeroInitializedVariable' (TF.opName .~ "a") + var = TF.zeroInitializedVariable' (TF.opName .~ "a") (TF.Shape []) TF.runSession $ do v <- var @@ -76,7 +75,8 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do p2 <- TF.placeholder [] let enc :: Float -> TF.TensorData Float enc n = TF.encodeTensorData [] (V.fromList [n]) - result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 + result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] + $ p1 `TF.add` p2 liftIO $ result @=? TF.Scalar 5 -- | Test that regular tensors can also be used for feeds, as long as they each @@ -90,7 +90,8 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0 let enc :: Float -> TF.TensorData Float enc n = TF.encodeTensorData [] (V.fromList [n]) - result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 + result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] + $ p1 `TF.add` p2 liftIO $ result @=? TF.Scalar 5 main :: IO () diff --git a/tensorflow-ops/tests/RegressionTest.hs b/tensorflow-ops/tests/RegressionTest.hs index ec83eed..7fd2e73 100644 --- a/tensorflow-ops/tests/RegressionTest.hs +++ b/tensorflow-ops/tests/RegressionTest.hs @@ -38,7 +38,7 @@ fit xData yData = TF.runSession $ do return (w', b') gradientDescent :: Float - -> TF.Tensor TF.Value Float + -> TF.Tensor TF.Build Float -> [TF.Tensor TF.Ref Float] -> TF.Session TF.ControlNode gradientDescent alpha loss params = do diff --git a/tensorflow-ops/tests/TypesTest.hs b/tensorflow-ops/tests/TypesTest.hs index 567d232..a7fdc84 100644 --- a/tensorflow-ops/tests/TypesTest.hs +++ b/tensorflow-ops/tests/TypesTest.hs @@ -53,9 +53,9 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $ let floatData = V.fromList [1..6 :: Float] stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]] boolData = V.fromList [True, True, False, True, False, False] - f <- TF.build $ TF.placeholder [2,3] - s <- TF.build $ TF.placeholder [2,3] - b <- TF.build $ TF.placeholder [2,3] + f <- TF.placeholder [2,3] + s <- TF.placeholder [2,3] + b <- TF.placeholder [2,3] let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData) , TF.feed s (TF.encodeTensorData [2,3] stringData) , TF.feed b (TF.encodeTensorData [2,3] boolData) @@ -63,7 +63,8 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $ -- Do something idempotent to the tensors to verify that tensorflow can -- handle the encoding. Originally this used `TF.identity`, but that -- wasn't enough to catch a bug in the encoding of Bool. - (f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b) + (f', s', b') <- TF.runWithFeeds feeds + (f `TF.add` 0, TF.identity s, TF.select b b b) liftIO $ do floatData @=? f' stringData @=? s' diff --git a/tensorflow-queue/src/TensorFlow/Queue.hs b/tensorflow-queue/src/TensorFlow/Queue.hs index 2e0a5cf..0131e9b 100644 --- a/tensorflow-queue/src/TensorFlow/Queue.hs +++ b/tensorflow-queue/src/TensorFlow/Queue.hs @@ -60,7 +60,7 @@ makeQueue :: forall as m . (MonadBuild m, TensorTypes as) -- under the given name across multiple sessions. -> m (Queue as) makeQueue capacity sharedName = do - q <- build $ buildOp (opDef "FIFOQueue" + q <- build $ buildOp [] (opDef "FIFOQueue" & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as) & opAttr "shared_name" .~ sharedName & opAttr "capacity" .~ capacity diff --git a/tensorflow/src/TensorFlow/Build.hs b/tensorflow/src/TensorFlow/Build.hs index 9cdd43b..84c6c09 100644 --- a/tensorflow/src/TensorFlow/Build.hs +++ b/tensorflow/src/TensorFlow/Build.hs @@ -13,9 +13,13 @@ -- limitations under the License. {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} +{-# LANGUAGE TypeFamilies #-} module TensorFlow.Build ( -- * Graph node types ControlNode(..) @@ -32,8 +36,6 @@ module TensorFlow.Build , opControlInputs -- * The Build monad , GraphState - , render - , renderNodeName , renderedNodeDefs , BuildT , Build @@ -46,27 +48,23 @@ module TensorFlow.Build , addGraphDef , flushInitializers , flushNodeBuffer + , summaries -- * Creating and looking up Ops , getOrAddOp , addNewOp - , renderOutput + , encodeOutput + , lookupNode -- * Modifying all nodes in a Build action - , colocateWith , withStateLens , withDevice , withNameScope , withNodeDependencies - -- * Internal Summary related bits. - , addSummary - , SummaryTensor - , collectAllSummaries ) where import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT) -import Data.ByteString (ByteString) import Data.Default (def) import Data.Functor.Identity (Identity(..)) import qualified Data.Map.Strict as Map @@ -94,7 +92,6 @@ import Proto.Tensorflow.Core.Framework.NodeDef import TensorFlow.Orphans () import TensorFlow.Output -import TensorFlow.Tensor newtype Unique = Unique Int deriving (Eq, Ord, Enum) @@ -125,9 +122,6 @@ opDefWithName n t = OpDef , _opControlInputs = [] } --- | Synonym for the tensors that return serialized Summary proto. -type SummaryTensor = Tensor Value ByteString - data GraphState = GraphState { _renderedNodes :: !(Map.Map PendingNode NodeDef) -- ^ Nodes which have been rendered. Keeps track of the unique ID we @@ -148,8 +142,8 @@ data GraphState = GraphState , _initializationNodes :: [NodeName] -- ^ The nodes to run next time a TF.run is issued, typically -- variable initializers. - , _summaries :: [SummaryTensor] - -- ^ The tensors for summary + , _summaries :: [Output] + -- ^ The tensors for summary (ByteString type) } -- | A node definition without its final name. Used as a key in the @@ -191,7 +185,7 @@ defaultControlInputs = lens _defaultControlInputs initializationNodes :: Lens' GraphState [NodeName] initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x }) -summaries :: Lens' GraphState [SummaryTensor] +summaries :: Lens' GraphState [Output] summaries = lens _summaries (\g x -> g { _summaries = x }) -- | An action for building nodes in a TensorFlow graph. @@ -238,9 +232,7 @@ flushInitializers = do -- | Registers the given node to be executed before the next -- 'TensorFlow.Session.run'. addInitializer :: MonadBuild m => ControlNode -> m () -addInitializer (ControlNode o) = build $ do - i <- getOrAddOp o - initializationNodes %= (i:) +addInitializer (ControlNode i) = build $ initializationNodes %= (i:) -- | Produce a GraphDef proto representation of the nodes that are rendered in -- the given 'Build' action. @@ -255,30 +247,31 @@ addGraphDef g = build $ nodeBuffer <>= g ^. node -- | Render the given op if it hasn't been rendered already, and return its -- name. -getOrAddOp :: Op -> Build NodeName -getOrAddOp o = NodeName . (^. name) <$> resolveOp o - -resolveOp :: Op -> Build NodeDef -resolveOp (Rendered n) = return n -resolveOp (Unrendered o) = do +getOrAddOp :: OpDef -> Build NodeName +getOrAddOp o = do pending <- getPendingNode o uses renderedNodes (Map.lookup pending) >>= \case - Just n -> return n + Just n -> return $ NodeName $ n ^. name Nothing -> addNewOpFromPending pending +lookupNode :: NodeName -> Build NodeDef +lookupNode n = uses renderedNodeDefs (Map.lookup n) >>= \case + Just n' -> return n' + Nothing -> error $ "lookupNode: unknown node name " ++ show n + -- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops -- which are not safe to dedup (e.g, "variable" and "assign"). -addNewOp :: OpDef -> Build NodeDef +addNewOp :: OpDef -> Build NodeName addNewOp o = getPendingNode o >>= addNewOpFromPending -addNewOpFromPending :: PendingNode -> Build NodeDef +addNewOpFromPending :: PendingNode -> Build NodeName addNewOpFromPending pending = do nodeName <- renderPendingNode pending let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName nodeBuffer %= (nodeDef :) renderedNodes %= Map.insert pending nodeDef renderedNodeDefs %= Map.insert nodeName nodeDef - return nodeDef + return nodeName -- | Get the pending node corresponding to an OpDef, which may or may not have -- been rendered before. Implicitly renders all of this node's inputs. @@ -287,20 +280,18 @@ getPendingNode o = do -- An empty string in the proto field means that no specific -- device is specified. dev <- maybe "" deviceName <$> use defaultDevice - inputs <- mapM getInput (o ^. opInputs) scope <- use currentScope controls <- use defaultControlInputs + let inputs = map encodeOutput (o ^. opInputs) let controlInputs - = map getDep (o ^. opControlInputs ++ Set.toList controls) + = map makeDep (o ^. opControlInputs ++ Set.toList controls) return $ PendingNode scope (o ^. opName) $ def & op .~ (unOpType (o ^. opType) :: Text) & attr .~ _opAttrs o & input .~ (inputs ++ controlInputs) & device .~ dev where - getInput (Output (OutputIx k) subOp) - = (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp - getDep = ("^" <>) . unNodeName + makeDep = ("^" <>) . unNodeName -- | Pick a name for a pending node. If it has an explicit name, just use that; -- if the name is implicit, assign a new unique name based on the op type. @@ -317,12 +308,11 @@ renderPendingNode (PendingNode scope pendingName nodeDef) return $ nodeDef ^. op <> "_" <> Text.pack (show k) --- | Render an 'Output' and return a string representation for the TensorFlow +-- | Turn an 'Output' into a string representation for the TensorFlow -- foreign APIs. -renderOutput :: Output -> Build Text -renderOutput (Output (OutputIx i) o) = do - n <- getOrAddOp o - return $ unNodeName n <> Text.pack (":" ++ show i) +encodeOutput :: Output -> Text +encodeOutput (Output (OutputIx 0) n) = unNodeName n +encodeOutput (Output (OutputIx i) n) = unNodeName n <> Text.pack (':' : show i) -- | Modify some part of the state, run an action, and restore the state -- after that action is done. @@ -339,15 +329,6 @@ withStateLens accessor f act = do withDevice :: MonadBuild m => Maybe Device -> m a -> m a withDevice d = withStateLens defaultDevice (const d) --- | Places all nodes rendered in the given 'Build' action on the same --- device as the given Tensor (see also 'withDevice'). Make sure that --- the action has side effects of rendering the desired tensors. A pure --- return would not have the desired effect. -colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a -colocateWith t x = do - d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp) - withDevice (Just d) x - -- | Prepend a scope to all nodes rendered in the given 'Build' action. withNameScope :: MonadBuild m => Text -> m a -> m a withNameScope s = withStateLens currentScope (Scope s :) @@ -355,31 +336,3 @@ withNameScope s = withStateLens currentScope (Scope s :) -- | Add control inputs to all nodes rendered in the given 'Build' action. withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes) - --- | Render a 'Tensor', fixing its name, scope, device and control inputs from --- the 'Build' context. Also renders any dependencies of the 'Tensor' that --- weren't already rendered. --- --- This operation is idempotent; @render >=> render === render@. However, --- rendering a (previously un-rendered) 'Tensor' in two different contexts --- may result in two different 'Tensor's. -render :: MonadBuild m => Tensor v a -> m (Tensor v a) -render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp) - --- | Render a 'Tensor' and get its node's name. -renderNodeName :: Tensor v a -> Build NodeName -renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp) - --- | Records the given summary action in Build for retrieval with --- 'collectAllSummaries'. The summary op is required to produce a --- Summary protocol buffer in string form. For safety, use the --- pre-composed functions: Logging.scalarSummary and --- Logging.histogramSummary. -addSummary :: SummaryTensor -> Build () -addSummary t = summaries %= (t :) - --- | Retrieves the summary ops collected thus far. Typically this only --- happens once, but if 'TensorFlow.Session.buildWithSummary' is used --- repeatedly, the values accumulate. -collectAllSummaries :: Monad m => BuildT m [SummaryTensor] -collectAllSummaries = use summaries diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index ad625d1..4a6907e 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -12,26 +12,27 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module TensorFlow.BuildOp - ( OpResult - , BuildOp + ( BuildResult(..) , buildOp - , buildListOp + , PureResult(..) + , pureOp , eqLengthGuard + , BuildInputs(..) , OpParams ) where -import Control.Monad (replicateM) +import Control.Monad (liftM2, replicateM) import Control.Monad.Reader (ReaderT, runReaderT, ask) -import Control.Monad.State.Strict (State, runState, get, put) +import Control.Monad.State.Strict (State, evalState, get, put) import Data.Int (Int64) -import Lens.Family2 ((&), (<>~), (^.)) import TensorFlow.Build import TensorFlow.Output @@ -40,48 +41,45 @@ import TensorFlow.Types data ResultState = ResultState !OutputIx [Int64] deriving Show -type Result = ReaderT Op (State ResultState) +type Result = ReaderT NodeName (State ResultState) -- | Class of types that can be used as op outputs. -class OpResult a where - toResult :: Result a +class BuildResult a where + buildResult :: Result a -instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where - toResult = (,) <$> toResult <*> toResult +instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where + buildResult = (,) <$> buildResult <*> buildResult -instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where - toResult = (,,) <$> toResult <*> toResult <*> toResult +instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where + buildResult = (,,) <$> buildResult <*> buildResult <*> buildResult -instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4) - => OpResult (a1, a2, a3, a4) where - toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult +instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4) + => BuildResult (a1, a2, a3, a4) where + buildResult = (,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult -instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5) - => OpResult (a1, a2, a3, a4, a5) where - toResult = (,,,,) <$> toResult - <*> toResult - <*> toResult - <*> toResult - <*> toResult +instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5) + => BuildResult (a1, a2, a3, a4, a5) where + buildResult = (,,,,) <$> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult -instance ( OpResult a1 - , OpResult a2 - , OpResult a3 - , OpResult a4 - , OpResult a5 - , OpResult a6 +instance ( BuildResult a1 + , BuildResult a2 + , BuildResult a3 + , BuildResult a4 + , BuildResult a5 + , BuildResult a6 ) - => OpResult (a1, a2, a3, a4, a5, a6) where - toResult = (,,,,,) - <$> toResult - <*> toResult - <*> toResult - <*> toResult - <*> toResult - <*> toResult - -tensorResult :: TensorKind v -> Result (Tensor v a) -tensorResult v = Tensor v <$> recordResult + => BuildResult (a1, a2, a3, a4, a5, a6) where + buildResult = (,,,,,) + <$> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult recordResult :: Result Output recordResult = do @@ -90,144 +88,39 @@ recordResult = do put $! ResultState (i+1) ns return $! output i o -instance OpResult ResourceHandle where - toResult = ResourceHandle <$> recordResult +instance BuildResult ResourceHandle where + buildResult = ResourceHandle <$> recordResult -instance OpResult (Tensor Value a) where - toResult = tensorResult ValueKind +instance Rendered v => BuildResult (Tensor v a) where + buildResult = Tensor . pure <$> recordResult -instance OpResult (Tensor Ref a) where - toResult = tensorResult RefKind +instance BuildResult ControlNode where + buildResult = ControlNode <$> ask -instance OpResult ControlNode where - toResult = ControlNode <$> ask +instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where + buildResult = loop (tensorTypes :: TensorTypeList as) + where + loop :: TensorTypeList bs -> Result (TensorList v bs) + loop Nil = return Nil + loop (TensorTypeProxy :/ ls) = do + t <- buildResult + ts <- loop ls + return (t :/ ts) -tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as) -tensorListResult v = loop (tensorTypes :: TensorTypeList as) - where - loop :: TensorTypeList bs -> Result (TensorList v bs) - loop Nil = return Nil - loop (TensorTypeProxy :/ ls) = do - t <- tensorResult v - ts <- loop ls - return (t :/ ts) - -instance TensorTypes as => OpResult (TensorList Value as) where - toResult = tensorListResult ValueKind - -instance TensorTypes as => OpResult (TensorList Ref as) where - toResult = tensorListResult RefKind - -instance OpResult a => OpResult [a] where - toResult = do +instance BuildResult a => BuildResult [a] where + buildResult = do ResultState i ns <- get case ns of - [] -> error $ "Ran out of counts in toResult. " ++ - "Likely misuse of buildListOp." + [] -> error $ "Ran out of counts in buildResult. " ++ + "Likely misuse of buildOp." (n : rest) -> do put $! ResultState i rest - replicateM (fromIntegral n) toResult + replicateM (fromIntegral n) buildResult -runResult :: OpResult a => [Int64] -> Op -> a -runResult ns o = - case runState (runReaderT toResult o) (ResultState 0 ns) of - (x, ResultState _ []) -> x - (_, ns') -> error $ "Ununsed length in runResult attributes: " ++ - show (ns, ns') - --- | Make a new "pure" op, which may be deduped with identical ops within --- the same scope. -pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a -pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts - --- | Make a new "stateful" op, which will not be deduped with otherwise --- identical ops. -buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a -buildResult ns o ts - = runResult ns . Rendered <$> addNewOp (addReversedInputs o ts) - -addReversedInputs :: OpDef -> [Output] -> OpDef -addReversedInputs o ts = o & opInputs <>~ reverse ts - --- | Class of types that can be used as op functions. -class BuildOp f where - buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr) - -> OpDef - -> [Output] -- ^ Accumulator for inputs to the op. - -> f - --- | Starts an operation that returns a structured set of tensors --- (singletons or tuples). -buildOp :: BuildOp f => OpDef -> f -buildOp o = buildOp' [] o [] - --- | Starts an operation that returns a list of tensors. -buildListOp :: BuildOp f => [Int64] - -- ^ Cardinality of the corresponding list of tensors output. - -> OpDef -> f -buildListOp counts o = buildOp' counts o [] - -instance BuildOp ControlNode where - buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts - -instance BuildOp ResourceHandle where - buildOp' = pureResult - -instance BuildOp (Tensor Value a) where - buildOp' = pureResult - -instance BuildOp (Tensor Ref a) where - buildOp' = pureResult - -instance TensorTypes as => BuildOp (TensorList Value as) where - buildOp' = pureResult - -instance TensorTypes as => BuildOp (TensorList Ref as) where - buildOp' = pureResult - -instance BuildOp [Tensor Value a] where - buildOp' = pureResult - -instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where - buildOp' = pureResult - -instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where - buildOp' = pureResult - -instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4) - => BuildOp (t1, t2, t3, t4) where - buildOp' = pureResult - -instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5) - => BuildOp (t1, t2, t3, t4, t5) where - buildOp' = pureResult - -instance ( OpResult t1 - , OpResult t2 - , OpResult t3 - , OpResult t4 - , OpResult t5 - , OpResult t6 - ) - => BuildOp (t1, t2, t3, t4, t5, t6) where - buildOp' = pureResult - -instance OpResult a => BuildOp (Build a) where - buildOp' = buildResult - -instance BuildOp f => BuildOp (ResourceHandle -> f) where - buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts) - -instance BuildOp f => BuildOp (Tensor v a -> f) where - buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts) - -instance BuildOp f => BuildOp ([Tensor v a] -> f) where - buildOp' rf o accum ts - = buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum) - -instance BuildOp f => BuildOp (TensorList v as -> f) where - buildOp' rf o accum ts - = buildOp' rf o (reverse (tensorListOutputs ts) ++ accum) +buildOp :: BuildResult a => [Int64] -> OpDef -> Build a +buildOp sizes o = do + n <- addNewOp o + return $ flip evalState (ResultState 0 sizes) (runReaderT buildResult n) -- | Returns true if all the integers in each tuple are identical. -- Throws an error with a descriptive message if not. @@ -240,6 +133,104 @@ eqLengthGuard = all eachOk error ("number_attr " ++ numberAttrName ++ " contains tensors with different length " ++ show pairs) +----------- + + +-- | Class of types that can be used as op outputs. +class PureResult a where + pureResult :: ReaderT (Build OpDef) (State ResultState) a + +instance PureResult (Tensor Build a) where + pureResult = do + ResultState i ns <- get + put $! ResultState (i+1) ns + makeOp <- ask + return $ Tensor $ do + o <- makeOp + -- TODO: unify with BuildResult (Tensor v) + output i <$> getOrAddOp o + +instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where + pureResult = (,) <$> pureResult <*> pureResult + +instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where + pureResult = (,,) <$> pureResult <*> pureResult <*> pureResult + +instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4) + => PureResult (a1, a2, a3, a4) where + pureResult = (,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult + +instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5) + => PureResult (a1, a2, a3, a4, a5) where + pureResult = (,,,,) <$> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + +instance ( PureResult a1 + , PureResult a2 + , PureResult a3 + , PureResult a4 + , PureResult a5 + , PureResult a6 + ) + => PureResult (a1, a2, a3, a4, a5, a6) where + pureResult = (,,,,,) + <$> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + +instance PureResult a => PureResult [a] where + pureResult = do + ResultState i ns <- get + case ns of + [] -> error $ "Ran out of counts in pureResult. " ++ + "Likely misuse of pureOp with output lists." + n : rest -> do + put $! ResultState i rest + replicateM (fromIntegral n) pureResult + +instance TensorTypes as => PureResult (TensorList Build as) where + pureResult = loop (tensorTypes :: TensorTypeList as) + where + loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState) + (TensorList Build bs) + loop Nil = return Nil + loop (TensorTypeProxy :/ ls) = do + t <- pureResult + ts <- loop ls + return (t :/ ts) + +pureOp :: PureResult a => [Int64] -> Build OpDef -> a +pureOp sizes o = flip evalState (ResultState 0 sizes) (runReaderT pureResult o) + +----- +-- Class of types that can be used as arguments + +class BuildInputs a where + buildInputs :: a -> Build [Output] + +instance BuildInputs a => BuildInputs [a] where + buildInputs = fmap concat . mapM buildInputs + +instance BuildInputs (Tensor v a) where + buildInputs (Tensor t) = do + o <- toBuild t + return [o] + +instance BuildInputs (ListOf (Tensor v) as) where + buildInputs Nil = return [] + buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts) + +instance BuildInputs ResourceHandle where + buildInputs (ResourceHandle o) = return [o] + +---- + -- | Parameters to build an op (for example, the node name or optional attributes). -- TODO: be more type safe. type OpParams = OpDef -> OpDef diff --git a/tensorflow/src/TensorFlow/ControlFlow.hs b/tensorflow/src/TensorFlow/ControlFlow.hs index 2f20b92..d8ebc46 100644 --- a/tensorflow/src/TensorFlow/ControlFlow.hs +++ b/tensorflow/src/TensorFlow/ControlFlow.hs @@ -25,9 +25,6 @@ module TensorFlow.ControlFlow , noOp ) where -import qualified Data.Set as Set -import Lens.Family2 ((&), (.~)) - import TensorFlow.BuildOp import TensorFlow.Build import TensorFlow.Nodes @@ -46,11 +43,8 @@ withControlDependencies deps act = do -- When this op finishes, all ops in the input @n@ have finished. This op has -- no output. group :: (MonadBuild m, Nodes t) => t -> m ControlNode -group deps = do - nodes <- build $ Set.toList <$> getNodes deps - -- TODO: slicker way - return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes +group deps = withControlDependencies deps noOp -- | Does nothing. Only useful as a placeholder for control edges. -noOp :: ControlNode -noOp = buildOp $ opDef "NoOp" +noOp :: MonadBuild m => m ControlNode +noOp = build $ buildOp [] $ opDef "NoOp" diff --git a/tensorflow/src/TensorFlow/Core.hs b/tensorflow/src/TensorFlow/Core.hs index 7764e37..7834211 100644 --- a/tensorflow/src/TensorFlow/Core.hs +++ b/tensorflow/src/TensorFlow/Core.hs @@ -57,9 +57,9 @@ module TensorFlow.Core , Tensor , Value , Ref - , TensorKind(..) , value , tensorFromName + , expr -- ** Element types , TensorType , TensorData diff --git a/tensorflow/src/TensorFlow/Nodes.hs b/tensorflow/src/TensorFlow/Nodes.hs index a7ce925..6731600 100644 --- a/tensorflow/src/TensorFlow/Nodes.hs +++ b/tensorflow/src/TensorFlow/Nodes.hs @@ -20,6 +20,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} -- For Fetchable (TensorExpr a) module TensorFlow.Nodes where import Control.Applicative (liftA2, liftA3) @@ -28,7 +29,6 @@ import Data.Map.Strict (Map) import Data.Monoid ((<>)) import Data.Set (Set) import Data.Text (Text) -import Lens.Family2 ((^.)) import qualified Data.Map.Strict as Map import qualified Data.Set as Set @@ -90,7 +90,7 @@ instance Fetchable t a => Fetchable [t] [a] where getFetch ts = sequenceA <$> mapM getFetch ts instance Nodes ControlNode where - getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o + getNodes (ControlNode o) = pure $ Set.singleton o -- We use the constraint @(a ~ ())@ to help with type inference. For example, -- if @t :: ControlNode@, then this constraint ensures that @run t :: Session @@ -113,13 +113,13 @@ instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity) getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs instance Nodes (Tensor v a) where - getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp) + getNodes (Tensor o) = Set.singleton . outputNodeName <$> toBuild o -fetchTensorVector :: forall a v . TensorType a +fetchTensorVector :: forall a v . (TensorType a) => Tensor v a -> Build (Fetch (TensorData a)) -fetchTensorVector (Tensor _ o) = do - outputName <- renderOutput o - return $ Fetch (Set.singleton outputName) $ \tensors -> +fetchTensorVector (Tensor o) = do + outputName <- encodeOutput <$> toBuild o + pure $ Fetch (Set.singleton outputName) $ \tensors -> let tensorData = tensors Map.! outputName expectedType = tensorType (undefined :: a) actualType = FFI.tensorDataType tensorData diff --git a/tensorflow/src/TensorFlow/Output.hs b/tensorflow/src/TensorFlow/Output.hs index 9ee31c8..ef6b489 100644 --- a/tensorflow/src/TensorFlow/Output.hs +++ b/tensorflow/src/TensorFlow/Output.hs @@ -22,8 +22,6 @@ module TensorFlow.Output , Device(..) -- * Ops , NodeName(..) - , Op(..) - , opUnrendered , OpDef(..) , opName , opType @@ -34,28 +32,24 @@ module TensorFlow.Output , OutputIx(..) , Output(..) , output - , outputIndex - , outputOp , PendingNodeName(..) , ResourceHandle(..) ) where import qualified Data.Map.Strict as Map -import Data.ProtoLens.TextFormat (showMessage) import Data.String (IsString(..)) import Data.Text (Text) import qualified Data.Text as Text -import Lens.Family2 (Lens', Traversal', (.~), (&), (^.)) +import Lens.Family2 (Lens') import Lens.Family2.Unchecked (lens) import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..)) -import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name) import Data.Default (def) import TensorFlow.Types (Attribute, attrLens) import TensorFlow.Orphans () -- | A type of graph node which has no outputs. These nodes are -- valuable for causing side effects when they are run. -newtype ControlNode = ControlNode { unControlNode :: Op } +newtype ControlNode = ControlNode { unControlNode :: NodeName } -- | The type of op of a node in the graph. This corresponds to the proto field -- NodeDef.op. @@ -66,18 +60,12 @@ instance IsString OpType where fromString = OpType . Text.pack -- | An output of a TensorFlow node. -data Output = Output !OutputIx !Op +data Output = Output {outputIndex :: !OutputIx, outputNodeName :: !NodeName} deriving (Eq, Ord, Show) -output :: OutputIx -> Op -> Output +output :: OutputIx -> NodeName -> Output output = Output -outputOp :: Lens' Output Op -outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o) - -outputIndex :: Lens' Output OutputIx -outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o) - newtype OutputIx = OutputIx { unOutputIx :: Int } deriving (Eq, Ord, Num, Enum, Show) @@ -90,25 +78,6 @@ newtype Device = Device {deviceName :: Text} instance Show Device where show (Device d) = show d --- | The representation of a node in a TensorFlow graph. -data Op - = Rendered !NodeDef -- ^ Properties are fixed, including the - -- device, name, and scope. - | Unrendered !OpDef -- ^ Properties are not fixed, and may change depending - -- on which context this op is rendered in. - deriving (Eq, Ord) - -instance Show Op where - show (Rendered n) = "Rendered " ++ showMessage n - show (Unrendered o) = "Unrendered " ++ show (o ^. opName) - --- | Traverse on the 'Unrendered' of an 'Op'. --- --- Same implementation as _Left. -opUnrendered :: Traversal' Op OpDef -opUnrendered f (Unrendered a) = Unrendered <$> f a -opUnrendered _ (Rendered b) = pure (Rendered b) - -- | Op definition. This corresponds somewhat to the 'NodeDef' proto. data OpDef = OpDef { _opName :: !PendingNodeName @@ -157,7 +126,7 @@ instance IsString Output where (n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr -> Output (fromInteger ix) $ assigned n _ -> Output 0 $ assigned s - where assigned n = Rendered $ def & name .~ Text.pack n + where assigned = NodeName . Text.pack -- | Opaque handle to a mutable resource in the graph. Typical such diff --git a/tensorflow/src/TensorFlow/Session.hs b/tensorflow/src/TensorFlow/Session.hs index bc24e3d..199d7f5 100644 --- a/tensorflow/src/TensorFlow/Session.hs +++ b/tensorflow/src/TensorFlow/Session.hs @@ -163,7 +163,7 @@ runWithFeeds feeds t = do runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a runFetchWithFeeds feeds target (Fetch fetch restore) = do extend - feeds' <- build $ fixFeeds feeds + let feeds' = fixFeeds feeds let fetchNames = encodeUtf8 <$> Set.toList fetch targetNames = toNodeNames $ Set.toList target session <- Session (asks rawSession) @@ -192,8 +192,8 @@ runWithFeeds_ feeds t = do ns <- build $ getNodes t runFetchWithFeeds feeds ns (pure ()) -fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)] -fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o +fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)] +fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d) -- | Starts a concurrent thread which evaluates the given Nodes -- forever until runSession exits or an exception occurs. Graph diff --git a/tensorflow/src/TensorFlow/Tensor.hs b/tensorflow/src/TensorFlow/Tensor.hs index e353eb6..bf3fc1f 100644 --- a/tensorflow/src/TensorFlow/Tensor.hs +++ b/tensorflow/src/TensorFlow/Tensor.hs @@ -16,21 +16,26 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} -- For the Render class module TensorFlow.Tensor where +import Data.ByteString (ByteString) import Data.String (IsString(..)) import qualified Data.Text as Text -import Lens.Family2 (Lens', (^.)) -import Lens.Family2.Unchecked (lens) +import Lens.Family2 ((^.)) +import Lens.Family2.State ((%=), use) -import TensorFlow.Output (Output) +import Proto.Tensorflow.Core.Framework.NodeDef (device) +import TensorFlow.Build +import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..)) import TensorFlow.Types ( TensorData(..) , ListOf(..) @@ -40,52 +45,149 @@ import qualified TensorFlow.Internal.FFI as FFI -- | A named output of a TensorFlow operation. -- -- The type parameter @a@ is the type of the elements in the 'Tensor'. The --- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is --- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a --- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via --- 'value'. -data Tensor v a = Tensor (TensorKind v) Output +-- parameter @v@ is either: +-- +-- * 'Build': An unrendered, immutable value. +-- * 'Value': A rendered, immutable value. +-- * 'Ref': A rendered stateful handle (e.g., a variable). +-- +-- Note that 'expr', 'value', 'render' and 'renderValue' can help convert between +-- the different types of 'Tensor'. +data Tensor v a where + Tensor :: TensorKind v => {tensorOutput :: v Output} -> Tensor v a -data Value -data Ref +newtype Value a = Value {runValue :: a} + deriving Functor --- | This class provides a runtime switch on whether a 'Tensor' should be --- treated as a 'Value' or as a 'Ref'. -data TensorKind v where - ValueKind :: TensorKind Value - RefKind :: TensorKind Ref +instance Applicative Value where + pure = Value + Value f <*> Value x = Value $ f x -tensorKind :: Lens' (Tensor v a) (TensorKind v) -tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o) +instance Monad Value where + f >>= g = g $ runValue f -tensorOutput :: Lens' (Tensor v a) Output -tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o) +newtype Ref a = Ref {runRef :: a} + deriving Functor --- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a --- Ref into Value. This behaves like a no-op. -value :: Tensor v a -> Tensor Value a -value (Tensor _ o) = Tensor ValueKind o +instance Applicative Ref where + pure = Ref + Ref f <*> Ref x = Ref $ f x + +instance Monad Ref where + f >>= g = g $ runRef f + +-- | Cast a 'Tensor Ref' into a 'Tensor Value'. This behaves like a no-op. +value :: Tensor Ref a -> Tensor Value a +value (Tensor o) = Tensor $ Value $ runRef o + +renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a) +renderValue (Tensor o) = render $ Tensor $ toBuild o -- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor' -- when running the graph. data Feed = Feed Output FFI.TensorData +-- | A class ensuring that a given tensor is rendered, i.e., has a fixed +-- name, device, etc. +class TensorKind v => Rendered v where + rendered :: v a -> a + +instance Rendered Value where + rendered = runValue + +instance Rendered Ref where + rendered = runRef + +renderedOutput :: Rendered v => Tensor v a -> Output +renderedOutput = rendered . tensorOutput + +tensorNodeName :: Rendered v => Tensor v a -> NodeName +tensorNodeName = outputNodeName . renderedOutput + + -- | Create a 'Feed' for feeding the given data into a 'Tensor' when running -- the graph. -- -- Note that if a 'Tensor' is rendered, its identity may change; so feeding the -- rendered 'Tensor' may be different than feeding the original 'Tensor'. -feed :: Tensor v a -> TensorData a -> Feed -feed (Tensor _ o) (TensorData td) = Feed o td +feed :: Rendered v => Tensor v a -> TensorData a -> Feed +feed t (TensorData td) = Feed (renderedOutput t) td -- | Create a 'Tensor' for a given name. This can be used to reference nodes -- in a 'GraphDef' that was loaded via 'addGraphDef'. -- TODO(judahjacobson): add more safety checks here. -tensorFromName :: TensorKind v -> Text.Text -> Tensor v a -tensorFromName v = Tensor v . fromString . Text.unpack +tensorFromName :: TensorKind v => Text.Text -> Tensor v a +tensorFromName = Tensor . pure . fromString . Text.unpack + +-- | Like 'tensorFromName', but type-restricted to 'Value'. +tensorValueFromName :: Text.Text -> Tensor Value a +tensorValueFromName = tensorFromName + +-- | Like 'tensorFromName', but type-restricted to 'Ref'. +tensorRefFromName :: Text.Text -> Tensor Ref a +tensorRefFromName = tensorFromName type TensorList v = ListOf (Tensor v) -tensorListOutputs :: TensorList v as -> [Output] +tensorListOutputs :: Rendered v => TensorList v as -> [Output] tensorListOutputs Nil = [] -tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts +tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts + +-- | Places all nodes rendered in the given 'Build' action on the same +-- device as the given Tensor (see also 'withDevice'). Make sure that +-- the action has side effects of rendering the desired tensors. A pure +-- return would not have the desired effect. +colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a +colocateWith t x = do + d <- build $ Device . (^. device) + <$> lookupNode (outputNodeName $ renderedOutput t) + withDevice (Just d) x + + +-- | Render a 'Tensor', fixing its name, scope, device and control inputs from +-- the 'MonadBuild' context. Also renders any dependencies of the 'Tensor' that +-- weren't already rendered. +-- +-- This operation is idempotent; calling 'render' on the same input in the same +-- context will produce the same result. However, rendering the same +-- @Tensor Build@ in two different contexts may result in two different +-- @Tensor Value@s. +render :: MonadBuild m => Tensor Build a -> m (Tensor Value a) +render (Tensor t) = Tensor . Value <$> build t + +-- TODO: better name. +expr :: TensorKind v => Tensor v a -> Tensor Build a +expr (Tensor o) = Tensor $ toBuild o + +-- | Records the given summary action in Build for retrieval with +-- Summary protocol buffer in string form. For safety, use the +-- pre-composed functions: Logging.scalarSummary and +-- Logging.histogramSummary. +addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString -- ^ A 'SummaryTensor' + -> m () +addSummary t = build $ do + -- TODO: more generic way + o <- toBuild $ tensorOutput t + summaries %= (o :) + +-- | Retrieves the summary ops collected thus far. Typically this only +-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used +-- repeatedly, the values accumulate. +collectAllSummaries :: MonadBuild m => m [SummaryTensor] +collectAllSummaries = build $ map (Tensor . Value) <$> use summaries + +-- | Synonym for the tensors that return serialized Summary proto. +type SummaryTensor = Tensor Value ByteString + +-- | An internal class for kinds of Tensors. +class Monad v => TensorKind v where + toBuild :: v a -> Build a + +instance TensorKind Value where + toBuild = return . rendered + +instance TensorKind Ref where + toBuild = return . rendered + +instance TensorKind Build where + toBuild = id