Remove the type parameter from ResourceHandle. (#76)

This change allows us to reenable the rest of the ResourceHandle ops, and
future-proofs us against more being added.  It removes the custom logic that
assumed there was a "dtype" attribute to guess what the type parameter is
(which wasn't true in general.)

When we switch to ResourceHandle (e.g., for queues and variables) we can add
parameters to the wrapper types like "Queue" on a case-by-case basis.
This commit is contained in:
Judah Jacobson 2017-02-21 19:38:26 -08:00 committed by Greg Steuck
parent b3c0997a8c
commit 0c8d41250a
5 changed files with 37 additions and 57 deletions

View File

@ -103,28 +103,4 @@ blackList =
, "_ListToArray"
-- Easy: support larger result tuples.
, "Skipgram"
-- Resource ops which don't use "dtype" as the type parameter.
, "ResourceApplyAdadelta"
, "ResourceApplyAdagrad"
, "ResourceApplyAdagradDA"
, "ResourceApplyAdam"
, "ResourceApplyCenteredRMSProp"
, "ResourceApplyFtrl"
, "ResourceApplyGradientDescent"
, "ResourceApplyMomentum"
, "ResourceApplyProximalAdagrad"
, "ResourceApplyProximalGradientDescent"
, "ResourceApplyRMSProp"
, "ResourceSparseApplyAdadelta"
, "ResourceSparseApplyAdagrad"
, "ResourceSparseApplyAdagradDA"
, "ResourceSparseApplyCenteredRMSProp"
, "ResourceSparseApplyFtrl"
, "ResourceSparseApplyMomentum"
, "ResourceSparseApplyProximalAdagrad"
, "ResourceSparseApplyProximalGradientDescent"
, "ResourceSparseApplyRMSProp"
, "TensorArrayScatterV3"
, "TensorArraySplitV3"
, "TensorArrayWriteV3"
]

View File

@ -32,7 +32,7 @@ where:
(for example: 'TensorType' or 'OneOf').
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one
of those types), and @a@ is either a concrete type or a (constrained) type
variable.
@ -271,7 +271,7 @@ typeSig pOp = constraints
| null (inferredTypeAttrs pOp) = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
ArgTensorEither v <- [parsedArgKind k]]
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ concatMap tensorArgConstraint
$ inferredTypeAttrs pOp
@ -299,19 +299,19 @@ typeSig pOp = constraints
| otherwise = o
-- | Render an op input or output.
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle"
tensorArg :: ParsedArg -> Doc
tensorArg p = case parsedArgCase p of
SimpleArg { argType = t } -> tensorType t
ListArg { argType = t } -> brackets $ tensorType t
ResourceArg -> "ResourceHandle"
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
where
tensorType t = let
v = case parsedArgKind p of
tensorType t k = let
v = case k of
ArgTensorRef -> "Tensor Ref"
ArgTensorValue -> "Tensor Value"
ArgTensorEither v' -> "Tensor" <+> strictText v'
ArgResource -> "ResourceHandle"
a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
ArgTypeAttr n -> renderHaskellName n

View File

