diff --git a/tensorflow-core-ops/Setup.hs b/tensorflow-core-ops/Setup.hs index ff4a8c5..349ae11 100644 --- a/tensorflow-core-ops/Setup.hs +++ b/tensorflow-core-ops/Setup.hs @@ -103,28 +103,4 @@ blackList = , "_ListToArray" -- Easy: support larger result tuples. , "Skipgram" - -- Resource ops which don't use "dtype" as the type parameter. - , "ResourceApplyAdadelta" - , "ResourceApplyAdagrad" - , "ResourceApplyAdagradDA" - , "ResourceApplyAdam" - , "ResourceApplyCenteredRMSProp" - , "ResourceApplyFtrl" - , "ResourceApplyGradientDescent" - , "ResourceApplyMomentum" - , "ResourceApplyProximalAdagrad" - , "ResourceApplyProximalGradientDescent" - , "ResourceApplyRMSProp" - , "ResourceSparseApplyAdadelta" - , "ResourceSparseApplyAdagrad" - , "ResourceSparseApplyAdagradDA" - , "ResourceSparseApplyCenteredRMSProp" - , "ResourceSparseApplyFtrl" - , "ResourceSparseApplyMomentum" - , "ResourceSparseApplyProximalAdagrad" - , "ResourceSparseApplyProximalGradientDescent" - , "ResourceSparseApplyRMSProp" - , "TensorArrayScatterV3" - , "TensorArraySplitV3" - , "TensorArrayWriteV3" ] diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index 2bca901..8e2a6f7 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -32,7 +32,7 @@ where: (for example: 'TensorType' or 'OneOf'). * @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of -the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one +the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one of those types), and @a@ is either a concrete type or a (constrained) type variable. @@ -271,7 +271,7 @@ typeSig pOp = constraints | null (inferredTypeAttrs pOp) = empty | otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>" typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, - ArgTensorEither v <- [parsedArgKind k]] + Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]] ++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp] classConstraints = tuple $ concatMap tensorArgConstraint $ inferredTypeAttrs pOp @@ -299,19 +299,19 @@ typeSig pOp = constraints | otherwise = o -- | Render an op input or output. --- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype" +-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle" tensorArg :: ParsedArg -> Doc tensorArg p = case parsedArgCase p of - SimpleArg { argType = t } -> tensorType t - ListArg { argType = t } -> brackets $ tensorType t + ResourceArg -> "ResourceHandle" + SimpleArg { argType = t, argCaseKind = k } -> tensorType t k + ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}" where - tensorType t = let - v = case parsedArgKind p of + tensorType t k = let + v = case k of ArgTensorRef -> "Tensor Ref" ArgTensorValue -> "Tensor Value" ArgTensorEither v' -> "Tensor" <+> strictText v' - ArgResource -> "ResourceHandle" a = case t of ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt ArgTypeAttr n -> renderHaskellName n diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs index 74263c4..49d5c34 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs @@ -16,6 +16,7 @@ module TensorFlow.OpGen.ParsedOp , ParsedArgCase(..) , ArgType(..) , ArgKind(..) + , argKind , parseOp , camelCase ) where @@ -108,18 +109,23 @@ data ParsedArg = ParsedArg { parsedArgName :: Name , parsedArgDescription :: Text , parsedArgCase :: ParsedArgCase - , parsedArgKind :: ArgKind } data ParsedArgCase - = SimpleArg { argType :: ArgType } + = SimpleArg { argType :: ArgType, argCaseKind :: ArgKind } | ListArg { argLength :: Name -- ^ The attribute that specifies this list's length. , argType :: ArgType + , argCaseKind :: ArgKind } - | MixedListArg { argTypeAttr :: Name } + | MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind } -- ^ A heterogeneous list. -- TODO(judahjacobson): Implement this. + | ResourceArg + +argKind :: ParsedArgCase -> Maybe ArgKind +argKind ResourceArg = Nothing +argKind a = Just $ argCaseKind a -- | The type of an argument. data ArgType @@ -131,12 +137,13 @@ data ArgKind = ArgTensorRef -- Tensor Ref a | ArgTensorValue -- Tensor Value a | ArgTensorEither Text -- Tensor v a; the Text is the variable `v` - | ArgResource -- Resource a + deriving (Eq) -isRefKind :: ArgKind -> Bool -isRefKind ArgTensorRef = True -isRefKind ArgResource = True -isRefKind _ = False +isRefCase :: ParsedArgCase -> Bool +isRefCase a = case argKind a of + Nothing -> True -- Resource + Just ArgTensorRef -> True + _ -> False makeName :: Text -> Name makeName n = Name @@ -208,7 +215,7 @@ parseOp o = ParsedOp , parsedOpSummary = o ^. summary , parsedOpDescription = o ^. description , parsedOpIsMonadic = o ^. isStateful - || any (isRefKind . parsedArgKind) parsedInputs + || any (isRefCase . parsedArgCase) parsedInputs , .. } where @@ -237,13 +244,11 @@ parseOp o = ParsedOp -- TODO(judahjacobson): Some arguments should be refs. inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind inputTensorKind a v - | a ^. type' == DT_RESOURCE = ArgResource | a ^. isRef = ArgTensorRef | otherwise = ArgTensorEither v outputTensorKind :: OpDef'ArgDef -> ArgKind outputTensorKind a - | a ^. type' == DT_RESOURCE = ArgResource | a ^. isRef = ArgTensorRef | otherwise = ArgTensorValue @@ -258,6 +263,7 @@ getExplicitInputAttr implicitAttrs a -- | The type attribute used by this input or output (if any). parsedArgTypeAttr :: ParsedArg -> Maybe TFName parsedArgTypeAttr p = case parsedArgCase p of + ResourceArg -> Nothing SimpleArg {argType = t} -> fromArgType t ListArg {argType = t} -> fromArgType t MixedListArg {argTypeAttr = n} -> Just $ tfName n @@ -294,19 +300,18 @@ parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg parseArg a tKind = ParsedArg { parsedArgName = makeName (a ^. name) , parsedArgDescription = a ^. description - , parsedArgCase = parseArgCase a - , parsedArgKind = tKind + , parsedArgCase = parseArgCase a tKind } -parseArgCase :: OpDef'ArgDef -> ParsedArgCase -parseArgCase a - | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n - | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType - | otherwise = SimpleArg thisArgType +parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase +parseArgCase a tKind + | a ^. type' == DT_RESOURCE = ResourceArg + | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind + | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind + | otherwise = SimpleArg thisArgType tKind where thisArgType | Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n - | a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype") | otherwise = ArgTypeFixed (a ^. type') maybeAttr :: Text -> Maybe Name maybeAttr "" = Nothing diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index 3d6c675..6b2df3e 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -86,7 +86,7 @@ recordResult = do put $! ResultState (i+1) ns return $! output i o -instance OpResult (ResourceHandle a) where +instance OpResult ResourceHandle where toResult = ResourceHandle <$> recordResult instance OpResult (Tensor Value a) where @@ -150,7 +150,7 @@ buildListOp counts o = buildOp' counts o [] instance BuildOp ControlNode where buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts -instance BuildOp (ResourceHandle a) where +instance BuildOp ResourceHandle where buildOp' = pureResult instance BuildOp (Tensor Value a) where @@ -189,7 +189,7 @@ instance ( OpResult t1 instance OpResult a => BuildOp (Build a) where buildOp' = buildResult -instance BuildOp f => BuildOp (ResourceHandle a -> f) where +instance BuildOp f => BuildOp (ResourceHandle -> f) where buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts) instance BuildOp f => BuildOp (Tensor v a -> f) where diff --git a/tensorflow/src/TensorFlow/Output.hs b/tensorflow/src/TensorFlow/Output.hs index 241a622..9edd720 100644 --- a/tensorflow/src/TensorFlow/Output.hs +++ b/tensorflow/src/TensorFlow/Output.hs @@ -158,6 +158,5 @@ instance IsString Output where -- | Opaque handle to a mutable resource in the graph. Typical such --- resources are variables. The type parameter corresponds to the --- dtype of the tensor held in the variable. -newtype ResourceHandle a = ResourceHandle Output +-- resources are variables. +newtype ResourceHandle = ResourceHandle Output