mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Support type attributes that aren't used by an input/output. (#51)
We should treat such attributes as regular `DataType` values rather than type parameters; otherwise we'll get ambiguous types. As with other attributes, they can either set by default or passed in as an explicit argument to the op. Allows us to reenable a couple more ops.
This commit is contained in:
parent
f170df9d13
commit
db75350969
2 changed files with 22 additions and 8 deletions
|
@ -83,10 +83,6 @@ blackList =
|
||||||
, "Print"
|
, "Print"
|
||||||
, "QueueEnqueue"
|
, "QueueEnqueue"
|
||||||
, "QueueEnqueueMany"
|
, "QueueEnqueueMany"
|
||||||
-- These have type ambiguities because one of the type arguments
|
|
||||||
-- doesn't appear in the signature.
|
|
||||||
, "ConditionalAccumulator"
|
|
||||||
, "SparseConditionalAccumulator"
|
|
||||||
-- Need list of types support.
|
-- Need list of types support.
|
||||||
, "DecodeCSV"
|
, "DecodeCSV"
|
||||||
, "ParseExample"
|
, "ParseExample"
|
||||||
|
|
|
@ -216,15 +216,23 @@ parseOp o = ParsedOp
|
||||||
(o ^. inputArg) tensorKindParams
|
(o ^. inputArg) tensorKindParams
|
||||||
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||||
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
-- Type attributes that can be inferred from at least one input or output.
|
||||||
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
|
||||||
$ o ^. attr
|
$ parsedInputs ++ parsedOutputs
|
||||||
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
|
||||||
|
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
||||||
|
-- Integer attributes that can be inferred from the size of at least one
|
||||||
|
-- input list.
|
||||||
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||||
$ o ^. attr
|
$ o ^. attr
|
||||||
implicitAttrs = Set.fromList $ map tfName $
|
implicitAttrs = Set.fromList $ map tfName $
|
||||||
map attrName inferredTypeAttrs
|
map attrName inferredTypeAttrs
|
||||||
++ map attrName inferredListSizeAttrs
|
++ map attrName inferredListSizeAttrs
|
||||||
|
-- Attributes that can't be inferred and don't have defaults, so must be passed
|
||||||
|
-- as separate arguments to the op.
|
||||||
|
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
||||||
|
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
||||||
|
$ o ^. attr
|
||||||
|
|
||||||
-- TODO(judahjacobson): Some arguments should be refs.
|
-- TODO(judahjacobson): Some arguments should be refs.
|
||||||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||||
|
@ -247,6 +255,16 @@ getExplicitInputAttr implicitAttrs a
|
||||||
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
|
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
|
||||||
| otherwise = Nothing
|
| otherwise = Nothing
|
||||||
|
|
||||||
|
-- | The type attribute used by this input or output (if any).
|
||||||
|
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
|
||||||
|
parsedArgTypeAttr p = case parsedArgCase p of
|
||||||
|
SimpleArg {argType = t} -> fromArgType t
|
||||||
|
ListArg {argType = t} -> fromArgType t
|
||||||
|
MixedListArg {argTypeAttr = n} -> Just $ tfName n
|
||||||
|
where
|
||||||
|
fromArgType (ArgTypeAttr n) = Just $ tfName n
|
||||||
|
fromArgType _ = Nothing
|
||||||
|
|
||||||
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
||||||
getInferredTypeAttr a
|
getInferredTypeAttr a
|
||||||
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
|
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
|
||||||
|
|
Loading…
Reference in a new issue