@ -16,6 +16,7 @@ module TensorFlow.OpGen.ParsedOp
, ParsedArgCase(..)
, ArgType(..)
, ArgKind(..)
, argKind
, parseOp
, camelCase
) where
@ -108,18 +109,23 @@ data ParsedArg = ParsedArg
{ parsedArgName :: Name
, parsedArgDescription :: Text
, parsedArgCase :: ParsedArgCase
, parsedArgKind :: ArgKind
}
data ParsedArgCase
= SimpleArg { argType :: ArgType }
= SimpleArg { argType :: ArgType, argCaseKind :: ArgKind }
| ListArg
{ argLength :: Name -- ^ The attribute that specifies this list's length.
, argType :: ArgType
, argCaseKind :: ArgKind
}
| MixedListArg { argTypeAttr :: Name }
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
-- ^ A heterogeneous list.
-- TODO(judahjacobson): Implement this.
| ResourceArg
argKind :: ParsedArgCase -> Maybe ArgKind
argKind ResourceArg = Nothing
argKind a = Just $ argCaseKind a
-- | The type of an argument.
data ArgType
@ -131,12 +137,13 @@ data ArgKind
= ArgTensorRef -- Tensor Ref a
| ArgTensorValue -- Tensor Value a
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
| ArgResource -- Resource a
deriving (Eq)
isRefKind :: ArgKind -> Bool
isRefKind ArgTensorRef = True
isRefKind ArgResource = True
isRefKind _ = False
isRefCase :: ParsedArgCase -> Bool
isRefCase a = case argKind a of
Nothing -> True -- Resource
Just ArgTensorRef -> True
_ -> False
makeName :: Text -> Name
makeName n = Name
@ -208,7 +215,7 @@ parseOp o = ParsedOp
, parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description
, parsedOpIsMonadic = o ^. isStateful
|| any (isRefKind . parsedArgKind) parsedInputs
|| any (isRefCase . parsedArgCase) parsedInputs
, ..
}
where
@ -237,13 +244,11 @@ parseOp o = ParsedOp
-- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
inputTensorKind a v
| a ^. type' == DT_RESOURCE = ArgResource
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorEither v
outputTensorKind :: OpDef'ArgDef -> ArgKind
outputTensorKind a
| a ^. type' == DT_RESOURCE = ArgResource
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue
@ -258,6 +263,7 @@ getExplicitInputAttr implicitAttrs a
-- | The type attribute used by this input or output (if any).
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
parsedArgTypeAttr p = case parsedArgCase p of
ResourceArg -> Nothing
SimpleArg {argType = t} -> fromArgType t
ListArg {argType = t} -> fromArgType t
MixedListArg {argTypeAttr = n} -> Just $ tfName n
@ -294,19 +300,18 @@ parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a tKind = ParsedArg
{ parsedArgName = makeName (a ^. name)
, parsedArgDescription = a ^. description
, parsedArgCase = parseArgCase a
, parsedArgKind = tKind
, parsedArgCase = parseArgCase a tKind
}
parseArgCase :: OpDef'ArgDef -> ParsedArgCase
parseArgCase a
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType
| otherwise = SimpleArg thisArgType
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
parseArgCase a tKind
| a ^. type' == DT_RESOURCE = ResourceArg
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
| otherwise = SimpleArg thisArgType tKind
where
thisArgType
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
| a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype")
| otherwise = ArgTypeFixed (a ^. type')
maybeAttr :: Text -> Maybe Name
maybeAttr "" = Nothing

View File

@ -86,7 +86,7 @@ recordResult = do
put $! ResultState (i+1) ns
return $! output i o
instance OpResult (ResourceHandle a) where
instance OpResult ResourceHandle where
toResult = ResourceHandle <$> recordResult
instance OpResult (Tensor Value a) where
@ -150,7 +150,7 @@ buildListOp counts o = buildOp' counts o []
instance BuildOp ControlNode where
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
instance BuildOp (ResourceHandle a) where
instance BuildOp ResourceHandle where
buildOp' = pureResult
instance BuildOp (Tensor Value a) where
@ -189,7 +189,7 @@ instance ( OpResult t1
instance OpResult a => BuildOp (Build a) where
buildOp' = buildResult
instance BuildOp f => BuildOp (ResourceHandle a -> f) where
instance BuildOp f => BuildOp (ResourceHandle -> f) where
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
instance BuildOp f => BuildOp (Tensor v a -> f) where

View File

@ -158,6 +158,5 @@ instance IsString Output where
-- | Opaque handle to a mutable resource in the graph. Typical such
-- resources are variables. The type parameter corresponds to the
-- dtype of the tensor held in the variable.
newtype ResourceHandle a = ResourceHandle Output
-- resources are variables.
newtype ResourceHandle = ResourceHandle Output