1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-03 16:29:46 +01:00

Support type attributes that aren't used by an input/output. (#51)

We should treat such attributes as regular `DataType` values rather than type
parameters; otherwise we'll get ambiguous types.  As with other attributes,
they can either set by default or passed in as an explicit argument to the op.

Allows us to reenable a couple more ops.
This commit is contained in:
Judah Jacobson 2016-12-15 11:52:48 -08:00 committed by Greg Steuck
parent f170df9d13
commit db75350969
2 changed files with 22 additions and 8 deletions

View file

@ -83,10 +83,6 @@ blackList =
, "Print"
, "QueueEnqueue"
, "QueueEnqueueMany"
-- These have type ambiguities because one of the type arguments
-- doesn't appear in the signature.
, "ConditionalAccumulator"
, "SparseConditionalAccumulator"
-- Need list of types support.
, "DecodeCSV"
, "ParseExample"

View file

@ -216,15 +216,23 @@ parseOp o = ParsedOp
(o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
$ o ^. attr
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
-- Type attributes that can be inferred from at least one input or output.
argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
$ parsedInputs ++ parsedOutputs
inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
-- Integer attributes that can be inferred from the size of at least one
-- input list.
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
$ o ^. attr
implicitAttrs = Set.fromList $ map tfName $
map attrName inferredTypeAttrs
++ map attrName inferredListSizeAttrs
-- Attributes that can't be inferred and don't have defaults, so must be passed
-- as separate arguments to the op.
explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
$ o ^. attr
-- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
@ -247,6 +255,16 @@ getExplicitInputAttr implicitAttrs a
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
| otherwise = Nothing
-- | The type attribute used by this input or output (if any).
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
parsedArgTypeAttr p = case parsedArgCase p of
SimpleArg {argType = t} -> fromArgType t
ListArg {argType = t} -> fromArgType t
MixedListArg {argTypeAttr = n} -> Just $ tfName n
where
fromArgType (ArgTypeAttr n) = Just $ tfName n
fromArgType _ = Nothing
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
getInferredTypeAttr a
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'