From cec666e135205bd4e91204b3c7ee04b83d3b2a8e Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Mon, 21 Nov 2016 10:19:15 -0800 Subject: [PATCH] Fix Ref and Build semantics for generated code. (#37) Also: - Make TensorFlow.Ops.{variable,assign} be the Core generated versions. - Make ops take "Shape" as mandatory input. --- tensorflow-opgen/src/TensorFlow/OpGen.hs | 38 ++++++++++++++++--- .../src/TensorFlow/OpGen/ParsedOp.hs | 16 +++++++- tensorflow-ops/src/TensorFlow/Ops.hs | 22 +++-------- 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index a825ec7..32b497b 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -16,7 +16,32 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeFamilies #-} --- | Rendering of TensorFlow operations as Haskell functions. +{- | Rendering of TensorFlow operations as Haskell functions. + +The basic type signature generated for each op is: + +> {constraints} => {mandatory attrs} -> {input tensors} -> {output tensors} + +where: + +* @{mandatory attrs}@ is of the form @A_1 -> ... -> A_N@, where each @A@ is an + op attribute that doesn't have a default and can't be inferred from other + inputs. + +* @{constraints}@ restrict the type parameters of the input and output tensors + (for example: 'TensorType' or 'OneOf'). + +* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of +the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one +of those types), and @a@ is either a concrete type or a (constrained) type +variable. + +* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and +@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if +it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly +marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs, +it is either @ControlNode@ or @Build ControlNode@.) +-} module TensorFlow.OpGen ( OpGenFlags(..) @@ -259,15 +284,18 @@ typeSig pOp = constraints AttrFloat -> "Float" AttrBool -> "Bool" AttrType -> "DataType" - AttrShape -> "TensorShapeProto" + AttrShape -> "Shape" AttrTensor -> "TensorProto" tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t) outputs = case parsedOutputs pOp of - [] -> "ControlNode" + [] -> wrapOutput "ControlNode" -- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a` - [a] -> tensorArg a <+> "-- ^" <+> argComment a - as -> tuple (map tensorArg as) <+/> resultComment as + [a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a + as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as + wrapOutput o + | parsedOpIsMonadic pOp = "Build" <+> parens o + | otherwise = o -- | Render an op input or output. -- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype" diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs index 1d69fea..ddc1311 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs @@ -41,6 +41,8 @@ import Proto.Tensorflow.Core.Framework.OpDef , description , name , inputArg + , isRef + , isStateful , outputArg , summary , typeListAttr @@ -67,6 +69,10 @@ data ParsedOp = ParsedOp -- Attributes which are list sizes (ints) that are inferred automatically -- from one or more of the input tensors. -- Associated with the list of tensors whose size it describes. + , parsedOpIsMonadic :: Bool + -- ^ Whether this op is stateful or takes a stateful input. Such ops + -- should not be CSE'd and must be monadic in our API (i.e., return a + -- Build action). } data Name = Name @@ -127,6 +133,10 @@ data ArgKind | ArgTensorEither Text -- Tensor v a; the Text is the variable `v` | ArgResource -- Resource a +isRefKind :: ArgKind -> Bool +isRefKind ArgTensorRef = True +isRefKind ArgResource = True +isRefKind _ = False makeName :: Text -> Name makeName n = Name @@ -197,6 +207,8 @@ parseOp o = ParsedOp { parsedOpName = makeName $ o ^. name , parsedOpSummary = o ^. summary , parsedOpDescription = o ^. description + , parsedOpIsMonadic = o ^. isStateful + || any (isRefKind . parsedArgKind) parsedInputs , .. } where @@ -218,11 +230,13 @@ parseOp o = ParsedOp inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind inputTensorKind a v | a ^. type' == DT_RESOURCE = ArgResource + | a ^. isRef = ArgTensorRef | otherwise = ArgTensorEither v outputTensorKind :: OpDef'ArgDef -> ArgKind outputTensorKind a | a ^. type' == DT_RESOURCE = ArgResource + | a ^. isRef = ArgTensorRef | otherwise = ArgTensorValue getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType @@ -230,7 +244,7 @@ getExplicitInputAttr implicitAttrs a | TFName (a ^. name) `Set.notMember` implicitAttrs , a ^. maybe'defaultValue == Nothing , t <- parseAttrType (a ^. type') - , t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t + , t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t | otherwise = Nothing getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType] diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 71b4bf1..e9fa3f5 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -60,7 +60,7 @@ module TensorFlow.Ops , CoreOps.abs , CoreOps.addN , CoreOps.argMax - , assign + , CoreOps.assign , CoreOps.broadcastGradientArgs , CoreOps.cast , CoreOps.concat @@ -97,7 +97,7 @@ module TensorFlow.Ops , CoreOps.sum , CoreOps.transpose , truncatedNormal - , variable + , CoreOps.variable , vector , zeros , CoreOps.zerosLike @@ -154,31 +154,18 @@ matTranspose :: forall a v . TensorType a => Tensor v a -> Tensor Value a matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32]) --- | Create a new, uninitialized stateful Tensor of the given shape. -variable :: forall a . TensorType a => Shape -> Build (Tensor Ref a) -variable shape' = buildOp $ opDef "Variable" - & opAttr "shape" .~ shape' - & opAttr "dtype" .~ tensorType (undefined :: a) - placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a) placeholder shape' = buildOp $ opDef "Placeholder" & opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "shape" .~ shape' --- Assign returns the input ref. -assign :: forall a v . TensorType a - => Tensor Ref a -> Tensor v a -> Build (Tensor Ref a) -assign = buildOp $ opDef "Assign" - & opAttr "T" .~ tensorType (undefined :: a) - & opAttr "use_locking" .~ True - -- | Creates a variable initialized to the given value. -- Initialization happens next time session runs. initializedVariable :: forall a . TensorType a => Tensor Value a -> Build (Tensor Ref a) initializedVariable initializer = do - v <- variable [] -- The shape is not known initially. + v <- CoreOps.variable [] -- The shape is not known initially. (i :: Tensor Ref a) <- buildOp (opDef "Assign" & opAttr "T" .~ tensorType (undefined :: a) @@ -220,7 +207,8 @@ restoreFromName :: forall a . TensorType a restoreFromName path name x = do let restoreOp = buildOp $ opDef "Restore" & opAttr "dt" .~ tensorType (undefined :: a) - group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a) + group =<< CoreOps.assign x + (restoreOp (scalar path) (scalar name) :: Tensor Value a) -- | Restore a tensor's value from a checkpoint file. restore :: forall a . TensorType a