1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Add StatelessCase, TPUCompile to blacklist. (#271)

Needed to support Tensorflow 2.4.

Also, add AttrFunc to unbreak build.
This commit is contained in:
Mike Sperber 2021-02-09 18:08:46 +01:00 committed by GitHub
parent b1a8a0513d
commit d088e30b80
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 1 deletions

View file

@ -128,10 +128,12 @@ blackList =
, "ScanDataset" , "ScanDataset"
, "SnapshotDatasetV2" , "SnapshotDatasetV2"
, "StatefulPartitionedCall" , "StatefulPartitionedCall"
, "StatelessCase"
, "StatelessIf" , "StatelessIf"
, "StatelessWhile" , "StatelessWhile"
, "SymbolicGradient" , "SymbolicGradient"
, "TakeWhileDataset" , "TakeWhileDataset"
, "TPUCompile"
, "TPUPartitionedCall" , "TPUPartitionedCall"
, "TPUReplicate" , "TPUReplicate"
, "While" , "While"

View file

@ -321,6 +321,7 @@ typeSig pre pOp = constraints
AttrType -> "DataType" AttrType -> "DataType"
AttrShape -> "Shape" AttrShape -> "Shape"
AttrTensor -> "TensorProto" AttrTensor -> "TensorProto"
AttrFunc -> error "AttrFunc not supported"
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t) tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
outputs = case parsedOutputs pOp of outputs = case parsedOutputs pOp of

View file

@ -101,7 +101,7 @@ data AttrType = AttrSingle AttrBaseType
deriving Eq deriving Eq
data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
| AttrType | AttrShape | AttrTensor | AttrType | AttrShape | AttrTensor | AttrFunc
deriving Eq deriving Eq
data TypeParam = TypeParam data TypeParam = TypeParam
@ -334,6 +334,7 @@ parseAttrType o = \case
"type" -> AttrSingle AttrType "type" -> AttrSingle AttrType
"shape" -> AttrSingle AttrShape "shape" -> AttrSingle AttrShape
"tensor" -> AttrSingle AttrTensor "tensor" -> AttrSingle AttrTensor
"func" -> AttrSingle AttrFunc
"list(string)" -> AttrList AttrBytes "list(string)" -> AttrList AttrBytes
"list(int)" -> AttrList AttrInt64 "list(int)" -> AttrList AttrInt64
"list(float)" -> AttrList AttrFloat "list(float)" -> AttrList AttrFloat
@ -341,5 +342,6 @@ parseAttrType o = \case
"list(type)" -> AttrList AttrType "list(type)" -> AttrList AttrType
"list(shape)" -> AttrList AttrShape "list(shape)" -> AttrList AttrShape
"list(tensor)" -> AttrList AttrTensor "list(tensor)" -> AttrList AttrTensor
"list(func)" -> AttrList AttrFunc
t -> error $ "parseAttrType: unrecognized type " ++ show t t -> error $ "parseAttrType: unrecognized type " ++ show t
++ " for op " ++ show (o ^. name) ++ " for op " ++ show (o ^. name)