From 0c8d41250aebc4290840fa1dea0907b5e73e9901 Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Tue, 21 Feb 2017 19:38:26 -0800 Subject: [PATCH] Remove the type parameter from ResourceHandle. (#76) This change allows us to reenable the rest of the ResourceHandle ops, and future-proofs us against more being added. It removes the custom logic that assumed there was a "dtype" attribute to guess what the type parameter is (which wasn't true in general.) When we switch to ResourceHandle (e.g., for queues and variables) we can add parameters to the wrapper types like "Queue" on a case-by-case basis. --- tensorflow-core-ops/Setup.hs | 24 ----------- tensorflow-opgen/src/TensorFlow/OpGen.hs | 16 +++---- .../src/TensorFlow/OpGen/ParsedOp.hs | 43 +++++++++++-------- tensorflow/src/TensorFlow/BuildOp.hs | 6 +-- tensorflow/src/TensorFlow/Output.hs | 5 +-- 5 files changed, 37 insertions(+), 57 deletions(-) diff --git a/tensorflow-core-ops/Setup.hs b/tensorflow-core-ops/Setup.hs index ff4a8c5..349ae11 100644 --- a/tensorflow-core-ops/Setup.hs +++ b/tensorflow-core-ops/Setup.hs @@ -103,28 +103,4 @@ blackList = , "_ListToArray" -- Easy: support larger result tuples. , "Skipgram" - -- Resource ops which don't use "dtype" as the type parameter. - , "ResourceApplyAdadelta" - , "ResourceApplyAdagrad" - , "ResourceApplyAdagradDA" - , "ResourceApplyAdam" - , "ResourceApplyCenteredRMSProp" - , "ResourceApplyFtrl" - , "ResourceApplyGradientDescent" - , "ResourceApplyMomentum" - , "ResourceApplyProximalAdagrad" - , "ResourceApplyProximalGradientDescent" - , "ResourceApplyRMSProp" - , "ResourceSparseApplyAdadelta" - , "ResourceSparseApplyAdagrad" - , "ResourceSparseApplyAdagradDA" - , "ResourceSparseApplyCenteredRMSProp" - , "ResourceSparseApplyFtrl" - , "ResourceSparseApplyMomentum" - , "ResourceSparseApplyProximalAdagrad" - , "ResourceSparseApplyProximalGradientDescent" - , "ResourceSparseApplyRMSProp" - , "TensorArrayScatterV3" - , "TensorArraySplitV3" - , "TensorArrayWriteV3" ] diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index 2bca901..8e2a6f7 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -32,7 +32,7 @@ where: (for example: 'TensorType' or 'OneOf'). * @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of -the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one +the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one of those types), and @a@ is either a concrete type or a (constrained) type variable. @@ -271,7 +271,7 @@ typeSig pOp = constraints | null (inferredTypeAttrs pOp) = empty | otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>" typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, - ArgTensorEither v <- [parsedArgKind k]] + Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]] ++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp] classConstraints = tuple $ concatMap tensorArgConstraint $ inferredTypeAttrs pOp @@ -299,19 +299,19 @@ typeSig pOp = constraints | otherwise = o -- | Render an op input or output. --- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype" +-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle" tensorArg :: ParsedArg -> Doc tensorArg p = case parsedArgCase p of - SimpleArg { argType = t } -> tensorType t - ListArg { argType = t } -> brackets $ tensorType t + ResourceArg -> "ResourceHandle" + SimpleArg { argType = t, argCaseKind = k } -> tensorType t k + ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}" where - tensorType t = let - v = case parsedArgKind p of + tensorType t k = let + v = case k of ArgTensorRef -> "Tensor Ref" ArgTensorValue -> "Tensor Value" ArgTensorEither v' -> "Tensor" <+> strictText v' - ArgResource -> "ResourceHandle" a = case t of ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt ArgTypeAttr n -> renderHaskellName n diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs index 74263c4..49d5c34 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs @@ -16,6 +16,7 @@ module TensorFlow.OpGen.ParsedOp , ParsedArgCase(..) , ArgType(..) , ArgKind(..) + , argKind , parseOp , camelCase ) where @@ -108,18 +109,23 @@ data ParsedArg = ParsedArg { parsedArgName :: Name , parsedArgDescription :: Text , parsedArgCase :: ParsedArgCase - , parsedArgKind :: ArgKind } data ParsedArgCase - = SimpleArg { argType :: ArgType } + = SimpleArg { argType :: ArgType, argCaseKind :: ArgKind } | ListArg { argLength :: Name -- ^ The attribute that specifies this list's length. , argType :: ArgType + , argCaseKind :: ArgKind } - | MixedListArg { argTypeAttr :: Name } + | MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind } -- ^ A heterogeneous list. -- TODO(judahjacobson): Implement this. + | ResourceArg + +argKind :: ParsedArgCase -> Maybe ArgKind +argKind ResourceArg = Nothing +argKind a = Just $ argCaseKind a -- | The type of an argument. data ArgType @@ -131,12 +137,13 @@ data ArgKind = ArgTensorRef -- Tensor Ref a | ArgTensorValue -- Tensor Value a | ArgTensorEither Text -- Tensor v a; the Text is the variable `v` - | ArgResource -- Resource a + deriving (Eq) -isRefKind :: ArgKind -> Bool -isRefKind ArgTensorRef = True -isRefKind ArgResource = True -isRefKind _ = False +isRefCase :: ParsedArgCase -> Bool +isRefCase a = case argKind a of + Nothing -> True -- Resource + Just ArgTensorRef -> True + _ -> False makeName :: Text -> Name makeName n = Name @@ -208,7 +215,7 @@ parseOp o = ParsedOp , parsedOpSummary = o ^. summary , parsedOpDescription = o ^. description , parsedOpIsMonadic = o ^. isStateful - || any (isRefKind . parsedArgKind) parsedInputs + || any (isRefCase . parsedArgCase) parsedInputs , .. } where @@ -237,13 +244,11 @@ parseOp o = ParsedOp -- TODO(judahjacobson): Some arguments should be refs. inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind inputTensorKind a v - | a ^. type' == DT_RESOURCE = ArgResource | a ^. isRef = ArgTensorRef | otherwise = ArgTensorEither v outputTensorKind :: OpDef'ArgDef -> ArgKind outputTensorKind a - | a ^. type' == DT_RESOURCE = ArgResource | a ^. isRef = ArgTensorRef | otherwise = ArgTensorValue @@ -258,6 +263,7 @@ getExplicitInputAttr implicitAttrs a -- | The type attribute used by this input or output (if any). parsedArgTypeAttr :: ParsedArg -> Maybe TFName parsedArgTypeAttr p = case parsedArgCase p of + ResourceArg -> Nothing SimpleArg {argType = t} -> fromArgType t ListArg {argType = t} -> fromArgType t MixedListArg {argTypeAttr = n} -> Just $ tfName n @@ -294,19 +300,18 @@ parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg parseArg a tKind = ParsedArg { parsedArgName = makeName (a ^. name) , parsedArgDescription = a ^. description - , parsedArgCase = parseArgCase a - , parsedArgKind = tKind + , parsedArgCase = parseArgCase a tKind } -parseArgCase :: OpDef'ArgDef -> ParsedArgCase -parseArgCase a - | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n - | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType - | otherwise = SimpleArg thisArgType +parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase +parseArgCase a tKind + | a ^. type' == DT_RESOURCE = ResourceArg + | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind + | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind + | otherwise = SimpleArg thisArgType tKind where thisArgType | Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n - | a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype") | otherwise = ArgTypeFixed (a ^. type') maybeAttr :: Text -> Maybe Name maybeAttr "" = Nothing diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index 3d6c675..6b2df3e 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -86,7 +86,7 @@ recordResult = do put $! ResultState (i+1) ns return $! output i o -instance OpResult (ResourceHandle a) where +instance OpResult ResourceHandle where toResult = ResourceHandle <$> recordResult instance OpResult (Tensor Value a) where @@ -150,7 +150,7 @@ buildListOp counts o = buildOp' counts o [] instance BuildOp ControlNode where buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts -instance BuildOp (ResourceHandle a) where +instance BuildOp ResourceHandle where buildOp' = pureResult instance BuildOp (Tensor Value a) where @@ -189,7 +189,7 @@ instance ( OpResult t1 instance OpResult a => BuildOp (Build a) where buildOp' = buildResult -instance BuildOp f => BuildOp (ResourceHandle a -> f) where +instance BuildOp f => BuildOp (ResourceHandle -> f) where buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts) instance BuildOp f => BuildOp (Tensor v a -> f) where diff --git a/tensorflow/src/TensorFlow/Output.hs b/tensorflow/src/TensorFlow/Output.hs index 241a622..9edd720 100644 --- a/tensorflow/src/TensorFlow/Output.hs +++ b/tensorflow/src/TensorFlow/Output.hs @@ -158,6 +158,5 @@ instance IsString Output where -- | Opaque handle to a mutable resource in the graph. Typical such --- resources are variables. The type parameter corresponds to the --- dtype of the tensor held in the variable. -newtype ResourceHandle a = ResourceHandle Output +-- resources are variables. +newtype ResourceHandle = ResourceHandle Output