1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Support type attributes that aren't used by an input/output. (#51)

We should treat such attributes as regular `DataType` values rather than type
parameters; otherwise we'll get ambiguous types.  As with other attributes,
they can either set by default or passed in as an explicit argument to the op.

Allows us to reenable a couple more ops.
This commit is contained in:
Judah Jacobson 2016-12-15 11:52:48 -08:00 committed by Greg Steuck
parent f170df9d13
commit db75350969
2 changed files with 22 additions and 8 deletions

View file

@ -83,10 +83,6 @@ blackList =
, "Print" , "Print"
, "QueueEnqueue" , "QueueEnqueue"
, "QueueEnqueueMany" , "QueueEnqueueMany"
-- These have type ambiguities because one of the type arguments
-- doesn't appear in the signature.
, "ConditionalAccumulator"
, "SparseConditionalAccumulator"
-- Need list of types support. -- Need list of types support.
, "DecodeCSV" , "DecodeCSV"
, "ParseExample" , "ParseExample"

View file

@ -216,15 +216,23 @@ parseOp o = ParsedOp
(o ^. inputArg) tensorKindParams (o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]] tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg) parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
explicitInputAttrs = sortBy (comparing (tfName . attrName)) -- Type attributes that can be inferred from at least one input or output.
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs) argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
$ o ^. attr $ parsedInputs ++ parsedOutputs
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
-- Integer attributes that can be inferred from the size of at least one
-- input list.
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs) inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
$ o ^. attr $ o ^. attr
implicitAttrs = Set.fromList $ map tfName $ implicitAttrs = Set.fromList $ map tfName $
map attrName inferredTypeAttrs map attrName inferredTypeAttrs
++ map attrName inferredListSizeAttrs ++ map attrName inferredListSizeAttrs
-- Attributes that can't be inferred and don't have defaults, so must be passed
-- as separate arguments to the op.
explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
$ o ^. attr
-- TODO(judahjacobson): Some arguments should be refs. -- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
@ -247,6 +255,16 @@ getExplicitInputAttr implicitAttrs a
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t , t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
| otherwise = Nothing | otherwise = Nothing
-- | The type attribute used by this input or output (if any).
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
parsedArgTypeAttr p = case parsedArgCase p of
SimpleArg {argType = t} -> fromArgType t
ListArg {argType = t} -> fromArgType t
MixedListArg {argTypeAttr = n} -> Just $ tfName n
where
fromArgType (ArgTypeAttr n) = Just $ tfName n
fromArgType _ = Nothing
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType] getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
getInferredTypeAttr a getInferredTypeAttr a
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type' | a ^. type' == "type" = Just $ a ^. allowedValues . list . type'