1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Remove the type parameter from ResourceHandle. (#76)

This change allows us to reenable the rest of the ResourceHandle ops, and
future-proofs us against more being added.  It removes the custom logic that
assumed there was a "dtype" attribute to guess what the type parameter is
(which wasn't true in general.)

When we switch to ResourceHandle (e.g., for queues and variables) we can add
parameters to the wrapper types like "Queue" on a case-by-case basis.
This commit is contained in:
Judah Jacobson 2017-02-21 19:38:26 -08:00 committed by Greg Steuck
parent b3c0997a8c
commit 0c8d41250a
5 changed files with 37 additions and 57 deletions

View file

@ -103,28 +103,4 @@ blackList =
, "_ListToArray" , "_ListToArray"
-- Easy: support larger result tuples. -- Easy: support larger result tuples.
, "Skipgram" , "Skipgram"
-- Resource ops which don't use "dtype" as the type parameter.
, "ResourceApplyAdadelta"
, "ResourceApplyAdagrad"
, "ResourceApplyAdagradDA"
, "ResourceApplyAdam"
, "ResourceApplyCenteredRMSProp"
, "ResourceApplyFtrl"
, "ResourceApplyGradientDescent"
, "ResourceApplyMomentum"
, "ResourceApplyProximalAdagrad"
, "ResourceApplyProximalGradientDescent"
, "ResourceApplyRMSProp"
, "ResourceSparseApplyAdadelta"
, "ResourceSparseApplyAdagrad"
, "ResourceSparseApplyAdagradDA"
, "ResourceSparseApplyCenteredRMSProp"
, "ResourceSparseApplyFtrl"
, "ResourceSparseApplyMomentum"
, "ResourceSparseApplyProximalAdagrad"
, "ResourceSparseApplyProximalGradientDescent"
, "ResourceSparseApplyRMSProp"
, "TensorArrayScatterV3"
, "TensorArraySplitV3"
, "TensorArrayWriteV3"
] ]

View file

@ -32,7 +32,7 @@ where:
(for example: 'TensorType' or 'OneOf'). (for example: 'TensorType' or 'OneOf').
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of * @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one
of those types), and @a@ is either a concrete type or a (constrained) type of those types), and @a@ is either a concrete type or a (constrained) type
variable. variable.
@ -271,7 +271,7 @@ typeSig pOp = constraints
| null (inferredTypeAttrs pOp) = empty | null (inferredTypeAttrs pOp) = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>" | otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
ArgTensorEither v <- [parsedArgKind k]] Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp] ++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ concatMap tensorArgConstraint classConstraints = tuple $ concatMap tensorArgConstraint
$ inferredTypeAttrs pOp $ inferredTypeAttrs pOp
@ -299,19 +299,19 @@ typeSig pOp = constraints
| otherwise = o | otherwise = o
-- | Render an op input or output. -- | Render an op input or output.
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype" -- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle"
tensorArg :: ParsedArg -> Doc tensorArg :: ParsedArg -> Doc
tensorArg p = case parsedArgCase p of tensorArg p = case parsedArgCase p of
SimpleArg { argType = t } -> tensorType t ResourceArg -> "ResourceHandle"
ListArg { argType = t } -> brackets $ tensorType t SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}" MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
where where
tensorType t = let tensorType t k = let
v = case parsedArgKind p of v = case k of
ArgTensorRef -> "Tensor Ref" ArgTensorRef -> "Tensor Ref"
ArgTensorValue -> "Tensor Value" ArgTensorValue -> "Tensor Value"
ArgTensorEither v' -> "Tensor" <+> strictText v' ArgTensorEither v' -> "Tensor" <+> strictText v'
ArgResource -> "ResourceHandle"
a = case t of a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
ArgTypeAttr n -> renderHaskellName n ArgTypeAttr n -> renderHaskellName n

View file

