Add StatelessCase, TPUCompile to blacklist. (#271)

Needed to support Tensorflow 2.4.

Also, add AttrFunc to unbreak build.
This commit is contained in:
Mike Sperber 2021-02-09 18:08:46 +01:00 committed by GitHub
parent b1a8a0513d
commit d088e30b80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 1 deletions

View File

@ -128,10 +128,12 @@ blackList =
, "ScanDataset"
, "SnapshotDatasetV2"
, "StatefulPartitionedCall"
, "StatelessCase"
, "StatelessIf"
, "StatelessWhile"
, "SymbolicGradient"
, "TakeWhileDataset"
, "TPUCompile"
, "TPUPartitionedCall"
, "TPUReplicate"
, "While"

View File

@ -321,6 +321,7 @@ typeSig pre pOp = constraints
AttrType -> "DataType"
AttrShape -> "Shape"
AttrTensor -> "TensorProto"
AttrFunc -> error "AttrFunc not supported"
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
outputs = case parsedOutputs pOp of

View File

@ -101,7 +101,7 @@ data AttrType = AttrSingle AttrBaseType
deriving Eq
data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
| AttrType | AttrShape | AttrTensor
| AttrType | AttrShape | AttrTensor | AttrFunc
deriving Eq
data TypeParam = TypeParam
@ -334,6 +334,7 @@ parseAttrType o = \case
"type" -> AttrSingle AttrType
"shape" -> AttrSingle AttrShape
"tensor" -> AttrSingle AttrTensor
"func" -> AttrSingle AttrFunc
"list(string)" -> AttrList AttrBytes
"list(int)" -> AttrList AttrInt64
"list(float)" -> AttrList AttrFloat
@ -341,5 +342,6 @@ parseAttrType o = \case
"list(type)" -> AttrList AttrType
"list(shape)" -> AttrList AttrShape
"list(tensor)" -> AttrList AttrTensor
"list(func)" -> AttrList AttrFunc
t -> error $ "parseAttrType: unrecognized type " ++ show t
++ " for op " ++ show (o ^. name)