mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Remove the type parameter from ResourceHandle. (#76)
This change allows us to reenable the rest of the ResourceHandle ops, and future-proofs us against more being added. It removes the custom logic that assumed there was a "dtype" attribute to guess what the type parameter is (which wasn't true in general.) When we switch to ResourceHandle (e.g., for queues and variables) we can add parameters to the wrapper types like "Queue" on a case-by-case basis.
This commit is contained in:
parent
b3c0997a8c
commit
0c8d41250a
5 changed files with 37 additions and 57 deletions
|
@ -103,28 +103,4 @@ blackList =
|
||||||
, "_ListToArray"
|
, "_ListToArray"
|
||||||
-- Easy: support larger result tuples.
|
-- Easy: support larger result tuples.
|
||||||
, "Skipgram"
|
, "Skipgram"
|
||||||
-- Resource ops which don't use "dtype" as the type parameter.
|
|
||||||
, "ResourceApplyAdadelta"
|
|
||||||
, "ResourceApplyAdagrad"
|
|
||||||
, "ResourceApplyAdagradDA"
|
|
||||||
, "ResourceApplyAdam"
|
|
||||||
, "ResourceApplyCenteredRMSProp"
|
|
||||||
, "ResourceApplyFtrl"
|
|
||||||
, "ResourceApplyGradientDescent"
|
|
||||||
, "ResourceApplyMomentum"
|
|
||||||
, "ResourceApplyProximalAdagrad"
|
|
||||||
, "ResourceApplyProximalGradientDescent"
|
|
||||||
, "ResourceApplyRMSProp"
|
|
||||||
, "ResourceSparseApplyAdadelta"
|
|
||||||
, "ResourceSparseApplyAdagrad"
|
|
||||||
, "ResourceSparseApplyAdagradDA"
|
|
||||||
, "ResourceSparseApplyCenteredRMSProp"
|
|
||||||
, "ResourceSparseApplyFtrl"
|
|
||||||
, "ResourceSparseApplyMomentum"
|
|
||||||
, "ResourceSparseApplyProximalAdagrad"
|
|
||||||
, "ResourceSparseApplyProximalGradientDescent"
|
|
||||||
, "ResourceSparseApplyRMSProp"
|
|
||||||
, "TensorArrayScatterV3"
|
|
||||||
, "TensorArraySplitV3"
|
|
||||||
, "TensorArrayWriteV3"
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -32,7 +32,7 @@ where:
|
||||||
(for example: 'TensorType' or 'OneOf').
|
(for example: 'TensorType' or 'OneOf').
|
||||||
|
|
||||||
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
|
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
|
||||||
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one
|
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one
|
||||||
of those types), and @a@ is either a concrete type or a (constrained) type
|
of those types), and @a@ is either a concrete type or a (constrained) type
|
||||||
variable.
|
variable.
|
||||||
|
|
||||||
|
@ -271,7 +271,7 @@ typeSig pOp = constraints
|
||||||
| null (inferredTypeAttrs pOp) = empty
|
| null (inferredTypeAttrs pOp) = empty
|
||||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||||
ArgTensorEither v <- [parsedArgKind k]]
|
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
||||||
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
||||||
classConstraints = tuple $ concatMap tensorArgConstraint
|
classConstraints = tuple $ concatMap tensorArgConstraint
|
||||||
$ inferredTypeAttrs pOp
|
$ inferredTypeAttrs pOp
|
||||||
|
@ -299,19 +299,19 @@ typeSig pOp = constraints
|
||||||
| otherwise = o
|
| otherwise = o
|
||||||
|
|
||||||
-- | Render an op input or output.
|
-- | Render an op input or output.
|
||||||
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
|
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle"
|
||||||
tensorArg :: ParsedArg -> Doc
|
tensorArg :: ParsedArg -> Doc
|
||||||
tensorArg p = case parsedArgCase p of
|
tensorArg p = case parsedArgCase p of
|
||||||
SimpleArg { argType = t } -> tensorType t
|
ResourceArg -> "ResourceHandle"
|
||||||
ListArg { argType = t } -> brackets $ tensorType t
|
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
||||||
|
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
||||||
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
|
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
|
||||||
where
|
where
|
||||||
tensorType t = let
|
tensorType t k = let
|
||||||
v = case parsedArgKind p of
|
v = case k of
|
||||||
ArgTensorRef -> "Tensor Ref"
|
ArgTensorRef -> "Tensor Ref"
|
||||||
ArgTensorValue -> "Tensor Value"
|
ArgTensorValue -> "Tensor Value"
|
||||||
ArgTensorEither v' -> "Tensor" <+> strictText v'
|
ArgTensorEither v' -> "Tensor" <+> strictText v'
|
||||||
ArgResource -> "ResourceHandle"
|
|
||||||
a = case t of
|
a = case t of
|
||||||
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
||||||
ArgTypeAttr n -> renderHaskellName n
|
ArgTypeAttr n -> renderHaskellName n
|
||||||
|
|
|
@ -16,6 +16,7 @@ module TensorFlow.OpGen.ParsedOp
|
||||||
, ParsedArgCase(..)
|
, ParsedArgCase(..)
|
||||||
, ArgType(..)
|
, ArgType(..)
|
||||||
, ArgKind(..)
|
, ArgKind(..)
|
||||||
|
, argKind
|
||||||
, parseOp
|
, parseOp
|
||||||
, camelCase
|
, camelCase
|
||||||
) where
|
) where
|
||||||
|
@ -108,18 +109,23 @@ data ParsedArg = ParsedArg
|
||||||
{ parsedArgName :: Name
|
{ parsedArgName :: Name
|
||||||
, parsedArgDescription :: Text
|
, parsedArgDescription :: Text
|
||||||
, parsedArgCase :: ParsedArgCase
|
, parsedArgCase :: ParsedArgCase
|
||||||
, parsedArgKind :: ArgKind
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data ParsedArgCase
|
data ParsedArgCase
|
||||||
= SimpleArg { argType :: ArgType }
|
= SimpleArg { argType :: ArgType, argCaseKind :: ArgKind }
|
||||||
| ListArg
|
| ListArg
|
||||||
{ argLength :: Name -- ^ The attribute that specifies this list's length.
|
{ argLength :: Name -- ^ The attribute that specifies this list's length.
|
||||||
, argType :: ArgType
|
, argType :: ArgType
|
||||||
|
, argCaseKind :: ArgKind
|
||||||
}
|
}
|
||||||
| MixedListArg { argTypeAttr :: Name }
|
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
|
||||||
-- ^ A heterogeneous list.
|
-- ^ A heterogeneous list.
|
||||||
-- TODO(judahjacobson): Implement this.
|
-- TODO(judahjacobson): Implement this.
|
||||||
|
| ResourceArg
|
||||||
|
|
||||||
|
argKind :: ParsedArgCase -> Maybe ArgKind
|
||||||
|
argKind ResourceArg = Nothing
|
||||||
|
argKind a = Just $ argCaseKind a
|
||||||
|
|
||||||
-- | The type of an argument.
|
-- | The type of an argument.
|
||||||
data ArgType
|
data ArgType
|
||||||
|
@ -131,12 +137,13 @@ data ArgKind
|
||||||
= ArgTensorRef -- Tensor Ref a
|
= ArgTensorRef -- Tensor Ref a
|
||||||
| ArgTensorValue -- Tensor Value a
|
| ArgTensorValue -- Tensor Value a
|
||||||
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
||||||
| ArgResource -- Resource a
|
deriving (Eq)
|
||||||
|
|
||||||
isRefKind :: ArgKind -> Bool
|
isRefCase :: ParsedArgCase -> Bool
|
||||||
isRefKind ArgTensorRef = True
|
isRefCase a = case argKind a of
|
||||||
isRefKind ArgResource = True
|
Nothing -> True -- Resource
|
||||||
isRefKind _ = False
|
Just ArgTensorRef -> True
|
||||||
|
_ -> False
|
||||||
|
|
||||||
makeName :: Text -> Name
|
makeName :: Text -> Name
|
||||||
makeName n = Name
|
makeName n = Name
|
||||||
|
@ -208,7 +215,7 @@ parseOp o = ParsedOp
|
||||||
, parsedOpSummary = o ^. summary
|
, parsedOpSummary = o ^. summary
|
||||||
, parsedOpDescription = o ^. description
|
, parsedOpDescription = o ^. description
|
||||||
, parsedOpIsMonadic = o ^. isStateful
|
, parsedOpIsMonadic = o ^. isStateful
|
||||||
|| any (isRefKind . parsedArgKind) parsedInputs
|
|| any (isRefCase . parsedArgCase) parsedInputs
|
||||||
, ..
|
, ..
|
||||||
}
|
}
|
||||||
where
|
where
|
||||||
|
@ -237,13 +244,11 @@ parseOp o = ParsedOp
|
||||||
-- TODO(judahjacobson): Some arguments should be refs.
|
-- TODO(judahjacobson): Some arguments should be refs.
|
||||||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||||
inputTensorKind a v
|
inputTensorKind a v
|
||||||
| a ^. type' == DT_RESOURCE = ArgResource
|
|
||||||
| a ^. isRef = ArgTensorRef
|
| a ^. isRef = ArgTensorRef
|
||||||
| otherwise = ArgTensorEither v
|
| otherwise = ArgTensorEither v
|
||||||
|
|
||||||
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
||||||
outputTensorKind a
|
outputTensorKind a
|
||||||
| a ^. type' == DT_RESOURCE = ArgResource
|
|
||||||
| a ^. isRef = ArgTensorRef
|
| a ^. isRef = ArgTensorRef
|
||||||
| otherwise = ArgTensorValue
|
| otherwise = ArgTensorValue
|
||||||
|
|
||||||
|
@ -258,6 +263,7 @@ getExplicitInputAttr implicitAttrs a
|
||||||
-- | The type attribute used by this input or output (if any).
|
-- | The type attribute used by this input or output (if any).
|
||||||
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
|
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
|
||||||
parsedArgTypeAttr p = case parsedArgCase p of
|
parsedArgTypeAttr p = case parsedArgCase p of
|
||||||
|
ResourceArg -> Nothing
|
||||||
SimpleArg {argType = t} -> fromArgType t
|
SimpleArg {argType = t} -> fromArgType t
|
||||||
ListArg {argType = t} -> fromArgType t
|
ListArg {argType = t} -> fromArgType t
|
||||||
MixedListArg {argTypeAttr = n} -> Just $ tfName n
|
MixedListArg {argTypeAttr = n} -> Just $ tfName n
|
||||||
|
@ -294,19 +300,18 @@ parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
|
||||||
parseArg a tKind = ParsedArg
|
parseArg a tKind = ParsedArg
|
||||||
{ parsedArgName = makeName (a ^. name)
|
{ parsedArgName = makeName (a ^. name)
|
||||||
, parsedArgDescription = a ^. description
|
, parsedArgDescription = a ^. description
|
||||||
, parsedArgCase = parseArgCase a
|
, parsedArgCase = parseArgCase a tKind
|
||||||
, parsedArgKind = tKind
|
|
||||||
}
|
}
|
||||||
|
|
||||||
parseArgCase :: OpDef'ArgDef -> ParsedArgCase
|
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
|
||||||
parseArgCase a
|
parseArgCase a tKind
|
||||||
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n
|
| a ^. type' == DT_RESOURCE = ResourceArg
|
||||||
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType
|
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
|
||||||
| otherwise = SimpleArg thisArgType
|
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
|
||||||
|
| otherwise = SimpleArg thisArgType tKind
|
||||||
where
|
where
|
||||||
thisArgType
|
thisArgType
|
||||||
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
|
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
|
||||||
| a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype")
|
|
||||||
| otherwise = ArgTypeFixed (a ^. type')
|
| otherwise = ArgTypeFixed (a ^. type')
|
||||||
maybeAttr :: Text -> Maybe Name
|
maybeAttr :: Text -> Maybe Name
|
||||||
maybeAttr "" = Nothing
|
maybeAttr "" = Nothing
|
||||||
|
|
|
@ -86,7 +86,7 @@ recordResult = do
|
||||||
put $! ResultState (i+1) ns
|
put $! ResultState (i+1) ns
|
||||||
return $! output i o
|
return $! output i o
|
||||||
|
|
||||||
instance OpResult (ResourceHandle a) where
|
instance OpResult ResourceHandle where
|
||||||
toResult = ResourceHandle <$> recordResult
|
toResult = ResourceHandle <$> recordResult
|
||||||
|
|
||||||
instance OpResult (Tensor Value a) where
|
instance OpResult (Tensor Value a) where
|
||||||
|
@ -150,7 +150,7 @@ buildListOp counts o = buildOp' counts o []
|
||||||
instance BuildOp ControlNode where
|
instance BuildOp ControlNode where
|
||||||
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
||||||
|
|
||||||
instance BuildOp (ResourceHandle a) where
|
instance BuildOp ResourceHandle where
|
||||||
buildOp' = pureResult
|
buildOp' = pureResult
|
||||||
|
|
||||||
instance BuildOp (Tensor Value a) where
|
instance BuildOp (Tensor Value a) where
|
||||||
|
@ -189,7 +189,7 @@ instance ( OpResult t1
|
||||||
instance OpResult a => BuildOp (Build a) where
|
instance OpResult a => BuildOp (Build a) where
|
||||||
buildOp' = buildResult
|
buildOp' = buildResult
|
||||||
|
|
||||||
instance BuildOp f => BuildOp (ResourceHandle a -> f) where
|
instance BuildOp f => BuildOp (ResourceHandle -> f) where
|
||||||
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
|
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
|
||||||
|
|
||||||
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
||||||
|
|
|
@ -158,6 +158,5 @@ instance IsString Output where
|
||||||
|
|
||||||
|
|
||||||
-- | Opaque handle to a mutable resource in the graph. Typical such
|
-- | Opaque handle to a mutable resource in the graph. Typical such
|
||||||
-- resources are variables. The type parameter corresponds to the
|
-- resources are variables.
|
||||||
-- dtype of the tensor held in the variable.
|
newtype ResourceHandle = ResourceHandle Output
|
||||||
newtype ResourceHandle a = ResourceHandle Output
|
|
||||||
|
|
Loading…
Reference in a new issue