@ -16,6 +16,7 @@ module TensorFlow.OpGen.ParsedOp
, ParsedArgCase(..) , ParsedArgCase(..)
, ArgType(..) , ArgType(..)
, ArgKind(..) , ArgKind(..)
, argKind
, parseOp , parseOp
, camelCase , camelCase
) where ) where
@ -108,18 +109,23 @@ data ParsedArg = ParsedArg
{ parsedArgName :: Name { parsedArgName :: Name
, parsedArgDescription :: Text , parsedArgDescription :: Text
, parsedArgCase :: ParsedArgCase , parsedArgCase :: ParsedArgCase
, parsedArgKind :: ArgKind
} }
data ParsedArgCase data ParsedArgCase
= SimpleArg { argType :: ArgType } = SimpleArg { argType :: ArgType, argCaseKind :: ArgKind }
| ListArg | ListArg
{ argLength :: Name -- ^ The attribute that specifies this list's length. { argLength :: Name -- ^ The attribute that specifies this list's length.
, argType :: ArgType , argType :: ArgType
, argCaseKind :: ArgKind
} }
| MixedListArg { argTypeAttr :: Name } | MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
-- ^ A heterogeneous list. -- ^ A heterogeneous list.
-- TODO(judahjacobson): Implement this. -- TODO(judahjacobson): Implement this.
| ResourceArg
argKind :: ParsedArgCase -> Maybe ArgKind
argKind ResourceArg = Nothing
argKind a = Just $ argCaseKind a
-- | The type of an argument. -- | The type of an argument.
data ArgType data ArgType
@ -131,12 +137,13 @@ data ArgKind
= ArgTensorRef -- Tensor Ref a = ArgTensorRef -- Tensor Ref a
| ArgTensorValue -- Tensor Value a | ArgTensorValue -- Tensor Value a
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v` | ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
| ArgResource -- Resource a deriving (Eq)
isRefKind :: ArgKind -> Bool isRefCase :: ParsedArgCase -> Bool
isRefKind ArgTensorRef = True isRefCase a = case argKind a of
isRefKind ArgResource = True Nothing -> True -- Resource
isRefKind _ = False Just ArgTensorRef -> True
_ -> False
makeName :: Text -> Name makeName :: Text -> Name
makeName n = Name makeName n = Name
@ -208,7 +215,7 @@ parseOp o = ParsedOp
, parsedOpSummary = o ^. summary , parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description , parsedOpDescription = o ^. description
, parsedOpIsMonadic = o ^. isStateful , parsedOpIsMonadic = o ^. isStateful
|| any (isRefKind . parsedArgKind) parsedInputs || any (isRefCase . parsedArgCase) parsedInputs
, .. , ..
} }
where where
@ -237,13 +244,11 @@ parseOp o = ParsedOp
-- TODO(judahjacobson): Some arguments should be refs. -- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
inputTensorKind a v inputTensorKind a v
| a ^. type' == DT_RESOURCE = ArgResource
| a ^. isRef = ArgTensorRef | a ^. isRef = ArgTensorRef
| otherwise = ArgTensorEither v | otherwise = ArgTensorEither v
outputTensorKind :: OpDef'ArgDef -> ArgKind outputTensorKind :: OpDef'ArgDef -> ArgKind
outputTensorKind a outputTensorKind a
| a ^. type' == DT_RESOURCE = ArgResource
| a ^. isRef = ArgTensorRef | a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue | otherwise = ArgTensorValue
@ -258,6 +263,7 @@ getExplicitInputAttr implicitAttrs a
-- | The type attribute used by this input or output (if any). -- | The type attribute used by this input or output (if any).
parsedArgTypeAttr :: ParsedArg -> Maybe TFName parsedArgTypeAttr :: ParsedArg -> Maybe TFName
parsedArgTypeAttr p = case parsedArgCase p of parsedArgTypeAttr p = case parsedArgCase p of
ResourceArg -> Nothing
SimpleArg {argType = t} -> fromArgType t SimpleArg {argType = t} -> fromArgType t
ListArg {argType = t} -> fromArgType t ListArg {argType = t} -> fromArgType t
MixedListArg {argTypeAttr = n} -> Just $ tfName n MixedListArg {argTypeAttr = n} -> Just $ tfName n
@ -294,19 +300,18 @@ parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a tKind = ParsedArg parseArg a tKind = ParsedArg
{ parsedArgName = makeName (a ^. name) { parsedArgName = makeName (a ^. name)
, parsedArgDescription = a ^. description , parsedArgDescription = a ^. description
, parsedArgCase = parseArgCase a , parsedArgCase = parseArgCase a tKind
, parsedArgKind = tKind
} }
parseArgCase :: OpDef'ArgDef -> ParsedArgCase parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
parseArgCase a parseArgCase a tKind
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n | a ^. type' == DT_RESOURCE = ResourceArg
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
| otherwise = SimpleArg thisArgType | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
| otherwise = SimpleArg thisArgType tKind
where where
thisArgType thisArgType
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n | Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
| a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype")
| otherwise = ArgTypeFixed (a ^. type') | otherwise = ArgTypeFixed (a ^. type')
maybeAttr :: Text -> Maybe Name maybeAttr :: Text -> Maybe Name
maybeAttr "" = Nothing maybeAttr "" = Nothing

View file

@ -86,7 +86,7 @@ recordResult = do
put $! ResultState (i+1) ns put $! ResultState (i+1) ns
return $! output i o return $! output i o
instance OpResult (ResourceHandle a) where instance OpResult ResourceHandle where
toResult = ResourceHandle <$> recordResult toResult = ResourceHandle <$> recordResult
instance OpResult (Tensor Value a) where instance OpResult (Tensor Value a) where
@ -150,7 +150,7 @@ buildListOp counts o = buildOp' counts o []
instance BuildOp ControlNode where instance BuildOp ControlNode where
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
instance BuildOp (ResourceHandle a) where instance BuildOp ResourceHandle where
buildOp' = pureResult buildOp' = pureResult
instance BuildOp (Tensor Value a) where instance BuildOp (Tensor Value a) where
@ -189,7 +189,7 @@ instance ( OpResult t1
instance OpResult a => BuildOp (Build a) where instance OpResult a => BuildOp (Build a) where
buildOp' = buildResult buildOp' = buildResult
instance BuildOp f => BuildOp (ResourceHandle a -> f) where instance BuildOp f => BuildOp (ResourceHandle -> f) where
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts) buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
instance BuildOp f => BuildOp (Tensor v a -> f) where instance BuildOp f => BuildOp (Tensor v a -> f) where

View file

@ -158,6 +158,5 @@ instance IsString Output where
-- | Opaque handle to a mutable resource in the graph. Typical such -- | Opaque handle to a mutable resource in the graph. Typical such
-- resources are variables. The type parameter corresponds to the -- resources are variables.
-- dtype of the tensor held in the variable. newtype ResourceHandle = ResourceHandle Output
newtype ResourceHandle a = ResourceHandle Output