mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-03 16:29:46 +01:00
Support type attributes that aren't used by an input/output. (#51)
We should treat such attributes as regular `DataType` values rather than type parameters; otherwise we'll get ambiguous types. As with other attributes, they can either set by default or passed in as an explicit argument to the op. Allows us to reenable a couple more ops.
This commit is contained in:
parent
f170df9d13
commit
db75350969
2 changed files with 22 additions and 8 deletions
|
@ -83,10 +83,6 @@ blackList =
|
|||
, "Print"
|
||||
, "QueueEnqueue"
|
||||
, "QueueEnqueueMany"
|
||||
-- These have type ambiguities because one of the type arguments
|
||||
-- doesn't appear in the signature.
|
||||
, "ConditionalAccumulator"
|
||||
, "SparseConditionalAccumulator"
|
||||
-- Need list of types support.
|
||||
, "DecodeCSV"
|
||||
, "ParseExample"
|
||||
|
|
|
@ -216,15 +216,23 @@ parseOp o = ParsedOp
|
|||
(o ^. inputArg) tensorKindParams
|
||||
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
||||
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
||||
$ o ^. attr
|
||||
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
||||
-- Type attributes that can be inferred from at least one input or output.
|
||||
argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
|
||||
$ parsedInputs ++ parsedOutputs
|
||||
inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
|
||||
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
||||
-- Integer attributes that can be inferred from the size of at least one
|
||||
-- input list.
|
||||
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||
$ o ^. attr
|
||||
implicitAttrs = Set.fromList $ map tfName $
|
||||
map attrName inferredTypeAttrs
|
||||
++ map attrName inferredListSizeAttrs
|
||||
-- Attributes that can't be inferred and don't have defaults, so must be passed
|
||||
-- as separate arguments to the op.
|
||||
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
||||
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
||||
$ o ^. attr
|
||||
|
||||
-- TODO(judahjacobson): Some arguments should be refs.
|
||||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||
|
@ -247,6 +255,16 @@ getExplicitInputAttr implicitAttrs a
|
|||
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
|
||||
| otherwise = Nothing
|
||||
|
||||
-- | The type attribute used by this input or output (if any).
|
||||
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
|
||||
parsedArgTypeAttr p = case parsedArgCase p of
|
||||
SimpleArg {argType = t} -> fromArgType t
|
||||
ListArg {argType = t} -> fromArgType t
|
||||
MixedListArg {argTypeAttr = n} -> Just $ tfName n
|
||||
where
|
||||
fromArgType (ArgTypeAttr n) = Just $ tfName n
|
||||
fromArgType _ = Nothing
|
||||
|
||||
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
||||
getInferredTypeAttr a
|
||||
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
|
||||
|
|
Loading…
Reference in a new issue