diff --git a/README.md b/README.md index 62793f7..8ebb889 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ fit xData yData = TF.runSession $ do return (w', b') gradientDescent :: Float - -> TF.Tensor TF.Value Float + -> TF.Tensor TF.Build Float -> [TF.Tensor TF.Ref Float] -> TF.Session TF.ControlNode gradientDescent alpha loss params = do diff --git a/tensorflow-logging/src/TensorFlow/Logging.hs b/tensorflow-logging/src/TensorFlow/Logging.hs index da1c29f..f0d5439 100644 --- a/tensorflow-logging/src/TensorFlow/Logging.hs +++ b/tensorflow-logging/src/TensorFlow/Logging.hs @@ -67,10 +67,10 @@ import Proto.Tensorflow.Core.Framework.Summary (Summary) import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime) import System.Directory (createDirectoryIfMissing) import System.FilePath (()) -import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries) +import TensorFlow.Build (MonadBuild) import TensorFlow.Ops (scalar) import TensorFlow.Records.Conduit (sinkTFRecords) -import TensorFlow.Tensor (Tensor) +import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries) import TensorFlow.Types (TensorType, type(/=)) import Text.Printf (printf) import qualified Data.ByteString.Lazy as L @@ -141,19 +141,19 @@ doubleWallTime = asDouble <$> getCurrentTime -- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally -- limited to a single value for simplicity. histogramSummary :: - (TensorType t, t /= ByteString, t /= Bool) + (MonadBuild m, TensorType t, t /= ByteString, t /= Bool) -- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) - => ByteString -> Tensor v t -> Build () + => ByteString -> Tensor v t -> m () histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag) -- | Adds a 'CoreOps.scalarSummary' node. scalarSummary :: - (TensorType t, t /= ByteString, t /= Bool) + (TensorType t, t /= ByteString, t /= Bool, MonadBuild m) -- (TensorType t, -- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) - => ByteString -> Tensor v t -> Build () + => ByteString -> Tensor v t -> m () scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag) -- | Merge all summaries accumulated in the 'Build' into one summary. -mergeAllSummaries :: Build SummaryTensor +mergeAllSummaries :: MonadBuild m => m SummaryTensor mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary diff --git a/tensorflow-mnist/app/Main.hs b/tensorflow-mnist/app/Main.hs index 5475469..ccb032b 100644 --- a/tensorflow-mnist/app/Main.hs +++ b/tensorflow-mnist/app/Main.hs @@ -34,13 +34,13 @@ numPixels = 28*28 :: Int64 numLabels = 10 :: Int64 -- | Create tensor with random values where the stddev depends on the width. -randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float) +randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Build Float) randomParam width (TF.Shape shape) = - (* stddev) <$> TF.truncatedNormal (TF.vector shape) + (`TF.mul` stddev) <$> TF.truncatedNormal (TF.vector shape) where stddev = TF.scalar (1 / sqrt (fromIntegral width)) -reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float +reduceMean :: TF.Tensor TF.Build Float -> TF.Tensor TF.Build Float reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32)) -- Types must match due to model structure. @@ -87,7 +87,7 @@ createModel = do grads <- TF.gradients loss params let lr = TF.scalar 0.00001 - applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad) + applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad) trainStep <- TF.group =<< zipWithM applyGrad params grads let correctPredictions = TF.equal predict labels diff --git a/tensorflow-mnist/tests/ParseTest.hs b/tensorflow-mnist/tests/ParseTest.hs index 95e8f11..9a0f4db 100644 --- a/tensorflow-mnist/tests/ParseTest.hs +++ b/tensorflow-mnist/tests/ParseTest.hs @@ -37,15 +37,15 @@ import TensorFlow.Examples.MNIST.TrainedGraph import TensorFlow.Build ( asGraphDef , addGraphDef - , render + , Build ) import TensorFlow.Tensor ( Tensor(..) , Ref - , Value , feed - , TensorKind(..) + , render , tensorFromName + , tensorValueFromName ) import TensorFlow.Ops import TensorFlow.Session @@ -80,7 +80,7 @@ testReadMNIST = testCase "testReadMNIST" $ do labelData <- readMNISTLabels =<< testLabelData 10000 @=? length labelData -testNodeName :: Text -> Tensor v a -> Assertion +testNodeName :: Text -> Tensor Build a -> Assertion testNodeName n g = n @=? opName where opName = head (gDef^.node)^.op @@ -89,7 +89,7 @@ testNodeName n g = n @=? opName testGraphDefGen :: Test testGraphDefGen = testCase "testGraphDefGen" $ do -- Test the inferred operation type. - let f0 :: Tensor Value Float + let f0 :: Tensor Build Float f0 = 0 testNodeName "Const" f0 testNodeName "Add" $ 1 + f0 @@ -109,7 +109,7 @@ testGraphDefExec = testCase "testGraphDefExec" $ do let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10 runSession $ do addGraphDef graphDef - x <- run $ tensorFromName ValueKind "Mul_2" + x <- run $ tensorValueFromName "Mul_2" liftIO $ (50 :: Float) @=? unScalar x -- | Load MNIST from a GraphDef and the weights from a checkpoint and run on @@ -142,8 +142,8 @@ testMNISTExec = testCase "testMNISTExec" $ do build $ addGraphDef $ mnist & version .~ 0 -- Define nodes that restore saved weights and biases. let bias, wts :: Tensor Ref Float - bias = tensorFromName RefKind "Variable" - wts = tensorFromName RefKind "weights" + bias = tensorFromName "Variable" + wts = tensorFromName "weights" wtsCkptPath <- liftIO wtsCkpt biasCkptPath <- liftIO biasCkpt -- Run those restoring nodes on the graph in the current session. @@ -155,12 +155,12 @@ testMNISTExec = testCase "testMNISTExec" $ do let ty = encodeTensorData [10] oneHotLabels where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates updates = [(fromIntegral label, 1)] - let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample - , feed (tensorFromName ValueKind "y-input") ty + let feeds = [ feed (tensorValueFromName "x-input") tensorSample + , feed (tensorValueFromName "y-input") ty ] -- Run the graph with the input feeds and read the ArgMax'd result from -- the test (not training) side of the evaluation. - x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax" + x <- runWithFeeds feeds $ tensorValueFromName "test/ArgMax" -- Print the trained model's predicted outcome. liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n" ++ "Prediction: " ++ show (unScalar x :: Int64) diff --git a/tensorflow-nn/src/TensorFlow/NN.hs b/tensorflow-nn/src/TensorFlow/NN.hs index 0ae1c05..9baaeae 100644 --- a/tensorflow-nn/src/TensorFlow/NN.hs +++ b/tensorflow-nn/src/TensorFlow/NN.hs @@ -24,7 +24,6 @@ import Prelude hiding ( log , exp ) import TensorFlow.Build ( MonadBuild - , render , withNameScope ) import TensorFlow.GenOps.Core ( greaterEqual @@ -33,6 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual , exp ) import TensorFlow.Tensor ( Tensor(..) + , render , Value ) import TensorFlow.Types ( TensorType(..) @@ -40,6 +40,8 @@ import TensorFlow.Types ( TensorType(..) ) import TensorFlow.Ops ( zerosLike , add + , mul + , neg ) -- | Computes sigmoid cross entropy given `logits`. @@ -76,13 +78,11 @@ sigmoidCrossEntropyWithLogits -> Tensor Value a -- ^ __targets__ -> m (Tensor Value a) sigmoidCrossEntropyWithLogits logits targets = do - logits' <- render logits - targets' <- render targets - let zeros = zerosLike logits' - cond = logits' `greaterEqual` zeros - relu_logits = select cond logits' zeros - neg_abs_logits = select cond (-logits') logits' + let zeros = zerosLike logits + cond = logits `greaterEqual` zeros + relu_logits = select cond logits zeros + neg_abs_logits = select cond (neg logits) logits withNameScope "logistic_loss" $ do - left <- render $ relu_logits - logits' * targets' + left <- render $ relu_logits - logits `mul` targets right <- render $ log (1 + exp neg_abs_logits) withNameScope "sigmoid_add" $ render $ left `add` right diff --git a/tensorflow-nn/tests/NNTest.hs b/tensorflow-nn/tests/NNTest.hs index d91dd70..4051212 100644 --- a/tensorflow-nn/tests/NNTest.hs +++ b/tensorflow-nn/tests/NNTest.hs @@ -23,10 +23,9 @@ import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) import qualified Data.Vector as V import qualified TensorFlow.Gradient as TF -import qualified TensorFlow.Nodes as TF import qualified TensorFlow.NN as TF import qualified TensorFlow.Ops as TF -import qualified TensorFlow.Session as TF +import qualified TensorFlow.Core as TF -- | These tests are ported from: -- @@ -60,12 +59,11 @@ defInputs = Inputs { testLogisticOutput :: Test testLogisticOutput = testCase "testLogisticOutput" $ do let inputs = defInputs - vLogits = TF.vector $ logits inputs - vTargets = TF.vector $ targets inputs - tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets - ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) - - r <- run tfLoss + r <- run $ do + vLogits <- TF.render $ TF.vector $ logits inputs + vTargets <- TF.render $ TF.vector $ targets inputs + TF.sigmoidCrossEntropyWithLogits vLogits vTargets + let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) assertAllClose r ourLoss @@ -74,23 +72,22 @@ testLogisticOutputMultipleDim = testCase "testLogisticOutputMultipleDim" $ do let inputs = defInputs shape = [2, 2, 2] - vLogits = TF.constant shape (logits inputs) - vTargets = TF.constant shape (targets inputs) - tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets - ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) - - r <- run tfLoss + r <- run $ do + vLogits <- TF.render $ TF.constant shape (logits inputs) + vTargets <- TF.render $ TF.constant shape (targets inputs) + TF.sigmoidCrossEntropyWithLogits vLogits vTargets + let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) assertAllClose r ourLoss testGradientAtZero :: Test testGradientAtZero = testCase "testGradientAtZero" $ do - let inputs = defInputs { logits = [0, 0], targets = [0, 1] } - vLogits = TF.vector $ logits inputs - vTargets = TF.vector $ targets inputs - tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets - r <- run $ do + let inputs = defInputs { logits = [0, 0], targets = [0, 1] } + vTargets <- TF.render $ TF.vector $ targets inputs + vLogits <- TF.render $ TF.vector $ logits inputs + let tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets + l <- tfLoss TF.gradients l [vLogits] diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index ac1e02d..c512168 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -231,21 +231,24 @@ renderHaskellAttrName :: Attr a -> Doc renderHaskellAttrName = renderHaskellName . attrName functionBody :: ParsedOp -> Doc -functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts)) - indent indentation (sep tensorArgs) +functionBody pOp + | parsedOpIsMonadic pOp + = "build $ do" + indent indentation (bindOpInputsVar + "buildOp" <+> outputListsSizes <+> opDef) + | otherwise + = "pureOp" <+> outputListsSizes <+> "$ do" + indent indentation (bindOpInputsVar "return" <+> opDef) where - maybeLift - | parsedOpIsMonadic pOp = "build $" - | otherwise = "" - buildFunction - | null outputListsSizes = "buildOp" - | otherwise = "buildListOp" <+> - brackets (commasep $ - map renderHaskellName outputListsSizes) - outputListsSizes = [ a - | ParsedArg { parsedArgCase = ListArg { argLength = a } } - <- parsedOutputs pOp] - buildOpParts = + outputListsSizes = brackets $ commasep + [ renderHaskellName a + | ParsedArg { parsedArgCase = ListArg { argLength = a } } + <- parsedOutputs pOp + ] + opInputsVar = "op'inputs" + bindOpInputsVar = opInputsVar <+> "<- fmap Prelude.concat $ Prelude.sequence" + <+> brackets (commasep $ map (\a -> "buildInputs" <+> a) tensorArgs) + opDef = parens $ hang 0 $ stack $ "opDef" <+> renderQuotedTFName (parsedOpName pOp) : -- Renders type parameter arguments. [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a @@ -259,10 +262,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n | a <- inferredListSizeAttrs pOp, let n = attrName a ] ++ - ["& op'options"] - - - tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp + ["& op'options & opInputs .~" <+> opInputsVar] + tensorArgs = renderTensorArg <$> parsedInputs pOp + renderTensorArg = renderHaskellName . parsedArgName inferredTypeExpr a | typeParamIsList $ attrInfo a = "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a @@ -296,7 +298,7 @@ typeSig pre pOp = constraints | null classConstraints = empty | otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>" typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, - Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]] + Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]] ++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp] ++ if parsedOpIsMonadic pOp then ["m'"] else [] -- Use m' as the type parameter to avoid clashing with an attribute name. @@ -327,7 +329,7 @@ typeSig pre pOp = constraints wrapOutput o | parsedOpIsMonadic pOp = "m'" <+> parens o | otherwise = o - + -- | Render an op input or output. -- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle" tensorArg :: ParsedArg -> Doc @@ -336,12 +338,13 @@ tensorArg p = case parsedArgCase p of SimpleArg { argType = t, argCaseKind = k } -> tensorType t k ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k MixedListArg {argTypeAttr = t, argCaseKind = k} - -> "TensorList" <+> kind k <+> renderHaskellName t + -> "TensorList" <+> parens (kind k) <+> renderHaskellName t where kind k = case k of ArgTensorRef -> "Ref" ArgTensorValue -> "Value" - ArgTensorEither v' -> strictText v' + ArgTensorBuild -> "Build" + ArgSomeTensor v -> strictText v tensorType t k = let a = case t of ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt @@ -350,7 +353,7 @@ tensorArg p = case parsedArgCase p of attrComment :: Attr a -> Doc attrComment a = argComment' (attrName a) (attrDescription a) - + argComment :: ParsedArg -> Doc argComment a = argComment' (parsedArgName a) (parsedArgDescription a) @@ -364,7 +367,7 @@ bold n = "__" <> n <> "__" -- | Comment for the outputs of an op. -- For example: -- -- ^ (__output1__, __output2__) --- -- +-- -- -- -- * __output1__: description1 -- -- -- -- * __output2__: description2 diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs index c7a5b69..2fbe4e4 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs @@ -141,7 +141,8 @@ data ArgType data ArgKind = ArgTensorRef -- Tensor Ref a | ArgTensorValue -- Tensor Value a - | ArgTensorEither Text -- Tensor v a; the Text is the variable `v` + | ArgTensorBuild -- Tensor Build a + | ArgSomeTensor Text -- Tensor v a; the Text is the variable 'v'. deriving (Eq) isRefCase :: ParsedArgCase -> Bool @@ -219,15 +220,17 @@ parseOp o = ParsedOp { parsedOpName = makeName $ o ^. name , parsedOpSummary = o ^. summary , parsedOpDescription = o ^. description - , parsedOpIsMonadic = o ^. isStateful - || any (isRefCase . parsedArgCase) parsedInputs , .. } where - parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v)) - (o ^. inputArg) tensorKindParams - tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]] - parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg) + parsedOpIsMonadic = o ^. isStateful + || any (isRefCase . parsedArgCase) parsedInputs + || null (o ^. outputArg) + parsedInputs = zipWith (\t a -> parseArg a (inputTensorKind t a)) + tensorKindParams (o ^. inputArg) + tensorKindParams = ["v'" <> Text.pack (show x) | x <- [1::Integer ..]] + parsedOutputs = map (\a -> parseArg a (outputTensorKind parsedOpIsMonadic a)) + (o ^. outputArg) -- Integer attributes that can be inferred from the size of at least one -- input list. inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs) @@ -246,15 +249,16 @@ parseOp o = ParsedOp $ o ^. attr -- TODO(judahjacobson): Some arguments should be refs. -inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind -inputTensorKind a v +inputTensorKind :: Text -> OpDef'ArgDef -> ArgKind +inputTensorKind v a | a ^. isRef = ArgTensorRef - | otherwise = ArgTensorEither v + | otherwise = ArgSomeTensor v -outputTensorKind :: OpDef'ArgDef -> ArgKind -outputTensorKind a +outputTensorKind :: Bool -> OpDef'ArgDef -> ArgKind +outputTensorKind isMonadic a | a ^. isRef = ArgTensorRef - | otherwise = ArgTensorValue + | isMonadic = ArgTensorValue + | otherwise = ArgTensorBuild getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType getExplicitInputAttr o implicitAttrs a diff --git a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs index df1866f..7b9e7fc 100644 --- a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs +++ b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs @@ -24,9 +24,9 @@ module TensorFlow.EmbeddingOps where import Control.Monad (zipWithM) import Data.Int (Int32, Int64) -import TensorFlow.Build (MonadBuild, colocateWith, render) +import TensorFlow.Build (MonadBuild) import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor -import TensorFlow.Tensor (Tensor, Value) +import TensorFlow.Tensor (Tensor, Value, Rendered, colocateWith, render) import TensorFlow.Types (OneOf, TensorType) import qualified TensorFlow.GenOps.Core as CoreOps @@ -44,17 +44,18 @@ import qualified TensorFlow.GenOps.Core as CoreOps -- -- The results of the lookup are concatenated into a dense -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. -embeddingLookup :: forall a b v m . +embeddingLookup :: forall a b v1 v2 m . ( MonadBuild m + , Rendered v1 , TensorType a , OneOf '[Int64, Int32] b , Num b ) - => [Tensor v a] + => [Tensor v1 a] -- ^ A list of tensors which can be concatenated along -- dimension 0. Each `Tensor` must be appropriately -- sized for `mod` partition strategy. - -> Tensor Value b + -> Tensor v2 b -- ^ A `Tensor` with type `int32` or `int64` -- containing the ids to be looked up in `params`. -- The ids are required to have fewer than 2^31 diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index 099a196..52467c8 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -31,6 +31,7 @@ import Data.ByteString (ByteString) import Data.Complex (Complex) import Data.Default (def) import Data.Int (Int32, Int64) +import Data.Foldable (foldlM) import Data.List (foldl', sortBy) import Data.Map.Strict (Map) import Data.Maybe (fromMaybe, maybeToList, mapMaybe) @@ -39,7 +40,7 @@ import Data.ProtoLens.TextFormat (showMessage) import Data.Set (Set) import Data.Text (Text) import Data.Tuple (swap) -import Lens.Family2 (Lens', (&), (^.), (.~), (%~)) +import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~)) import Lens.Family2.State.Strict (uses) import Lens.Family2.Stock (at, intAt) import Lens.Family2.Unchecked (lens, iso) @@ -59,11 +60,10 @@ import TensorFlow.Build ( MonadBuild , Build , build - , render - , renderNodeName , renderedNodeDefs , opDef , opAttr + , opInputs ) import TensorFlow.BuildOp import TensorFlow.Ops @@ -86,16 +86,19 @@ import TensorFlow.Ops ) import TensorFlow.Output ( NodeName(..) - , Op (Rendered) , Output(..) , OutputIx(..) , outputIndex ) import TensorFlow.Tensor ( Tensor(..) - , TensorKind (ValueKind) , Value - , tensorOutput + , render + , expr + , Rendered + , tensorNodeName + , renderedOutput + , renderValue ) import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens) import Proto.Tensorflow.Core.Framework.NodeDef @@ -114,10 +117,7 @@ type GradientCompatible a = -- | Gradient of @y@ w.r.t. each element of @xs@. gradients :: forall a v1 v2 m . (MonadBuild m - , Num (Tensor v1 a) - -- TODO(gnezdo): remove indirect constraint. - -- It's a wart inherited from Num instance. - , v1 ~ Value + , Rendered v2 , GradientCompatible a ) => Tensor v1 a -- ^ The output of the graph. @@ -150,27 +150,31 @@ gradients y xs = build $ do -- -- 4. Lookup the recorded gradient for each x in xs. - yName <- renderNodeName y + y' <- renderValue y + let yName = tensorNodeName y' + yOne <- render $ fill (shape y') (scalar 1) -- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName? nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $ (\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x)) . flip Map.lookup let (gr, nodeMap) = createGraph yName nodeDefLookup -- Set gradient of y to one. + -- TODO: nicer let initPending :: Map.Map FGL.Node (PendingGradients a) - initPending = Map.empty & at (nodeMap Map.! yName) + = Map.empty & (at (nodeMap Map.! yName) . nonEmpty - . outputIxAt (y ^. tensorOutput . outputIndex) + . outputIxAt (outputIndex $ renderedOutput y') . nonEmpty - .~ [fill (shape y) (scalar 1)] + .~ [yOne] + ) -- Calculate the gradients of y w.r.t. each node in the graph. gradientMap <- graphGrads gr initPending -- Lookup the gradients for each x. - forM xs $ \x -> do - xName <- renderNodeName x - render $ fromMaybe (zerosLike x) $ do + forM xs $ \x -> + let xName = tensorNodeName x + in maybe (render $ zerosLike x) return $ do n <- nodeMap ^. at xName - let i = x ^. tensorOutput . outputIndex + let i = outputIndex $ renderedOutput x gradientMap ^. at n . nonEmpty . outputIxAt i outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v) @@ -182,6 +186,7 @@ outputIxAt = intAt . unOutputIx type PendingGradients a = IntMap.IntMap [Tensor Value a] -- | Gradients of a node's outputs. The key is an OutputIx sans newtype. +-- TODO: precache the rendering? type Gradients a = IntMap.IntMap (Tensor Value a) -- | Graph of TensorFlow operations. @@ -229,42 +234,43 @@ non a = anon a (a==) nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v) nonEmpty = anon mempty null +-- TODO: strictness (e.g., foldlM') + -- | Calculate the gradients for every node in a graph. graphGrads :: forall a. GradientCompatible a => Graph -> Map FGL.Node (PendingGradients a) -- ^ Initial gradients (usually just 1 for the node of interest). -> Build (Map FGL.Node (Gradients a)) -graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult) +graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder where initState = GradientsState initPending Map.empty -- Reverse topological sort. -- TODO(fmayle): Filter out nodes that are not successors of any x in xs to -- avoid calculating gradients that won't be used. nodeOrder = FGL.topsort $ FGL.grev gr - go state node = + go :: GradientsState a -> Int -> Build (GradientsState a) + go state node = do -- Aggregate the accumulated gradients for this node. - let outputGrads = + outputGrads <- sumPendingGradient (state ^. gradientsPending . at node . nonEmpty) - in if null outputGrads - then state - else + if null outputGrads + then pure state + else do + let ctx = FGL.context gr node + inputGrads <- calculateInputGrads ctx outputGrads gr -- Calculate the gradients for each of the node's inputs. let nextState = state & gradientsResult %~ Map.insert node outputGrads - ctx = FGL.context gr node - in updatePendingGradients - ctx - (calculateInputGrads ctx outputGrads gr) - nextState + pure $ updatePendingGradients ctx inputGrads nextState -- | Reduce accumulated gradients for each output to one Tensor. sumPendingGradient :: GradientCompatible a - => PendingGradients a -> Gradients a -sumPendingGradient = IntMap.mapMaybe f + => PendingGradients a -> Build (Gradients a) +sumPendingGradient = sequence . IntMap.mapMaybe f where f [] = Nothing - f [x] = Just x - f xs = Just (addN xs) + f [x] = Just (pure x) + f xs = Just (render $ addN xs) -- | Calculate the gradients of a node's input tensors. @@ -274,18 +280,18 @@ calculateInputGrads :: forall a. GradientCompatible a => FGL.Context NodeDef EdgeLabel -> Gradients a -- ^ Output gradients of the node. -> Graph - -> [Maybe (Tensor Value a)] -calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = - opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads + -> Build [Maybe (Tensor Value a)] +calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = do + fullOutGrads <- fullOutputGrads (numOutputs nodeDef) (nodeDefName nodeDef) + outputGrads + traverse (traverse render) $ opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads where - fullOutGrads = - fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads -- Create a tensor from an edge (technically an Output, but it seems less -- confusing to refer to it as a tensor here). edgeToTensor :: (EdgeLabel, FGL.Node) -> Output edgeToTensor ((i, _), n) = case FGL.lab gr n of - Just edgeNodeDef -> Output i (Rendered edgeNodeDef) + Just edgeNodeDef -> Output i (NodeName $ edgeNodeDef ^. name) Nothing -> error $ "calculateInputGrads: missing input node for " ++ Text.unpack (nodeDef ^. name) -- Input tensors, sorted by input index. @@ -294,11 +300,11 @@ calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = -- | Convert a Map of gradients to a list, with zeros for missing outputs. fullOutputGrads :: (TensorType a, Num a) => OutputIx -- ^ Number of outputs. - -> Op + -> NodeName -> Gradients a - -> [Tensor Value a] + -> Build [Tensor Value a] fullOutputGrads n o gs = - map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1] + mapM (\i -> maybe (render $ zero i) return (gs ^. outputIxAt i)) [0..n-1] where -- A tensor of zeros with the same shape as the i'th output. zero i = zerosLike $ toT (Output i o) @@ -397,19 +403,19 @@ type GradientFunc a = NodeDef -- ^ Input tensors. -> [Tensor Value a] -- ^ Gradient of y w.r.t. each output tensor. - -> [Maybe (Tensor Value a)] + -> [Maybe (Tensor Build a)] -- ^ Gradient of y w.r.t. each input tensor. -- TODO(fmayle): Assert the type is correct. -- | Create a Tensor from an Output. -toT :: Output -> Tensor Value a -toT = Tensor ValueKind +toT :: Output -> Tensor Build a +toT = Tensor . pure -- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for -- simple slicing operations. -flatSlice :: forall v1 t . (TensorType t) +flatSlice :: forall v1 t . TensorType t => Tensor v1 t -- ^ __input__ -> Int32 -- ^ __begin__: specifies the offset into the first dimension of -- 'input' to slice from. @@ -417,9 +423,12 @@ flatSlice :: forall v1 t . (TensorType t) -- of 'input' to slice. If size is -1, all remaining elements in the dimension -- are included in the slice (i.e. this is equivalent to setting -- size = input.dim_size(0) - begin). - -> Tensor Value t -- ^ __output__ + -> Tensor Build t -- ^ __output__ flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size]) +nodeDefName :: NodeDef -> NodeName +nodeDefName = NodeName . view name + -- | The gradient function for an op type. -- @@ -427,8 +436,8 @@ flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size]) -- third_party/tensorflow/python/ops/*_grad.py opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a -opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x] -opGrad "Neg" _ [_] [dz] = [Just $ -dz] +opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x] +opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz] opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x] opGrad "Square" _ [toT -> x] [dz] = @@ -436,7 +445,7 @@ opGrad "Square" _ [toT -> x] [dz] = -- TODO(fmayle): The python code makes dz a control dependency of the 2*x -- (for performance reasons?). Will need to put these functions in the Build -- monad to replicate that. - [Just $ dz * (2 * x)] + [Just $ dz `CoreOps.mul` (2 * x)] opGrad "Gather" _ [toT -> x, toT -> indices] [dz] = -- TODO(fmayle): The python version uses a better performance implementation @@ -448,20 +457,20 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] = ] where -- TODO(gnezdo): Use colocateWith but it requires Build monad. - denseShape = shape (x :: Tensor Value a) + denseShape = shape (x :: Tensor Build a) numRows = scalarize $ flatSlice denseShape 0 1 valuesShape = CoreOps.concat 0 [ allDimensions , flatSlice denseShape 1 (-1) ] values = reshape dz valuesShape -- TODO(fmayle): This could be either Int32 or Int64. - indices' = reshape indices allDimensions :: Tensor Value Int32 + indices' = reshape indices allDimensions :: Tensor Build Int32 opGrad "Max" _ [toT -> x, toT -> indices] [dz] = [Just $ indicators `CoreOps.div` numSelected * dz', Nothing] where - sx = shape (x :: Tensor Value a) - outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) + sx = shape (x :: Tensor Build a) + outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32) y = CoreOps.max x indices y' = reshape y outputShapeKeptDims dz' = reshape dz outputShapeKeptDims @@ -475,8 +484,8 @@ opGrad "Sum" _ [toT -> x, toT -> indices] [dz] = [ Just $ CoreOps.tile grad tileScaling, Nothing ] where -- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad. - sx = shape (x :: Tensor Value a) - outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) + sx = shape (x :: Tensor Build a) + outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32) tileScaling = safeShapeDiv sx outputShapeKeptDims grad = reshape dz outputShapeKeptDims @@ -484,8 +493,8 @@ opGrad "Mean" u v@[toT -> x, _] w = [Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing] where [Just dz, Nothing] = opGrad "Sum" u v w - inputShape = shape (x :: Tensor Value a) - outputShape = shape (dz :: Tensor Value a) + inputShape = shape (x :: Tensor Build a) + outputShape = shape (dz :: Tensor Build a) -- TODO(fmayle): Add fast path when shape is known. inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape @@ -495,8 +504,8 @@ opGrad "Add" _ [toT -> x, toT -> y] [dz] = [ Just $ reshape (sum dz rx) sx , Just $ reshape (sum dz ry) sy ] where - sx = shape (x :: Tensor Value a) - sy = shape (y :: Tensor Value a) + sx = shape (x :: Tensor Build a) + sy = shape (y :: Tensor Build a) (rx, ry) = broadcastGradientArgs sx sy opGrad "Sub" u v w = @@ -510,22 +519,24 @@ opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] = opGrad "Mul" _ [toT -> x, toT -> y] [dz] = -- TODO(fmayle): Handle complex numbers. - [ Just $ reshape (sum (dz * y) rx) sx - , Just $ reshape (sum (x * dz) ry) sy ] + [ Just $ reshape (sum (dz `CoreOps.mul` y) rx) sx + , Just $ reshape (sum (x `CoreOps.mul` dz) ry) sy ] where - sx = shape (x :: Tensor Value a) - sy = shape (y :: Tensor Value a) + sx = shape (x :: Tensor Build a) + sy = shape (y :: Tensor Build a) (rx, ry) = broadcastGradientArgs sx sy opGrad "Div" _ [toT -> x, toT -> y] [dz] = -- TODO(fmayle): Handle complex numbers. -- TODO(gnezdo): Provide Fractional instance and use '/' instead of div. [ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx - , Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy + , Just $ reshape (sum (dz `CoreOps.mul` (negate x `CoreOps.div` (y * y))) + ry) + sy ] where - sx = shape (x :: Tensor Value a) - sy = shape (y :: Tensor Value a) + sx = shape (x :: Tensor Build a) + sy = shape (y :: Tensor Build a) (rx, ry) = broadcastGradientArgs sx sy opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] = @@ -549,7 +560,7 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] = opGrad "Transpose" _ [_, toT -> p] [dz] = [ Just $ CoreOps.transpose dz - (CoreOps.invertPermutation p :: Tensor Value Int32) + (CoreOps.invertPermutation p :: Tensor Build Int32) , Nothing ] @@ -582,28 +593,28 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] = x output dz ] where - output :: Tensor Value a - output = toT $ Output 0 (Rendered nodeDef) + output :: Tensor Build a + output = toT $ Output 0 (nodeDefName nodeDef) ksize = lookupAttr nodeDef "ksize" :: [Int64] strides = lookupAttr nodeDef "strides" :: [Int64] padding = lookupAttr nodeDef "padding" :: ByteString dataFormat = lookupAttr nodeDef "data_format" :: ByteString opGrad "Reshape" _ [toT -> x, _] [dz] = - [Just $ reshape dz $ shape (x :: Tensor Value a), Nothing] + [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing] opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing] opGrad "TruncatedNormal" _ _ _ = [Nothing] -opGrad "RefIdentity" _ _ [dz] = [Just dz] +opGrad "RefIdentity" _ _ [dz] = [Just $ expr dz] opGrad "Cast" nodeDef _ [dz] = [Just reverseCast] where -- TODO(gnezdo): too permissive, python only allows float types as src_type. reverseCast = - buildOp (opDef "Cast" + pureOp [] $ pure (opDef "Cast" & opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString) - & opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString)) - dz + & opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString) + & opInputs .~ [renderedOutput dz]) opGrad "DynamicStitch" nodeDef inputs [dz] = replicate halfLen Nothing ++ valuesGrads @@ -614,7 +625,7 @@ opGrad "DynamicStitch" nodeDef inputs [dz] = in if 2 * half == len then half else error ("Uneven input size " ++ show (len, showMessage nodeDef)) - valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32) + valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Build Int32) | idx <- take halfLen inputs ] @@ -622,14 +633,14 @@ opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz = [ Just reconstructed, Nothing ] where reconstructed = CoreOps.reshape stitched - (CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32) + (CoreOps.shape (xs :: Tensor Build a) :: Tensor Build Int32) stitched = CoreOps.dynamicStitch partitionedIndices dz partitionedIndices = CoreOps.dynamicPartition np originalIndices indices np = lookupAttr nodeDef "num_partitions" :: Int64 originalIndices = CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape prefixShape = shapeInt32 indices - shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32 + shapeInt32 t = CoreOps.shape t :: Tensor Build Int32 opGrad "Select" _ [toT -> c, toT -> x, _] [dz] = [ Nothing @@ -639,18 +650,18 @@ opGrad "Select" _ [toT -> c, toT -> x, _] [dz] = where zeros = CoreOps.zerosLike x -- TODO(gnezdo): Unlike Python, no control dependency on dz. -opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ] +opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.inv x ] -- TODO(gnezdo): Reuse the output instead of doing another exp, -- though, it is probably CSE'd away anyway. -opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ] +opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.exp x ] opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] = [ Just $ CoreOps.unsortedSegmentSum - (CoreOps.gather dz (t :: Tensor Value Int32)) - (y :: Tensor Value Int32) inputRows + (CoreOps.gather dz (t :: Tensor Build Int32)) + (y :: Tensor Build Int32) inputRows , Nothing , Nothing ] - where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1 + where inputRows = flatSlice (shape (x :: Tensor Build a)) 0 1 opGrad "LabelClasses" _ _ _ = [Nothing, Nothing] opGrad "LabelWeights" _ _ _ = [Nothing] @@ -710,13 +721,13 @@ numOutputs o = _ -> error $ "numOuputs not implemented for " ++ show (o ^. op) -- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0` -safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32 +safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32 safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1) -allDimensions :: Tensor Value Int32 +allDimensions :: Tensor Build Int32 allDimensions = vector [-1 :: Int32] -rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32 +rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Build Int32 rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1 lookupAttr :: Attribute a1 => NodeDef -> Text -> a1 diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index e123a41..3f27a80 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -166,7 +166,6 @@ import qualified Proto.Tensorflow.Core.Framework.TensorShape import TensorFlow.Build import TensorFlow.BuildOp import TensorFlow.ControlFlow (group) -import TensorFlow.Output (unNodeName) import TensorFlow.Tensor import TensorFlow.Types @@ -183,7 +182,7 @@ import qualified Prelude (abs) -- "1". instance ( TensorType a , Num a - , v ~ Value + , v ~ Build , OneOf '[ Double, Float, Int32, Int64 , Complex Float, Complex Double] a) => Num (Tensor v a) where (+) = CoreOps.add @@ -194,10 +193,10 @@ instance ( TensorType a signum = CoreOps.sign negate = CoreOps.neg -matTranspose :: TensorType a => Tensor v a -> Tensor Value a +matTranspose :: TensorType a => Tensor e a -> Tensor Build a matTranspose = matTranspose' id -matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a +matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Build a matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32]) placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a) @@ -208,7 +207,7 @@ placeholder' :: forall m a . (MonadBuild m, TensorType a) placeholder' params pShape -- Note: we don't use CoreOps.placeholder' since that op isn't stateful, -- and thus would be CSE'd. - = build $ buildOp $ opDef "Placeholder" + = build $ buildOp [] $ opDef "Placeholder" & opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "shape" .~ pShape & params @@ -216,11 +215,11 @@ placeholder' params pShape -- | Creates a variable initialized to the given value. -- Initialization happens next time session runs. initializedVariable :: (MonadBuild m, TensorType a) - => Tensor Value a -> m (Tensor Ref a) + => Tensor v a -> m (Tensor Ref a) initializedVariable = initializedVariable' id initializedVariable' :: (MonadBuild m, TensorType a) - => OpParams -> Tensor Value a -> m (Tensor Ref a) + => OpParams -> Tensor v a -> m (Tensor Ref a) initializedVariable' params initializer = do v <- CoreOps.variable' params [] -- The shape is not known initially. i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v @@ -240,17 +239,20 @@ zeroInitializedVariable' zeroInitializedVariable' params = initializedVariable' params . zeros -- TODO: Support heterogeneous list of tensors. -save :: forall a m v . (MonadBuild m, TensorType a) +save :: forall a m v . (Rendered v, MonadBuild m, TensorType a) => ByteString -- ^ File path. -> [Tensor v a] -- ^ Tensors to save. -> m ControlNode -save path xs = do - let toByteStringTensor = scalar . encodeUtf8 . unNodeName - names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs +save path xs = build $ do + let toByteStringTensor = scalar . encodeUtf8 . encodeOutput . renderedOutput + let names = fmap toByteStringTensor xs let types = replicate (length xs) (tensorType (undefined :: a)) - let saveOp = buildOp $ opDef "Save" - & opAttr "T" .~ types - build $ saveOp (scalar path) (CoreOps.pack names) xs + names' <- buildInputs $ CoreOps.pack names + xs' <- buildInputs xs + path' <- buildInputs $ scalar path + buildOp [] $ opDef "Save" + & opAttr "T" .~ types + & opInputs .~ (path' ++ names' ++ xs') -- | Restore a tensor's value from a checkpoint file. -- @@ -261,20 +263,22 @@ restoreFromName :: forall a m . (MonadBuild m, TensorType a) -> ByteString -- ^ Tensor name override. -> Tensor Ref a -- ^ Tensor to restore. -> m ControlNode -restoreFromName path name x = do - let restoreOp = buildOp $ opDef "Restore" - & opAttr "dt" .~ tensorType (undefined :: a) - group =<< CoreOps.assign x - (restoreOp (scalar path) (scalar name) :: Tensor Value a) +restoreFromName path name x = build $ do + path' <- buildInputs $ scalar path + name' <- buildInputs $ scalar name + restoreOp <- buildOp [] $ opDef "Restore" + & opAttr "dt" .~ tensorType (undefined :: a) + & opInputs .~ (path' ++ name') + group =<< CoreOps.assign x (restoreOp :: Tensor Value a) -- | Restore a tensor's value from a checkpoint file. restore :: forall a m . (MonadBuild m, TensorType a) => ByteString -- ^ File path. -> Tensor Ref a -- ^ Tensor to restore. -> m ControlNode -restore path x = do - name <- encodeUtf8 . unNodeName <$> build (renderNodeName x) - restoreFromName path name x +restore path x = restoreFromName path name x + where + name = encodeUtf8 $ encodeOutput $ renderedOutput x -- | Create a constant tensor. -- @@ -283,10 +287,10 @@ restore path x = do -- element 0: index (0, ..., 0) -- element 1: index (0, ..., 1) -- ... -constant :: TensorType a => Shape -> [a] -> Tensor Value a +constant :: TensorType a => Shape -> [a] -> Tensor Build a constant = constant' id -constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a +constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Build a constant' params (Shape cShape) values | invalidLength = error invalidLengthMsg | otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode)) @@ -305,24 +309,24 @@ constant' params (Shape cShape) values -- | Reshape a N-D tensor down to a scalar. -- -- See `TensorFlow.GenOps.Core.reshape`. -scalarize :: (TensorType a) => Tensor v a -> Tensor Value a +scalarize :: TensorType a => Tensor v a -> Tensor Build a scalarize t = CoreOps.reshape t (vector scalarShape) where scalarShape = [] :: [Int32] -- | Create a constant vector. -vector :: TensorType a => [a] -> Tensor Value a +vector :: TensorType a => [a] -> Tensor Build a vector = vector' id -vector' :: TensorType a => OpParams -> [a] -> Tensor Value a +vector' :: TensorType a => OpParams -> [a] -> Tensor Build a vector' params xs = constant' params [fromIntegral $ length xs] xs -- | Create a constant scalar. -scalar :: TensorType a => a -> Tensor Value a +scalar :: TensorType a => a -> Tensor Build a scalar = scalar' id -scalar' :: TensorType a => OpParams -> a -> Tensor Value a +scalar' :: TensorType a => OpParams -> a -> Tensor Build a scalar' params x = constant' params [] [x] -- | Random tensor from the unit normal distribution with bounded values. @@ -338,28 +342,28 @@ truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a) -> m (Tensor Value a) truncatedNormal' = CoreOps.truncatedNormal' -zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a +zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0) -shape :: TensorType t => Tensor v1 t -> Tensor Value Int32 +shape :: TensorType t => Tensor v t -> Tensor Build Int32 shape = CoreOps.shape -shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32 +shape' :: TensorType t => OpParams -> Tensor v t -> Tensor Build Int32 shape' = CoreOps.shape' -expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t +expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t expandDims = CoreOps.expandDims -expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t +expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t expandDims' = CoreOps.expandDims' -- | Helper function for reduction ops (translation of math_ops.reduced_shape). reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) => - Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32 + Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32 reducedShape inputShape axes = let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7] axes32 = toInt32 axes -- [1, 2] - toInt32 x = CoreOps.cast x :: Tensor Value Int32 + toInt32 x = CoreOps.cast x :: Tensor Build Int32 inputRank = CoreOps.size inputShape32 -- 4 axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank axesShape = shape axesMod -- [2] diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index d5a37bd..d11b093 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -79,6 +79,7 @@ Test-Suite EmbeddingOpsTest , test-framework , test-framework-hunit , test-framework-quickcheck2 + , transformers , vector Test-Suite ArrayOpsTest diff --git a/tensorflow-ops/tests/ArrayOpsTest.hs b/tensorflow-ops/tests/ArrayOpsTest.hs index 1b32b03..e266ace 100644 --- a/tensorflow-ops/tests/ArrayOpsTest.hs +++ b/tensorflow-ops/tests/ArrayOpsTest.hs @@ -24,9 +24,7 @@ import Test.HUnit ((@=?)) import qualified Data.Vector as V import qualified TensorFlow.Ops as TF -import qualified TensorFlow.Session as TF -import qualified TensorFlow.Tensor as TF -import qualified TensorFlow.Types as TF +import qualified TensorFlow.Core as TF import qualified TensorFlow.GenOps.Core as CoreOps -- | Test split and concat are inverses. @@ -44,7 +42,7 @@ testSplit = testCase "testSplit" $ TF.runSession $ do testShapeN :: Test testShapeN = testCase "testShapeN" $ TF.runSession $ do let shapes = map TF.Shape [[1],[2,3]] - let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float] + let tensors = map TF.zeros shapes :: [TF.Tensor TF.Build Float] result <- TF.run $ CoreOps.shapeN tensors liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64]) diff --git a/tensorflow-ops/tests/BuildTest.hs b/tensorflow-ops/tests/BuildTest.hs index 9410dab..1df0719 100644 --- a/tensorflow-ops/tests/BuildTest.hs +++ b/tensorflow-ops/tests/BuildTest.hs @@ -34,9 +34,7 @@ import TensorFlow.Build , asGraphDef , evalBuildT , flushNodeBuffer - , render , withDevice - , colocateWith , withNameScope , opName ) @@ -50,7 +48,13 @@ import TensorFlow.Ops , variable' ) import TensorFlow.Output (Device(..)) -import TensorFlow.Tensor (Tensor, Value, Ref) +import TensorFlow.Tensor + ( colocateWith + , render + , Tensor + , Value + , Ref + ) import TensorFlow.Session ( run , runSession @@ -65,8 +69,7 @@ import qualified Data.Vector as V -- | Test 'opName' behavior. testOpName :: Test testOpName = testCase "testOpName" $ do - let graph = variable' (opName .~ "foo") [] - >>= render :: Build (Tensor Ref Float) + let graph = variable' (opName .~ "foo") [] :: Build (Tensor Ref Float) nodeDef :: NodeDef nodeDef = head $ asGraphDef graph ^. node "Variable" @=? (nodeDef ^. op) @@ -114,7 +117,6 @@ testNamedAndScoped :: Test testNamedAndScoped = testCase "testNamedAndScoped" $ do let graph :: Build (Tensor Ref Float) graph = withNameScope "foo1" (variable' (opName .~ "bar1") []) - >>= render nodeDef :: NodeDef nodeDef = head $ asGraphDef graph ^. node "Variable" @=? (nodeDef ^. op) diff --git a/tensorflow-ops/tests/DataFlowOpsTest.hs b/tensorflow-ops/tests/DataFlowOpsTest.hs index 38fdb5a..184cbd2 100644 --- a/tensorflow-ops/tests/DataFlowOpsTest.hs +++ b/tensorflow-ops/tests/DataFlowOpsTest.hs @@ -26,16 +26,14 @@ import Test.QuickCheck.Monadic (monadicIO, run) import qualified Data.Vector as V import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.Ops as TF -import qualified TensorFlow.Session as TF -import qualified TensorFlow.Tensor as TF -import qualified TensorFlow.Types as TF +import qualified TensorFlow.Core as TF -- DynamicSplit is undone with DynamicStitch to get the original input -- back. testDynamicPartitionStitchInverse :: forall a. (TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property testDynamicPartitionStitchInverse (StitchExample numParts values partitions) = - let splitParts :: [TF.Tensor TF.Value a] = + let splitParts :: [TF.Tensor TF.Build a] = CoreOps.dynamicPartition numParts (TF.vector values) partTensor partTensor = TF.vector partitions restitchIndices = CoreOps.dynamicPartition numParts diff --git a/tensorflow-ops/tests/EmbeddingOpsTest.hs b/tensorflow-ops/tests/EmbeddingOpsTest.hs index 137f7f6..ecaeaa4 100644 --- a/tensorflow-ops/tests/EmbeddingOpsTest.hs +++ b/tensorflow-ops/tests/EmbeddingOpsTest.hs @@ -19,6 +19,7 @@ -- | Tests for EmbeddingOps. module Main where +import Control.Monad.IO.Class (liftIO) import Data.Int (Int32, Int64) import Data.List (genericLength) import Google.Test (googleTest) @@ -48,16 +49,15 @@ testEmbeddingLookupHasRightShapeWithPartition = let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items. let embedding1 = [1, 1, 1 :: Int32] let embedding2 = [0, 0, 0 :: Int32] - let embedding = [ TF.constant embShape embedding1 - , TF.constant embShape embedding2 - ] let idValues = [0, 1 :: Int32] - let ids = TF.constant (TF.Shape [1, 2]) idValues - let op = embeddingLookup embedding ids (values, shape) <- TF.runSession $ do - vs <- op + embedding <- mapM TF.render [ TF.constant embShape embedding1 + , TF.constant embShape embedding2 + ] + let ids = TF.constant (TF.Shape [1, 2]) idValues + vs <- embeddingLookup embedding ids TF.run (vs, TF.shape vs) -- This is the shape that is returned in the equiv. Python. @@ -77,13 +77,12 @@ testEmbeddingLookupHasRightShape = , 0, 0, 0 :: Int32 ] - let embedding = TF.constant embShape embeddingInit let idValues = [0, 1 :: Int32] - let ids = TF.constant (TF.Shape [1, 2]) idValues - let op = embeddingLookup [embedding] ids (values, shape) <- TF.runSession $ do - vs <- op + embedding <- TF.render $ TF.constant embShape embeddingInit + let ids = TF.constant (TF.Shape [1, 2]) idValues + vs <- embeddingLookup [embedding] ids TF.run (vs, TF.shape vs) -- This is the shape that is returned in the equiv. Python. @@ -92,7 +91,6 @@ testEmbeddingLookupHasRightShape = -- "[0, 1]" should pull out the resulting vector. values @=? V.fromList [1, 1, 1, 0, 0, 0] - -- | Check that we can calculate gradients w.r.t embeddings. testEmbeddingLookupGradients :: Test testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do @@ -108,10 +106,10 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do x <- TF.placeholder (TF.Shape [2]) embedding <- TF.initializedVariable - =<< TF.render (TF.constant embShape embeddingInit) + (TF.constant embShape embeddingInit) op <- embeddingLookup [embedding] ids - let twoNorm = CoreOps.square $ TF.abs (op - x) + let twoNorm = CoreOps.square $ TF.abs (op `TF.sub` x) loss = TF.mean twoNorm (TF.scalar (0 :: Int32)) grad <- fmap head (TF.gradients loss [embedding]) @@ -131,23 +129,21 @@ testEmbeddingLookupUndoesSplit (LookupExample numParts shape@(TF.Shape (firstDim : restDims)) values - indices) = - let modShardedValues :: [TF.Tensor TF.Value a] = - CoreOps.dynamicPartition numParts shapedValues cyclicCounter - cyclicCounter :: TF.Tensor TF.Value Int32 = + indices) = monadicIO $ run $ TF.runSession $ do + let shapedValues = TF.constant shape values + indicesVector <- TF.render $ TF.vector indices + let directs = CoreOps.gather shapedValues indicesVector + let cyclicCounter :: TF.Tensor TF.Build Int32 = TF.vector [0..fromIntegral firstDim-1] `CoreOps.mod` fromIntegral numParts - indicesVector = TF.vector indices - directs = CoreOps.gather shapedValues indicesVector - shapedValues = TF.constant shape values - in monadicIO $ run $ do - (shapeOut, got, want :: V.Vector a) <- - TF.runSession $ TF.run =<< do - embeddings <- embeddingLookup modShardedValues indicesVector - return (TF.cast (TF.shape embeddings), embeddings, directs) - -- Checks the explicitly documented invariant of embeddingLookup. - shapeOut @=? V.fromList (genericLength indices : restDims) - got @=? want + modShardedValues :: [TF.Tensor TF.Value a] <- + mapM TF.render $ CoreOps.dynamicPartition numParts shapedValues cyclicCounter + embeddings <- embeddingLookup modShardedValues indicesVector + (shapeOut, got, want :: V.Vector a) <- + TF.run (TF.cast (TF.shape embeddings), embeddings, directs) + -- Checks the explicitly documented invariant of embeddingLookup. + liftIO $ shapeOut @=? V.fromList (genericLength indices : restDims) + liftIO $ got @=? want testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)" -- | Consistent set of parameters for EmbeddingLookupUndoesSplit. diff --git a/tensorflow-ops/tests/GradientTest.hs b/tensorflow-ops/tests/GradientTest.hs index 765ed22..b51eab9 100644 --- a/tensorflow-ops/tests/GradientTest.hs +++ b/tensorflow-ops/tests/GradientTest.hs @@ -36,10 +36,11 @@ import Proto.Tensorflow.Core.Framework.NodeDef (op) testGradientSimple :: Test testGradientSimple = testCase "testGradientSimple" $ do - let x = TF.scalar (3 :: Float) - b = TF.scalar (4 :: Float) - y = x*x + b - grads = TF.gradients y [x, b] + let grads = do + x <- TF.render $ TF.scalar (3 :: Float) + b <- TF.render $ TF.scalar (4 :: Float) + let y = x `TF.mul` x `TF.add` b + TF.gradients y [x, b] -- Assert that the gradients are right. [dx, db] <- TF.runSession $ grads >>= TF.run 6 @=? TF.unScalar dx @@ -88,9 +89,10 @@ testGradientSimple = testCase "testGradientSimple" $ do testGradientDisconnected :: Test testGradientDisconnected = testCase "testGradientDisconnected" $ do - let x = TF.scalar (3 :: Float) - b = TF.scalar (4 :: Float) - grads = TF.gradients x [x, b] + let grads = do + x <- TF.render $ TF.scalar (3 :: Float) + b <- TF.render $ TF.scalar (4 :: Float) + TF.gradients x [x, b] -- Assert that the gradients are right. [dx, db] <- TF.runSession $ grads >>= TF.run 1 @=? TF.unScalar dx @@ -118,7 +120,7 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do let shape = TF.constant (TF.Shape [1]) [1] x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape - TF.gradients (x + y*3) [x, y] >>= TF.run + TF.gradients (TF.expr x + TF.expr y * 3) [x, y] >>= TF.run -- If this test fails, it will likely be caused by an exception within -- `TF.gradients`. These asserts are extra. 1 @=? TF.unScalar dx @@ -142,8 +144,8 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do testDiamond :: Test testDiamond = testCase "testDiamond" $ do [dx] <- TF.runSession $ do - let x = TF.vector [1] - y = x*x + x <- TF.render $ TF.vector [1] + let y = x `TF.mul` x z = y*y TF.gradients z [x] >>= TF.run (4 :: Float) @=? TF.unScalar dx @@ -152,8 +154,8 @@ testDiamond = testCase "testDiamond" $ do testMaxGradient :: Test testMaxGradient = testCase "testMaxGradient" $ do [dx] <- TF.runSession $ do - let x = TF.vector [1, 2, 3, 0, 1 :: Float] - y = TF.max x (0 :: TF.Tensor TF.Value Int32) + x <- TF.render $ TF.vector [1, 2, 3, 0, 1 :: Float] + let y = TF.max x (0 :: TF.Tensor TF.Build Int32) TF.gradients y [x] >>= TF.run V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index a117015..de35828 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -56,8 +56,7 @@ testSaveRestore = testCase "testSaveRestore" $ withSystemTempDirectory "" $ \dirPath -> do let path = B8.pack $ dirPath ++ "/checkpoint" var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float) - var = TF.render =<< - TF.zeroInitializedVariable' (TF.opName .~ "a") + var = TF.zeroInitializedVariable' (TF.opName .~ "a") (TF.Shape []) TF.runSession $ do v <- var @@ -76,7 +75,8 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do p2 <- TF.placeholder [] let enc :: Float -> TF.TensorData Float enc n = TF.encodeTensorData [] (V.fromList [n]) - result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 + result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] + $ p1 `TF.add` p2 liftIO $ result @=? TF.Scalar 5 -- | Test that regular tensors can also be used for feeds, as long as they each @@ -90,7 +90,8 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0 let enc :: Float -> TF.TensorData Float enc n = TF.encodeTensorData [] (V.fromList [n]) - result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 + result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] + $ p1 `TF.add` p2 liftIO $ result @=? TF.Scalar 5 main :: IO () diff --git a/tensorflow-ops/tests/RegressionTest.hs b/tensorflow-ops/tests/RegressionTest.hs index ec83eed..7fd2e73 100644 --- a/tensorflow-ops/tests/RegressionTest.hs +++ b/tensorflow-ops/tests/RegressionTest.hs @@ -38,7 +38,7 @@ fit xData yData = TF.runSession $ do return (w', b') gradientDescent :: Float - -> TF.Tensor TF.Value Float + -> TF.Tensor TF.Build Float -> [TF.Tensor TF.Ref Float] -> TF.Session TF.ControlNode gradientDescent alpha loss params = do diff --git a/tensorflow-ops/tests/TypesTest.hs b/tensorflow-ops/tests/TypesTest.hs index 567d232..a7fdc84 100644 --- a/tensorflow-ops/tests/TypesTest.hs +++ b/tensorflow-ops/tests/TypesTest.hs @@ -53,9 +53,9 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $ let floatData = V.fromList [1..6 :: Float] stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]] boolData = V.fromList [True, True, False, True, False, False] - f <- TF.build $ TF.placeholder [2,3] - s <- TF.build $ TF.placeholder [2,3] - b <- TF.build $ TF.placeholder [2,3] + f <- TF.placeholder [2,3] + s <- TF.placeholder [2,3] + b <- TF.placeholder [2,3] let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData) , TF.feed s (TF.encodeTensorData [2,3] stringData) , TF.feed b (TF.encodeTensorData [2,3] boolData) @@ -63,7 +63,8 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $ -- Do something idempotent to the tensors to verify that tensorflow can -- handle the encoding. Originally this used `TF.identity`, but that -- wasn't enough to catch a bug in the encoding of Bool. - (f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b) + (f', s', b') <- TF.runWithFeeds feeds + (f `TF.add` 0, TF.identity s, TF.select b b b) liftIO $ do floatData @=? f' stringData @=? s' diff --git a/tensorflow-queue/src/TensorFlow/Queue.hs b/tensorflow-queue/src/TensorFlow/Queue.hs index 2e0a5cf..0131e9b 100644 --- a/tensorflow-queue/src/TensorFlow/Queue.hs +++ b/tensorflow-queue/src/TensorFlow/Queue.hs @@ -60,7 +60,7 @@ makeQueue :: forall as m . (MonadBuild m, TensorTypes as) -- under the given name across multiple sessions. -> m (Queue as) makeQueue capacity sharedName = do - q <- build $ buildOp (opDef "FIFOQueue" + q <- build $ buildOp [] (opDef "FIFOQueue" & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as) & opAttr "shared_name" .~ sharedName & opAttr "capacity" .~ capacity diff --git a/tensorflow/src/TensorFlow/Build.hs b/tensorflow/src/TensorFlow/Build.hs index 9cdd43b..84c6c09 100644 --- a/tensorflow/src/TensorFlow/Build.hs +++ b/tensorflow/src/TensorFlow/Build.hs @@ -13,9 +13,13 @@ -- limitations under the License. {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} +{-# LANGUAGE TypeFamilies #-} module TensorFlow.Build ( -- * Graph node types ControlNode(..) @@ -32,8 +36,6 @@ module TensorFlow.Build , opControlInputs -- * The Build monad , GraphState - , render - , renderNodeName , renderedNodeDefs , BuildT , Build @@ -46,27 +48,23 @@ module TensorFlow.Build , addGraphDef , flushInitializers , flushNodeBuffer + , summaries -- * Creating and looking up Ops , getOrAddOp , addNewOp - , renderOutput + , encodeOutput + , lookupNode -- * Modifying all nodes in a Build action - , colocateWith , withStateLens , withDevice , withNameScope , withNodeDependencies - -- * Internal Summary related bits. - , addSummary - , SummaryTensor - , collectAllSummaries ) where import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT) -import Data.ByteString (ByteString) import Data.Default (def) import Data.Functor.Identity (Identity(..)) import qualified Data.Map.Strict as Map @@ -94,7 +92,6 @@ import Proto.Tensorflow.Core.Framework.NodeDef import TensorFlow.Orphans () import TensorFlow.Output -import TensorFlow.Tensor newtype Unique = Unique Int deriving (Eq, Ord, Enum) @@ -125,9 +122,6 @@ opDefWithName n t = OpDef , _opControlInputs = [] } --- | Synonym for the tensors that return serialized Summary proto. -type SummaryTensor = Tensor Value ByteString - data GraphState = GraphState { _renderedNodes :: !(Map.Map PendingNode NodeDef) -- ^ Nodes which have been rendered. Keeps track of the unique ID we @@ -148,8 +142,8 @@ data GraphState = GraphState , _initializationNodes :: [NodeName] -- ^ The nodes to run next time a TF.run is issued, typically -- variable initializers. - , _summaries :: [SummaryTensor] - -- ^ The tensors for summary + , _summaries :: [Output] + -- ^ The tensors for summary (ByteString type) } -- | A node definition without its final name. Used as a key in the @@ -191,7 +185,7 @@ defaultControlInputs = lens _defaultControlInputs initializationNodes :: Lens' GraphState [NodeName] initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x }) -summaries :: Lens' GraphState [SummaryTensor] +summaries :: Lens' GraphState [Output] summaries = lens _summaries (\g x -> g { _summaries = x }) -- | An action for building nodes in a TensorFlow graph. @@ -238,9 +232,7 @@ flushInitializers = do -- | Registers the given node to be executed before the next -- 'TensorFlow.Session.run'. addInitializer :: MonadBuild m => ControlNode -> m () -addInitializer (ControlNode o) = build $ do - i <- getOrAddOp o - initializationNodes %= (i:) +addInitializer (ControlNode i) = build $ initializationNodes %= (i:) -- | Produce a GraphDef proto representation of the nodes that are rendered in -- the given 'Build' action. @@ -255,30 +247,31 @@ addGraphDef g = build $ nodeBuffer <>= g ^. node -- | Render the given op if it hasn't been rendered already, and return its -- name. -getOrAddOp :: Op -> Build NodeName -getOrAddOp o = NodeName . (^. name) <$> resolveOp o - -resolveOp :: Op -> Build NodeDef -resolveOp (Rendered n) = return n -resolveOp (Unrendered o) = do +getOrAddOp :: OpDef -> Build NodeName +getOrAddOp o = do pending <- getPendingNode o uses renderedNodes (Map.lookup pending) >>= \case - Just n -> return n + Just n -> return $ NodeName $ n ^. name Nothing -> addNewOpFromPending pending +lookupNode :: NodeName -> Build NodeDef +lookupNode n = uses renderedNodeDefs (Map.lookup n) >>= \case + Just n' -> return n' + Nothing -> error $ "lookupNode: unknown node name " ++ show n + -- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops -- which are not safe to dedup (e.g, "variable" and "assign"). -addNewOp :: OpDef -> Build NodeDef +addNewOp :: OpDef -> Build NodeName addNewOp o = getPendingNode o >>= addNewOpFromPending -addNewOpFromPending :: PendingNode -> Build NodeDef +addNewOpFromPending :: PendingNode -> Build NodeName addNewOpFromPending pending = do nodeName <- renderPendingNode pending let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName nodeBuffer %= (nodeDef :) renderedNodes %= Map.insert pending nodeDef renderedNodeDefs %= Map.insert nodeName nodeDef - return nodeDef + return nodeName -- | Get the pending node corresponding to an OpDef, which may or may not have -- been rendered before. Implicitly renders all of this node's inputs. @@ -287,20 +280,18 @@ getPendingNode o = do -- An empty string in the proto field means that no specific -- device is specified. dev <- maybe "" deviceName <$> use defaultDevice - inputs <- mapM getInput (o ^. opInputs) scope <- use currentScope controls <- use defaultControlInputs + let inputs = map encodeOutput (o ^. opInputs) let controlInputs - = map getDep (o ^. opControlInputs ++ Set.toList controls) + = map makeDep (o ^. opControlInputs ++ Set.toList controls) return $ PendingNode scope (o ^. opName) $ def & op .~ (unOpType (o ^. opType) :: Text) & attr .~ _opAttrs o & input .~ (inputs ++ controlInputs) & device .~ dev where - getInput (Output (OutputIx k) subOp) - = (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp - getDep = ("^" <>) . unNodeName + makeDep = ("^" <>) . unNodeName -- | Pick a name for a pending node. If it has an explicit name, just use that; -- if the name is implicit, assign a new unique name based on the op type. @@ -317,12 +308,11 @@ renderPendingNode (PendingNode scope pendingName nodeDef) return $ nodeDef ^. op <> "_" <> Text.pack (show k) --- | Render an 'Output' and return a string representation for the TensorFlow +-- | Turn an 'Output' into a string representation for the TensorFlow -- foreign APIs. -renderOutput :: Output -> Build Text -renderOutput (Output (OutputIx i) o) = do - n <- getOrAddOp o - return $ unNodeName n <> Text.pack (":" ++ show i) +encodeOutput :: Output -> Text +encodeOutput (Output (OutputIx 0) n) = unNodeName n +encodeOutput (Output (OutputIx i) n) = unNodeName n <> Text.pack (':' : show i) -- | Modify some part of the state, run an action, and restore the state -- after that action is done. @@ -339,15 +329,6 @@ withStateLens accessor f act = do withDevice :: MonadBuild m => Maybe Device -> m a -> m a withDevice d = withStateLens defaultDevice (const d) --- | Places all nodes rendered in the given 'Build' action on the same --- device as the given Tensor (see also 'withDevice'). Make sure that --- the action has side effects of rendering the desired tensors. A pure --- return would not have the desired effect. -colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a -colocateWith t x = do - d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp) - withDevice (Just d) x - -- | Prepend a scope to all nodes rendered in the given 'Build' action. withNameScope :: MonadBuild m => Text -> m a -> m a withNameScope s = withStateLens currentScope (Scope s :) @@ -355,31 +336,3 @@ withNameScope s = withStateLens currentScope (Scope s :) -- | Add control inputs to all nodes rendered in the given 'Build' action. withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes) - --- | Render a 'Tensor', fixing its name, scope, device and control inputs from --- the 'Build' context. Also renders any dependencies of the 'Tensor' that --- weren't already rendered. --- --- This operation is idempotent; @render >=> render === render@. However, --- rendering a (previously un-rendered) 'Tensor' in two different contexts --- may result in two different 'Tensor's. -render :: MonadBuild m => Tensor v a -> m (Tensor v a) -render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp) - --- | Render a 'Tensor' and get its node's name. -renderNodeName :: Tensor v a -> Build NodeName -renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp) - --- | Records the given summary action in Build for retrieval with --- 'collectAllSummaries'. The summary op is required to produce a --- Summary protocol buffer in string form. For safety, use the --- pre-composed functions: Logging.scalarSummary and --- Logging.histogramSummary. -addSummary :: SummaryTensor -> Build () -addSummary t = summaries %= (t :) - --- | Retrieves the summary ops collected thus far. Typically this only --- happens once, but if 'TensorFlow.Session.buildWithSummary' is used --- repeatedly, the values accumulate. -collectAllSummaries :: Monad m => BuildT m [SummaryTensor] -collectAllSummaries = use summaries diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index ad625d1..4a6907e 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -12,26 +12,27 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module TensorFlow.BuildOp - ( OpResult - , BuildOp + ( BuildResult(..) , buildOp - , buildListOp + , PureResult(..) + , pureOp , eqLengthGuard + , BuildInputs(..) , OpParams ) where -import Control.Monad (replicateM) +import Control.Monad (liftM2, replicateM) import Control.Monad.Reader (ReaderT, runReaderT, ask) -import Control.Monad.State.Strict (State, runState, get, put) +import Control.Monad.State.Strict (State, evalState, get, put) import Data.Int (Int64) -import Lens.Family2 ((&), (<>~), (^.)) import TensorFlow.Build import TensorFlow.Output @@ -40,48 +41,45 @@ import TensorFlow.Types data ResultState = ResultState !OutputIx [Int64] deriving Show -type Result = ReaderT Op (State ResultState) +type Result = ReaderT NodeName (State ResultState) -- | Class of types that can be used as op outputs. -class OpResult a where - toResult :: Result a +class BuildResult a where + buildResult :: Result a -instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where - toResult = (,) <$> toResult <*> toResult +instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where + buildResult = (,) <$> buildResult <*> buildResult -instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where - toResult = (,,) <$> toResult <*> toResult <*> toResult +instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where + buildResult = (,,) <$> buildResult <*> buildResult <*> buildResult -instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4) - => OpResult (a1, a2, a3, a4) where - toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult +instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4) + => BuildResult (a1, a2, a3, a4) where + buildResult = (,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult -instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5) - => OpResult (a1, a2, a3, a4, a5) where - toResult = (,,,,) <$> toResult - <*> toResult - <*> toResult - <*> toResult - <*> toResult +instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5) + => BuildResult (a1, a2, a3, a4, a5) where + buildResult = (,,,,) <$> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult -instance ( OpResult a1 - , OpResult a2 - , OpResult a3 - , OpResult a4 - , OpResult a5 - , OpResult a6 +instance ( BuildResult a1 + , BuildResult a2 + , BuildResult a3 + , BuildResult a4 + , BuildResult a5 + , BuildResult a6 ) - => OpResult (a1, a2, a3, a4, a5, a6) where - toResult = (,,,,,) - <$> toResult - <*> toResult - <*> toResult - <*> toResult - <*> toResult - <*> toResult - -tensorResult :: TensorKind v -> Result (Tensor v a) -tensorResult v = Tensor v <$> recordResult + => BuildResult (a1, a2, a3, a4, a5, a6) where + buildResult = (,,,,,) + <$> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult recordResult :: Result Output recordResult = do @@ -90,144 +88,39 @@ recordResult = do put $! ResultState (i+1) ns return $! output i o -instance OpResult ResourceHandle where - toResult = ResourceHandle <$> recordResult +instance BuildResult ResourceHandle where + buildResult = ResourceHandle <$> recordResult -instance OpResult (Tensor Value a) where - toResult = tensorResult ValueKind +instance Rendered v => BuildResult (Tensor v a) where + buildResult = Tensor . pure <$> recordResult -instance OpResult (Tensor Ref a) where - toResult = tensorResult RefKind +instance BuildResult ControlNode where + buildResult = ControlNode <$> ask -instance OpResult ControlNode where - toResult = ControlNode <$> ask +instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where + buildResult = loop (tensorTypes :: TensorTypeList as) + where + loop :: TensorTypeList bs -> Result (TensorList v bs) + loop Nil = return Nil + loop (TensorTypeProxy :/ ls) = do + t <- buildResult + ts <- loop ls + return (t :/ ts) -tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as) -tensorListResult v = loop (tensorTypes :: TensorTypeList as) - where - loop :: TensorTypeList bs -> Result (TensorList v bs) - loop Nil = return Nil - loop (TensorTypeProxy :/ ls) = do - t <- tensorResult v - ts <- loop ls - return (t :/ ts) - -instance TensorTypes as => OpResult (TensorList Value as) where - toResult = tensorListResult ValueKind - -instance TensorTypes as => OpResult (TensorList Ref as) where - toResult = tensorListResult RefKind - -instance OpResult a => OpResult [a] where - toResult = do +instance BuildResult a => BuildResult [a] where + buildResult = do ResultState i ns <- get case ns of - [] -> error $ "Ran out of counts in toResult. " ++ - "Likely misuse of buildListOp." + [] -> error $ "Ran out of counts in buildResult. " ++ + "Likely misuse of buildOp." (n : rest) -> do put $! ResultState i rest - replicateM (fromIntegral n) toResult + replicateM (fromIntegral n) buildResult -runResult :: OpResult a => [Int64] -> Op -> a -runResult ns o = - case runState (runReaderT toResult o) (ResultState 0 ns) of - (x, ResultState _ []) -> x - (_, ns') -> error $ "Ununsed length in runResult attributes: " ++ - show (ns, ns') - --- | Make a new "pure" op, which may be deduped with identical ops within --- the same scope. -pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a -pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts - --- | Make a new "stateful" op, which will not be deduped with otherwise --- identical ops. -buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a -buildResult ns o ts - = runResult ns . Rendered <$> addNewOp (addReversedInputs o ts) - -addReversedInputs :: OpDef -> [Output] -> OpDef -addReversedInputs o ts = o & opInputs <>~ reverse ts - --- | Class of types that can be used as op functions. -class BuildOp f where - buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr) - -> OpDef - -> [Output] -- ^ Accumulator for inputs to the op. - -> f - --- | Starts an operation that returns a structured set of tensors --- (singletons or tuples). -buildOp :: BuildOp f => OpDef -> f -buildOp o = buildOp' [] o [] - --- | Starts an operation that returns a list of tensors. -buildListOp :: BuildOp f => [Int64] - -- ^ Cardinality of the corresponding list of tensors output. - -> OpDef -> f -buildListOp counts o = buildOp' counts o [] - -instance BuildOp ControlNode where - buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts - -instance BuildOp ResourceHandle where - buildOp' = pureResult - -instance BuildOp (Tensor Value a) where - buildOp' = pureResult - -instance BuildOp (Tensor Ref a) where - buildOp' = pureResult - -instance TensorTypes as => BuildOp (TensorList Value as) where - buildOp' = pureResult - -instance TensorTypes as => BuildOp (TensorList Ref as) where - buildOp' = pureResult - -instance BuildOp [Tensor Value a] where - buildOp' = pureResult - -instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where - buildOp' = pureResult - -instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where - buildOp' = pureResult - -instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4) - => BuildOp (t1, t2, t3, t4) where - buildOp' = pureResult - -instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5) - => BuildOp (t1, t2, t3, t4, t5) where - buildOp' = pureResult - -instance ( OpResult t1 - , OpResult t2 - , OpResult t3 - , OpResult t4 - , OpResult t5 - , OpResult t6 - ) - => BuildOp (t1, t2, t3, t4, t5, t6) where - buildOp' = pureResult - -instance OpResult a => BuildOp (Build a) where - buildOp' = buildResult - -instance BuildOp f => BuildOp (ResourceHandle -> f) where - buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts) - -instance BuildOp f => BuildOp (Tensor v a -> f) where - buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts) - -instance BuildOp f => BuildOp ([Tensor v a] -> f) where - buildOp' rf o accum ts - = buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum) - -instance BuildOp f => BuildOp (TensorList v as -> f) where - buildOp' rf o accum ts - = buildOp' rf o (reverse (tensorListOutputs ts) ++ accum) +buildOp :: BuildResult a => [Int64] -> OpDef -> Build a +buildOp sizes o = do + n <- addNewOp o + return $ flip evalState (ResultState 0 sizes) (runReaderT buildResult n) -- | Returns true if all the integers in each tuple are identical. -- Throws an error with a descriptive message if not. @@ -240,6 +133,104 @@ eqLengthGuard = all eachOk error ("number_attr " ++ numberAttrName ++ " contains tensors with different length " ++ show pairs) +----------- + + +-- | Class of types that can be used as op outputs. +class PureResult a where + pureResult :: ReaderT (Build OpDef) (State ResultState) a + +instance PureResult (Tensor Build a) where + pureResult = do + ResultState i ns <- get + put $! ResultState (i+1) ns + makeOp <- ask + return $ Tensor $ do + o <- makeOp + -- TODO: unify with BuildResult (Tensor v) + output i <$> getOrAddOp o + +instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where + pureResult = (,) <$> pureResult <*> pureResult + +instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where + pureResult = (,,) <$> pureResult <*> pureResult <*> pureResult + +instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4) + => PureResult (a1, a2, a3, a4) where + pureResult = (,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult + +instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5) + => PureResult (a1, a2, a3, a4, a5) where + pureResult = (,,,,) <$> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + +instance ( PureResult a1 + , PureResult a2 + , PureResult a3 + , PureResult a4 + , PureResult a5 + , PureResult a6 + ) + => PureResult (a1, a2, a3, a4, a5, a6) where + pureResult = (,,,,,) + <$> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + +instance PureResult a => PureResult [a] where + pureResult = do + ResultState i ns <- get + case ns of + [] -> error $ "Ran out of counts in pureResult. " ++ + "Likely misuse of pureOp with output lists." + n : rest -> do + put $! ResultState i rest + replicateM (fromIntegral n) pureResult + +instance TensorTypes as => PureResult (TensorList Build as) where + pureResult = loop (tensorTypes :: TensorTypeList as) + where + loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState) + (TensorList Build bs) + loop Nil = return Nil + loop (TensorTypeProxy :/ ls) = do + t <- pureResult + ts <- loop ls + return (t :/ ts) + +pureOp :: PureResult a => [Int64] -> Build OpDef -> a +pureOp sizes o = flip evalState (ResultState 0 sizes) (runReaderT pureResult o) + +----- +-- Class of types that can be used as arguments + +class BuildInputs a where + buildInputs :: a -> Build [Output] + +instance BuildInputs a => BuildInputs [a] where + buildInputs = fmap concat . mapM buildInputs + +instance BuildInputs (Tensor v a) where + buildInputs (Tensor t) = do + o <- toBuild t + return [o] + +instance BuildInputs (ListOf (Tensor v) as) where + buildInputs Nil = return [] + buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts) + +instance BuildInputs ResourceHandle where + buildInputs (ResourceHandle o) = return [o] + +---- + -- | Parameters to build an op (for example, the node name or optional attributes). -- TODO: be more type safe. type OpParams = OpDef -> OpDef diff --git a/tensorflow/src/TensorFlow/ControlFlow.hs b/tensorflow/src/TensorFlow/ControlFlow.hs index 2f20b92..d8ebc46 100644 --- a/tensorflow/src/TensorFlow/ControlFlow.hs +++ b/tensorflow/src/TensorFlow/ControlFlow.hs @@ -25,9 +25,6 @@ module TensorFlow.ControlFlow , noOp ) where -import qualified Data.Set as Set -import Lens.Family2 ((&), (.~)) - import TensorFlow.BuildOp import TensorFlow.Build import TensorFlow.Nodes @@ -46,11 +43,8 @@ withControlDependencies deps act = do -- When this op finishes, all ops in the input @n@ have finished. This op has -- no output. group :: (MonadBuild m, Nodes t) => t -> m ControlNode -group deps = do - nodes <- build $ Set.toList <$> getNodes deps - -- TODO: slicker way - return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes +group deps = withControlDependencies deps noOp -- | Does nothing. Only useful as a placeholder for control edges. -noOp :: ControlNode -noOp = buildOp $ opDef "NoOp" +noOp :: MonadBuild m => m ControlNode +noOp = build $ buildOp [] $ opDef "NoOp" diff --git a/tensorflow/src/TensorFlow/Core.hs b/tensorflow/src/TensorFlow/Core.hs index 7764e37..7834211 100644 --- a/tensorflow/src/TensorFlow/Core.hs +++ b/tensorflow/src/TensorFlow/Core.hs @@ -57,9 +57,9 @@ module TensorFlow.Core , Tensor , Value , Ref - , TensorKind(..) , value , tensorFromName + , expr -- ** Element types , TensorType , TensorData diff --git a/tensorflow/src/TensorFlow/Nodes.hs b/tensorflow/src/TensorFlow/Nodes.hs index a7ce925..6731600 100644 --- a/tensorflow/src/TensorFlow/Nodes.hs +++ b/tensorflow/src/TensorFlow/Nodes.hs @@ -20,6 +20,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} -- For Fetchable (TensorExpr a) module TensorFlow.Nodes where import Control.Applicative (liftA2, liftA3) @@ -28,7 +29,6 @@ import Data.Map.Strict (Map) import Data.Monoid ((<>)) import Data.Set (Set) import Data.Text (Text) -import Lens.Family2 ((^.)) import qualified Data.Map.Strict as Map import qualified Data.Set as Set @@ -90,7 +90,7 @@ instance Fetchable t a => Fetchable [t] [a] where getFetch ts = sequenceA <$> mapM getFetch ts instance Nodes ControlNode where - getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o + getNodes (ControlNode o) = pure $ Set.singleton o -- We use the constraint @(a ~ ())@ to help with type inference. For example, -- if @t :: ControlNode@, then this constraint ensures that @run t :: Session @@ -113,13 +113,13 @@ instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity) getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs instance Nodes (Tensor v a) where - getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp) + getNodes (Tensor o) = Set.singleton . outputNodeName <$> toBuild o -fetchTensorVector :: forall a v . TensorType a +fetchTensorVector :: forall a v . (TensorType a) => Tensor v a -> Build (Fetch (TensorData a)) -fetchTensorVector (Tensor _ o) = do - outputName <- renderOutput o - return $ Fetch (Set.singleton outputName) $ \tensors -> +fetchTensorVector (Tensor o) = do + outputName <- encodeOutput <$> toBuild o + pure $ Fetch (Set.singleton outputName) $ \tensors -> let tensorData = tensors Map.! outputName expectedType = tensorType (undefined :: a) actualType = FFI.tensorDataType tensorData diff --git a/tensorflow/src/TensorFlow/Output.hs b/tensorflow/src/TensorFlow/Output.hs index 9ee31c8..ef6b489 100644 --- a/tensorflow/src/TensorFlow/Output.hs +++ b/tensorflow/src/TensorFlow/Output.hs @@ -22,8 +22,6 @@ module TensorFlow.Output , Device(..) -- * Ops , NodeName(..) - , Op(..) - , opUnrendered , OpDef(..) , opName , opType @@ -34,28 +32,24 @@ module TensorFlow.Output , OutputIx(..) , Output(..) , output - , outputIndex - , outputOp , PendingNodeName(..) , ResourceHandle(..) ) where import qualified Data.Map.Strict as Map -import Data.ProtoLens.TextFormat (showMessage) import Data.String (IsString(..)) import Data.Text (Text) import qualified Data.Text as Text -import Lens.Family2 (Lens', Traversal', (.~), (&), (^.)) +import Lens.Family2 (Lens') import Lens.Family2.Unchecked (lens) import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..)) -import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name) import Data.Default (def) import TensorFlow.Types (Attribute, attrLens) import TensorFlow.Orphans () -- | A type of graph node which has no outputs. These nodes are -- valuable for causing side effects when they are run. -newtype ControlNode = ControlNode { unControlNode :: Op } +newtype ControlNode = ControlNode { unControlNode :: NodeName } -- | The type of op of a node in the graph. This corresponds to the proto field -- NodeDef.op. @@ -66,18 +60,12 @@ instance IsString OpType where fromString = OpType . Text.pack -- | An output of a TensorFlow node. -data Output = Output !OutputIx !Op +data Output = Output {outputIndex :: !OutputIx, outputNodeName :: !NodeName} deriving (Eq, Ord, Show) -output :: OutputIx -> Op -> Output +output :: OutputIx -> NodeName -> Output output = Output -outputOp :: Lens' Output Op -outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o) - -outputIndex :: Lens' Output OutputIx -outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o) - newtype OutputIx = OutputIx { unOutputIx :: Int } deriving (Eq, Ord, Num, Enum, Show) @@ -90,25 +78,6 @@ newtype Device = Device {deviceName :: Text} instance Show Device where show (Device d) = show d --- | The representation of a node in a TensorFlow graph. -data Op - = Rendered !NodeDef -- ^ Properties are fixed, including the - -- device, name, and scope. - | Unrendered !OpDef -- ^ Properties are not fixed, and may change depending - -- on which context this op is rendered in. - deriving (Eq, Ord) - -instance Show Op where - show (Rendered n) = "Rendered " ++ showMessage n - show (Unrendered o) = "Unrendered " ++ show (o ^. opName) - --- | Traverse on the 'Unrendered' of an 'Op'. --- --- Same implementation as _Left. -opUnrendered :: Traversal' Op OpDef -opUnrendered f (Unrendered a) = Unrendered <$> f a -opUnrendered _ (Rendered b) = pure (Rendered b) - -- | Op definition. This corresponds somewhat to the 'NodeDef' proto. data OpDef = OpDef { _opName :: !PendingNodeName @@ -157,7 +126,7 @@ instance IsString Output where (n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr -> Output (fromInteger ix) $ assigned n _ -> Output 0 $ assigned s - where assigned n = Rendered $ def & name .~ Text.pack n + where assigned = NodeName . Text.pack -- | Opaque handle to a mutable resource in the graph. Typical such diff --git a/tensorflow/src/TensorFlow/Session.hs b/tensorflow/src/TensorFlow/Session.hs index bc24e3d..199d7f5 100644 --- a/tensorflow/src/TensorFlow/Session.hs +++ b/tensorflow/src/TensorFlow/Session.hs @@ -163,7 +163,7 @@ runWithFeeds feeds t = do runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a runFetchWithFeeds feeds target (Fetch fetch restore) = do extend - feeds' <- build $ fixFeeds feeds + let feeds' = fixFeeds feeds let fetchNames = encodeUtf8 <$> Set.toList fetch targetNames = toNodeNames $ Set.toList target session <- Session (asks rawSession) @@ -192,8 +192,8 @@ runWithFeeds_ feeds t = do ns <- build $ getNodes t runFetchWithFeeds feeds ns (pure ()) -fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)] -fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o +fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)] +fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d) -- | Starts a concurrent thread which evaluates the given Nodes -- forever until runSession exits or an exception occurs. Graph diff --git a/tensorflow/src/TensorFlow/Tensor.hs b/tensorflow/src/TensorFlow/Tensor.hs index e353eb6..bf3fc1f 100644 --- a/tensorflow/src/TensorFlow/Tensor.hs +++ b/tensorflow/src/TensorFlow/Tensor.hs @@ -16,21 +16,26 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} -- For the Render class module TensorFlow.Tensor where +import Data.ByteString (ByteString) import Data.String (IsString(..)) import qualified Data.Text as Text -import Lens.Family2 (Lens', (^.)) -import Lens.Family2.Unchecked (lens) +import Lens.Family2 ((^.)) +import Lens.Family2.State ((%=), use) -import TensorFlow.Output (Output) +import Proto.Tensorflow.Core.Framework.NodeDef (device) +import TensorFlow.Build +import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..)) import TensorFlow.Types ( TensorData(..) , ListOf(..) @@ -40,52 +45,149 @@ import qualified TensorFlow.Internal.FFI as FFI -- | A named output of a TensorFlow operation. -- -- The type parameter @a@ is the type of the elements in the 'Tensor'. The --- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is --- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a --- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via --- 'value'. -data Tensor v a = Tensor (TensorKind v) Output +-- parameter @v@ is either: +-- +-- * 'Build': An unrendered, immutable value. +-- * 'Value': A rendered, immutable value. +-- * 'Ref': A rendered stateful handle (e.g., a variable). +-- +-- Note that 'expr', 'value', 'render' and 'renderValue' can help convert between +-- the different types of 'Tensor'. +data Tensor v a where + Tensor :: TensorKind v => {tensorOutput :: v Output} -> Tensor v a -data Value -data Ref +newtype Value a = Value {runValue :: a} + deriving Functor --- | This class provides a runtime switch on whether a 'Tensor' should be --- treated as a 'Value' or as a 'Ref'. -data TensorKind v where - ValueKind :: TensorKind Value - RefKind :: TensorKind Ref +instance Applicative Value where + pure = Value + Value f <*> Value x = Value $ f x -tensorKind :: Lens' (Tensor v a) (TensorKind v) -tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o) +instance Monad Value where + f >>= g = g $ runValue f -tensorOutput :: Lens' (Tensor v a) Output -tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o) +newtype Ref a = Ref {runRef :: a} + deriving Functor --- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a --- Ref into Value. This behaves like a no-op. -value :: Tensor v a -> Tensor Value a -value (Tensor _ o) = Tensor ValueKind o +instance Applicative Ref where + pure = Ref + Ref f <*> Ref x = Ref $ f x + +instance Monad Ref where + f >>= g = g $ runRef f + +-- | Cast a 'Tensor Ref' into a 'Tensor Value'. This behaves like a no-op. +value :: Tensor Ref a -> Tensor Value a +value (Tensor o) = Tensor $ Value $ runRef o + +renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a) +renderValue (Tensor o) = render $ Tensor $ toBuild o -- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor' -- when running the graph. data Feed = Feed Output FFI.TensorData +-- | A class ensuring that a given tensor is rendered, i.e., has a fixed +-- name, device, etc. +class TensorKind v => Rendered v where + rendered :: v a -> a + +instance Rendered Value where + rendered = runValue + +instance Rendered Ref where + rendered = runRef + +renderedOutput :: Rendered v => Tensor v a -> Output +renderedOutput = rendered . tensorOutput + +tensorNodeName :: Rendered v => Tensor v a -> NodeName +tensorNodeName = outputNodeName . renderedOutput + + -- | Create a 'Feed' for feeding the given data into a 'Tensor' when running -- the graph. -- -- Note that if a 'Tensor' is rendered, its identity may change; so feeding the -- rendered 'Tensor' may be different than feeding the original 'Tensor'. -feed :: Tensor v a -> TensorData a -> Feed -feed (Tensor _ o) (TensorData td) = Feed o td +feed :: Rendered v => Tensor v a -> TensorData a -> Feed +feed t (TensorData td) = Feed (renderedOutput t) td -- | Create a 'Tensor' for a given name. This can be used to reference nodes -- in a 'GraphDef' that was loaded via 'addGraphDef'. -- TODO(judahjacobson): add more safety checks here. -tensorFromName :: TensorKind v -> Text.Text -> Tensor v a -tensorFromName v = Tensor v . fromString . Text.unpack +tensorFromName :: TensorKind v => Text.Text -> Tensor v a +tensorFromName = Tensor . pure . fromString . Text.unpack + +-- | Like 'tensorFromName', but type-restricted to 'Value'. +tensorValueFromName :: Text.Text -> Tensor Value a +tensorValueFromName = tensorFromName + +-- | Like 'tensorFromName', but type-restricted to 'Ref'. +tensorRefFromName :: Text.Text -> Tensor Ref a +tensorRefFromName = tensorFromName type TensorList v = ListOf (Tensor v) -tensorListOutputs :: TensorList v as -> [Output] +tensorListOutputs :: Rendered v => TensorList v as -> [Output] tensorListOutputs Nil = [] -tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts +tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts + +-- | Places all nodes rendered in the given 'Build' action on the same +-- device as the given Tensor (see also 'withDevice'). Make sure that +-- the action has side effects of rendering the desired tensors. A pure +-- return would not have the desired effect. +colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a +colocateWith t x = do + d <- build $ Device . (^. device) + <$> lookupNode (outputNodeName $ renderedOutput t) + withDevice (Just d) x + + +-- | Render a 'Tensor', fixing its name, scope, device and control inputs from +-- the 'MonadBuild' context. Also renders any dependencies of the 'Tensor' that +-- weren't already rendered. +-- +-- This operation is idempotent; calling 'render' on the same input in the same +-- context will produce the same result. However, rendering the same +-- @Tensor Build@ in two different contexts may result in two different +-- @Tensor Value@s. +render :: MonadBuild m => Tensor Build a -> m (Tensor Value a) +render (Tensor t) = Tensor . Value <$> build t + +-- TODO: better name. +expr :: TensorKind v => Tensor v a -> Tensor Build a +expr (Tensor o) = Tensor $ toBuild o + +-- | Records the given summary action in Build for retrieval with +-- Summary protocol buffer in string form. For safety, use the +-- pre-composed functions: Logging.scalarSummary and +-- Logging.histogramSummary. +addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString -- ^ A 'SummaryTensor' + -> m () +addSummary t = build $ do + -- TODO: more generic way + o <- toBuild $ tensorOutput t + summaries %= (o :) + +-- | Retrieves the summary ops collected thus far. Typically this only +-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used +-- repeatedly, the values accumulate. +collectAllSummaries :: MonadBuild m => m [SummaryTensor] +collectAllSummaries = build $ map (Tensor . Value) <$> use summaries + +-- | Synonym for the tensors that return serialized Summary proto. +type SummaryTensor = Tensor Value ByteString + +-- | An internal class for kinds of Tensors. +class Monad v => TensorKind v where + toBuild :: v a -> Build a + +instance TensorKind Value where + toBuild = return . rendered + +instance TensorKind Ref where + toBuild = return . rendered + +instance TensorKind Build where + toBuild = id