mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-24 18:49:46 +01:00
Remove the type parameter from ResourceHandle. (#76)
This change allows us to reenable the rest of the ResourceHandle ops, and future-proofs us against more being added. It removes the custom logic that assumed there was a "dtype" attribute to guess what the type parameter is (which wasn't true in general.) When we switch to ResourceHandle (e.g., for queues and variables) we can add parameters to the wrapper types like "Queue" on a case-by-case basis.
This commit is contained in:
parent
b3c0997a8c
commit
0c8d41250a
5 changed files with 37 additions and 57 deletions
|
@ -103,28 +103,4 @@ blackList =
|
|||
, "_ListToArray"
|
||||
-- Easy: support larger result tuples.
|
||||
, "Skipgram"
|
||||
-- Resource ops which don't use "dtype" as the type parameter.
|
||||
, "ResourceApplyAdadelta"
|
||||
, "ResourceApplyAdagrad"
|
||||
, "ResourceApplyAdagradDA"
|
||||
, "ResourceApplyAdam"
|
||||
, "ResourceApplyCenteredRMSProp"
|
||||
, "ResourceApplyFtrl"
|
||||
, "ResourceApplyGradientDescent"
|
||||
, "ResourceApplyMomentum"
|
||||
, "ResourceApplyProximalAdagrad"
|
||||
, "ResourceApplyProximalGradientDescent"
|
||||
, "ResourceApplyRMSProp"
|
||||
, "ResourceSparseApplyAdadelta"
|
||||
, "ResourceSparseApplyAdagrad"
|
||||
, "ResourceSparseApplyAdagradDA"
|
||||
, "ResourceSparseApplyCenteredRMSProp"
|
||||
, "ResourceSparseApplyFtrl"
|
||||
, "ResourceSparseApplyMomentum"
|
||||
, "ResourceSparseApplyProximalAdagrad"
|
||||
, "ResourceSparseApplyProximalGradientDescent"
|
||||
, "ResourceSparseApplyRMSProp"
|
||||
, "TensorArrayScatterV3"
|
||||
, "TensorArraySplitV3"
|
||||
, "TensorArrayWriteV3"
|
||||
]
|
||||
|
|
|
@ -32,7 +32,7 @@ where:
|
|||
(for example: 'TensorType' or 'OneOf').
|
||||
|
||||
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
|
||||
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one
|
||||
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one
|
||||
of those types), and @a@ is either a concrete type or a (constrained) type
|
||||
variable.
|
||||
|
||||
|
@ -271,7 +271,7 @@ typeSig pOp = constraints
|
|||
| null (inferredTypeAttrs pOp) = empty
|
||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||
ArgTensorEither v <- [parsedArgKind k]]
|
||||
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
||||
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
||||
classConstraints = tuple $ concatMap tensorArgConstraint
|
||||
$ inferredTypeAttrs pOp
|
||||
|
@ -299,19 +299,19 @@ typeSig pOp = constraints
|
|||
| otherwise = o
|
||||
|
||||
-- | Render an op input or output.
|
||||
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
|
||||
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle"
|
||||
tensorArg :: ParsedArg -> Doc
|
||||
tensorArg p = case parsedArgCase p of
|
||||
SimpleArg { argType = t } -> tensorType t
|
||||
ListArg { argType = t } -> brackets $ tensorType t
|
||||
ResourceArg -> "ResourceHandle"
|
||||
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
||||
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
||||
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
|
||||
where
|
||||
tensorType t = let
|
||||
v = case parsedArgKind p of
|
||||
tensorType t k = let
|
||||
v = case k of
|
||||
ArgTensorRef -> "Tensor Ref"
|
||||
ArgTensorValue -> "Tensor Value"
|
||||
ArgTensorEither v' -> "Tensor" <+> strictText v'
|
||||
ArgResource -> "ResourceHandle"
|
||||
a = case t of
|
||||
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
||||
ArgTypeAttr n -> renderHaskellName n
|
||||
|
|
|
@ -16,6 +16,7 @@ module TensorFlow.OpGen.ParsedOp
|
|||
, ParsedArgCase(..)
|
||||
, ArgType(..)
|
||||
, ArgKind(..)
|
||||
, argKind
|
||||
, parseOp
|
||||
, camelCase
|
||||
) where
|
||||
|
@ -108,18 +109,23 @@ data ParsedArg = ParsedArg
|
|||
{ parsedArgName :: Name
|
||||
, parsedArgDescription :: Text
|
||||
, parsedArgCase :: ParsedArgCase
|
||||
, parsedArgKind :: ArgKind
|
||||
}
|
||||
|
||||
data ParsedArgCase
|
||||
= SimpleArg { argType :: ArgType }
|
||||
= SimpleArg { argType :: ArgType, argCaseKind :: ArgKind }
|
||||
| ListArg
|
||||
{ argLength :: Name -- ^ The attribute that specifies this list's length.
|
||||
, argType :: ArgType
|
||||
, argCaseKind :: ArgKind
|
||||
}
|
||||
| MixedListArg { argTypeAttr :: Name }
|
||||
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
|
||||
-- ^ A heterogeneous list.
|
||||
-- TODO(judahjacobson): Implement this.
|
||||
| ResourceArg
|
||||
|
||||
argKind :: ParsedArgCase -> Maybe ArgKind
|
||||
argKind ResourceArg = Nothing
|
||||
argKind a = Just $ argCaseKind a
|
||||
|
||||
-- | The type of an argument.
|
||||
data ArgType
|
||||
|
@ -131,12 +137,13 @@ data ArgKind
|
|||
= ArgTensorRef -- Tensor Ref a
|
||||
| ArgTensorValue -- Tensor Value a
|
||||
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
||||
| ArgResource -- Resource a
|
||||
deriving (Eq)
|
||||
|
||||
isRefKind :: ArgKind -> Bool
|
||||
isRefKind ArgTensorRef = True
|
||||
isRefKind ArgResource = True
|
||||
isRefKind _ = False
|
||||
isRefCase :: ParsedArgCase -> Bool
|
||||
isRefCase a = case argKind a of
|
||||
Nothing -> True -- Resource
|
||||
Just ArgTensorRef -> True
|
||||
_ -> False
|
||||
|
||||
makeName :: Text -> Name
|
||||
makeName n = Name
|
||||
|
@ -208,7 +215,7 @@ parseOp o = ParsedOp
|
|||
, parsedOpSummary = o ^. summary
|
||||
, parsedOpDescription = o ^. description
|
||||
, parsedOpIsMonadic = o ^. isStateful
|
||||
|| any (isRefKind . parsedArgKind) parsedInputs
|
||||
|| any (isRefCase . parsedArgCase) parsedInputs
|
||||
, ..
|
||||
}
|
||||
where
|
||||
|
@ -237,13 +244,11 @@ parseOp o = ParsedOp
|
|||
-- TODO(judahjacobson): Some arguments should be refs.
|
||||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||
inputTensorKind a v
|
||||
| a ^. type' == DT_RESOURCE = ArgResource
|
||||
| a ^. isRef = ArgTensorRef
|
||||
| otherwise = ArgTensorEither v
|
||||
|
||||
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
||||
outputTensorKind a
|
||||
| a ^. type' == DT_RESOURCE = ArgResource
|
||||
| a ^. isRef = ArgTensorRef
|
||||
| otherwise = ArgTensorValue
|
||||
|
||||
|
@ -258,6 +263,7 @@ getExplicitInputAttr implicitAttrs a
|
|||
-- | The type attribute used by this input or output (if any).
|
||||
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
|
||||
parsedArgTypeAttr p = case parsedArgCase p of
|
||||
ResourceArg -> Nothing
|
||||
SimpleArg {argType = t} -> fromArgType t
|
||||
ListArg {argType = t} -> fromArgType t
|
||||
MixedListArg {argTypeAttr = n} -> Just $ tfName n
|
||||
|
@ -294,19 +300,18 @@ parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
|
|||
parseArg a tKind = ParsedArg
|
||||
{ parsedArgName = makeName (a ^. name)
|
||||
, parsedArgDescription = a ^. description
|
||||
, parsedArgCase = parseArgCase a
|
||||
, parsedArgKind = tKind
|
||||
, parsedArgCase = parseArgCase a tKind
|
||||
}
|
||||
|
||||
parseArgCase :: OpDef'ArgDef -> ParsedArgCase
|
||||
parseArgCase a
|
||||
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n
|
||||
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType
|
||||
| otherwise = SimpleArg thisArgType
|
||||
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
|
||||
parseArgCase a tKind
|
||||
| a ^. type' == DT_RESOURCE = ResourceArg
|
||||
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
|
||||
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
|
||||
| otherwise = SimpleArg thisArgType tKind
|
||||
where
|
||||
thisArgType
|
||||
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
|
||||
| a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype")
|
||||
| otherwise = ArgTypeFixed (a ^. type')
|
||||
maybeAttr :: Text -> Maybe Name
|
||||
maybeAttr "" = Nothing
|
||||
|
|
|
@ -86,7 +86,7 @@ recordResult = do
|
|||
put $! ResultState (i+1) ns
|
||||
return $! output i o
|
||||
|
||||
instance OpResult (ResourceHandle a) where
|
||||
instance OpResult ResourceHandle where
|
||||
toResult = ResourceHandle <$> recordResult
|
||||
|
||||
instance OpResult (Tensor Value a) where
|
||||
|
@ -150,7 +150,7 @@ buildListOp counts o = buildOp' counts o []
|
|||
instance BuildOp ControlNode where
|
||||
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
||||
|
||||
instance BuildOp (ResourceHandle a) where
|
||||
instance BuildOp ResourceHandle where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance BuildOp (Tensor Value a) where
|
||||
|
@ -189,7 +189,7 @@ instance ( OpResult t1
|
|||
instance OpResult a => BuildOp (Build a) where
|
||||
buildOp' = buildResult
|
||||
|
||||
instance BuildOp f => BuildOp (ResourceHandle a -> f) where
|
||||
instance BuildOp f => BuildOp (ResourceHandle -> f) where
|
||||
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
|
||||
|
||||
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
||||
|
|
|
@ -158,6 +158,5 @@ instance IsString Output where
|
|||
|
||||
|
||||
-- | Opaque handle to a mutable resource in the graph. Typical such
|
||||
-- resources are variables. The type parameter corresponds to the
|
||||
-- dtype of the tensor held in the variable.
|
||||
newtype ResourceHandle a = ResourceHandle Output
|
||||
-- resources are variables.
|
||||
newtype ResourceHandle = ResourceHandle Output
|
||||
|
|
Loading…
Reference in a new issue