diff --git a/tensorflow-core-ops/Setup.hs b/tensorflow-core-ops/Setup.hs index 25e93b9..cdd835d 100644 --- a/tensorflow-core-ops/Setup.hs +++ b/tensorflow-core-ops/Setup.hs @@ -83,10 +83,6 @@ blackList = , "Print" , "QueueEnqueue" , "QueueEnqueueMany" - -- These have type ambiguities because one of the type arguments - -- doesn't appear in the signature. - , "ConditionalAccumulator" - , "SparseConditionalAccumulator" -- Need list of types support. , "DecodeCSV" , "ParseExample" diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs index ddc1311..74263c4 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs @@ -216,15 +216,23 @@ parseOp o = ParsedOp (o ^. inputArg) tensorKindParams tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]] parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg) - explicitInputAttrs = sortBy (comparing (tfName . attrName)) - $ mapMaybeAttrs (getExplicitInputAttr implicitAttrs) - $ o ^. attr - inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr + -- Type attributes that can be inferred from at least one input or output. + argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr + $ parsedInputs ++ parsedOutputs + inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName) + $ mapMaybeAttrs getInferredTypeAttr $ o ^. attr + -- Integer attributes that can be inferred from the size of at least one + -- input list. inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs) $ o ^. attr implicitAttrs = Set.fromList $ map tfName $ map attrName inferredTypeAttrs ++ map attrName inferredListSizeAttrs + -- Attributes that can't be inferred and don't have defaults, so must be passed + -- as separate arguments to the op. + explicitInputAttrs = sortBy (comparing (tfName . attrName)) + $ mapMaybeAttrs (getExplicitInputAttr implicitAttrs) + $ o ^. attr -- TODO(judahjacobson): Some arguments should be refs. inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind @@ -247,6 +255,16 @@ getExplicitInputAttr implicitAttrs a , t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t | otherwise = Nothing +-- | The type attribute used by this input or output (if any). +parsedArgTypeAttr :: ParsedArg -> Maybe TFName +parsedArgTypeAttr p = case parsedArgCase p of + SimpleArg {argType = t} -> fromArgType t + ListArg {argType = t} -> fromArgType t + MixedListArg {argTypeAttr = n} -> Just $ tfName n + where + fromArgType (ArgTypeAttr n) = Just $ tfName n + fromArgType _ = Nothing + getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType] getInferredTypeAttr a | a ^. type' == "type" = Just $ a ^. allowedValues . list . type'