support map and seq in tablegen (#159)
* support map and seq in tablegen * register MLONNX for testing * format * remove the unwanted test * add a test Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									b28fcbc745
								
							
						
					
					
						commit
						1fc43fa181
					
				|  | @ -63,7 +63,7 @@ ONNX CastMap operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tuple with any combination of tensor of 64-bit signless integer values values or memref of 64-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -133,7 +133,7 @@ ONNX DictVectorizer operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tuple with any combination of tensor of 64-bit signless integer or 32-bit float or 64-bit float values values or memref of 64-bit signless integer or 32-bit float or 64-bit float values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -593,5 +593,5 @@ ONNX ZipMap operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `Z` | memref of any type values or tensor of any type values | ||||
| `Z` | tensor of tensor of 32-bit float or 64-bit signless integer values values or memref of 32-bit float or 64-bit signless integer values | ||||
| 
 | ||||
|  |  | |||
|  | @ -531,7 +531,7 @@ ONNX ConcatFromSequence operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | memref of any type values or tensor of any type values | ||||
| `input_sequence` | memref of any type values or tensor of tensor of any type values values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -4363,7 +4363,7 @@ ONNX SequenceAt operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | memref of any type values or tensor of any type values | ||||
| `input_sequence` | memref of any type values or tensor of tensor of any type values values | ||||
| `position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
|  | @ -4389,7 +4389,7 @@ ONNX SequenceConstruct operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output_sequence` | memref of any type values or tensor of any type values | ||||
| `output_sequence` | memref of any type values or tensor of tensor of any type values values | ||||
| 
 | ||||
| ### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp) | ||||
| 
 | ||||
|  | @ -4407,7 +4407,7 @@ ONNX SequenceEmpty operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output` | memref of any type values or tensor of any type values | ||||
| `output` | memref of any type values or tensor of tensor of any type values values | ||||
| 
 | ||||
| ### `onnx.SequenceErase` (ONNXSequenceEraseOp) | ||||
| 
 | ||||
|  | @ -4422,14 +4422,14 @@ ONNX SequenceErase operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | memref of any type values or tensor of any type values | ||||
| `input_sequence` | memref of any type values or tensor of tensor of any type values values | ||||
| `position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output_sequence` | memref of any type values or tensor of any type values | ||||
| `output_sequence` | memref of any type values or tensor of tensor of any type values values | ||||
| 
 | ||||
| ### `onnx.SequenceInsert` (ONNXSequenceInsertOp) | ||||
| 
 | ||||
|  | @ -4445,7 +4445,7 @@ ONNX SequenceInsert operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | memref of any type values or tensor of any type values | ||||
| `input_sequence` | memref of any type values or tensor of tensor of any type values values | ||||
| `tensor` | memref of any type values or tensor of any type values | ||||
| `position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type | ||||
| 
 | ||||
|  | @ -4453,7 +4453,7 @@ ONNX SequenceInsert operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output_sequence` | memref of any type values or tensor of any type values | ||||
| `output_sequence` | memref of any type values or tensor of tensor of any type values values | ||||
| 
 | ||||
| ### `onnx.SequenceLength` (ONNXSequenceLengthOp) | ||||
| 
 | ||||
|  | @ -4465,7 +4465,7 @@ ONNX SequenceLength operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | memref of any type values or tensor of any type values | ||||
| `input_sequence` | memref of any type values or tensor of tensor of any type values values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -4828,7 +4828,7 @@ ONNX SplitToSequence operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output_sequence` | memref of any type values or tensor of any type values | ||||
| `output_sequence` | memref of any type values or tensor of tensor of any type values values | ||||
| 
 | ||||
| ### `onnx.Sqrt` (ONNXSqrtOp) | ||||
| 
 | ||||
|  |  | |||
|  | @ -31,7 +31,7 @@ public: | |||
| 
 | ||||
|   /// Provide a utility accessor to the dialect namespace. This is used by
 | ||||
|   /// several utilities for casting between dialects.
 | ||||
|   static StringRef getDialectNamespace() { return "onnx"; } | ||||
|   static StringRef getDialectNamespace() { return "mlonnx"; } | ||||
| }; | ||||
| 
 | ||||
| /// Include the auto-generated header file containing the declarations of the
 | ||||
|  |  | |||
|  | @ -57,7 +57,7 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap", | |||
|   "    in ascending order based on this key.<br>The operator supports dense packing or sparse packing." | ||||
|   "    If using sparse packing, the key cannot exceed the max_map-1 value." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64]>]>, MemRefOf<[I64]>]>:$X, | ||||
|     DefaultValuedAttr<StrAttr, "TO_FLOAT">:$cast_to, | ||||
|     DefaultValuedAttr<StrAttr, "DENSE">:$map_form, | ||||
|     DefaultValuedAttr<I64Attr, "1">:$max_map); | ||||
|  | @ -124,7 +124,7 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer", | |||
|   "    then an input of ``{\"a\": 4, \"c\": 8}`` will produce an output of ``[4, 8, 0, 0]``." | ||||
|   "    " | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64,F32,F64]>]>, MemRefOf<[I64,F32,F64]>]>:$X, | ||||
|     OptionalAttr<I64ArrayAttr>:$int64_vocabulary, | ||||
|     OptionalAttr<StrArrayAttr>:$string_vocabulary); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|  | @ -555,7 +555,7 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap", | |||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|     OptionalAttr<I64ArrayAttr>:$classlabels_int64s, | ||||
|     OptionalAttr<StrArrayAttr>:$classlabels_strings); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[F32,I64]>]>, MemRefOf<[F32,I64]>]>:$Z); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|  |  | |||
|  | @ -601,7 +601,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", | |||
|   "By default 'new_axis' is 0, the behavior is similar to numpy.concatenate." | ||||
|   "When 'new_axis' is 1, the behavior is similar to numpy.stack." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence, | ||||
|     I64Attr:$axis, | ||||
|     DefaultValuedAttr<I64Attr, "0">:$new_axis); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$concat_result); | ||||
|  | @ -4562,7 +4562,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", | |||
|   "Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'." | ||||
|   "Negative value means counting positions from the back." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence, | ||||
|     AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$position); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor); | ||||
|   let extraClassDeclaration = [{ | ||||
|  | @ -4586,7 +4586,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", | |||
|   "All tensors in 'inputs' must have the same data type." | ||||
|   }]; | ||||
|   let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return -1; | ||||
|  | @ -4607,7 +4607,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", | |||
|   "Construct an empty tensor sequence, with given data type." | ||||
|   }]; | ||||
|   let arguments = (ins OptionalAttr<I64Attr>:$dtype); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 0; | ||||
|  | @ -4630,9 +4630,9 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", | |||
|   "Negative value means counting positions from the back." | ||||
|   "'position' is optional, by default it erases the last tensor from 'input_sequence'." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence, | ||||
|     AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 2; | ||||
|  | @ -4656,10 +4656,10 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", | |||
|   "Negative value means counting positions from the back." | ||||
|   "'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence, | ||||
|     AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor, | ||||
|     AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 3; | ||||
|  | @ -4679,7 +4679,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", | |||
|   let description = [{ | ||||
|   "Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence); | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$length); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|  | @ -5054,7 +5054,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", | |||
|     AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$split, | ||||
|     DefaultValuedAttr<I64Attr, "0">:$axis, | ||||
|     DefaultValuedAttr<I64Attr, "1">:$keepdims); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 2; | ||||
|  |  | |||
|  | @ -246,6 +246,7 @@ void registerDialects() { | |||
|   mlir::registerDialect<mlir::scf::SCFDialect>(); | ||||
|   mlir::registerDialect<mlir::StandardOpsDialect>(); | ||||
|   mlir::registerDialect<mlir::ONNXOpsDialect>(); | ||||
|   mlir::registerDialect<mlir::MLONNXOpsDialect>(); | ||||
|   mlir::registerDialect<mlir::KrnlOpsDialect>(); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -22,6 +22,7 @@ | |||
| 
 | ||||
| #include "src/Builder/FrontendDialectTransformer.hpp" | ||||
| #include "src/Dialect/Krnl/KrnlOps.hpp" | ||||
| #include "src/Dialect/MLONNX/MLONNXOps.hpp" | ||||
| #include "src/Dialect/ONNX/ONNXOps.hpp" | ||||
| #include "src/Pass/Passes.hpp" | ||||
| 
 | ||||
|  |  | |||
|  | @ -18,3 +18,8 @@ target_link_libraries(onnx-mlir-opt | |||
|         OMONNXOps | ||||
|         OMONNXRewrite | ||||
|         ${MLIRLibs}) | ||||
| 
 | ||||
| if (INCLUDE_ONNX_ML) | ||||
|   target_link_libraries(onnx-mlir-opt OMMLONNXOps) | ||||
|   add_dependencies(onnx-mlir-opt OMMLONNXOpsInc) | ||||
| endif() | ||||
|  |  | |||
|  | @ -20,6 +20,7 @@ | |||
| #include <mlir/Support/MlirOptMain.h> | ||||
| 
 | ||||
| #include "src/Dialect/Krnl/KrnlOps.hpp" | ||||
| #include "src/Dialect/MLONNX/MLONNXOps.hpp" | ||||
| #include "src/Dialect/ONNX/ONNXOps.hpp" | ||||
| #include "src/Pass/Passes.hpp" | ||||
| 
 | ||||
|  | @ -67,6 +68,7 @@ int main(int argc, char **argv) { | |||
|   llvm::InitLLVM y(argc, argv); | ||||
| 
 | ||||
|   mlir::registerDialect<mlir::ONNXOpsDialect>(); | ||||
|   mlir::registerDialect<mlir::MLONNXOpsDialect>(); | ||||
|   mlir::registerDialect<mlir::KrnlOpsDialect>(); | ||||
| 
 | ||||
|   mlir::registerAsmPrinterCLOptions(); | ||||
|  |  | |||
|  | @ -0,0 +1,9 @@ | |||
| // RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // CHECK-LABEL: @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> { | ||||
| func @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> { | ||||
|   %0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> | ||||
|   return %0 : tensor<*xi64> | ||||
|   // CHECK-NEXT: %0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> | ||||
| } | ||||
|  | @ -23,7 +23,7 @@ add_custom_target(OMONNXOpsIncTranslation | |||
|         DEPENDS OMONNXOpsTableGenIncGen | ||||
|                 OMONNXOpsBuildTableIncGen) | ||||
| 
 | ||||
| # Invoke gen_onnx_mlir.py to obtain ONNXOps.td.inc, OpBuildTable.inc. | ||||
| # Invoke gen_onnx_mlir.py to obtain MLONNXOps.td.inc, MLOpBuildTable.inc. | ||||
| add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/MLONNXOps.td.inc | ||||
|                 ${CMAKE_CURRENT_SOURCE_DIR}/MLOpBuildTable.inc | ||||
|        	COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py --domain="ONNX_ML" | ||||
|  |  | |||
|  | @ -424,13 +424,14 @@ def get_tblgen_type_index(type_str): | |||
|     return tblgen_types.index(type_str) | ||||
| 
 | ||||
| #the possible data structures are tensor, map and seq(tensor()) | ||||
| #TOFIX: currently, only tensor structure is supported | ||||
| def get_data_structure_element(allowed_type_str): | ||||
|     if allowed_type_str.startswith('tensor') : | ||||
|         element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1) | ||||
|         return ('tensor', element) | ||||
|     else : | ||||
|         return (None, None) | ||||
| def get_data_structure_element(allowed_type_str):  | ||||
|     structure_list = ['tensor', 'seq', 'map'] | ||||
|     for structure in structure_list: | ||||
|         if allowed_type_str.startswith(structure) : | ||||
|             element = allowed_type_str.replace( | ||||
|                 structure+'(', '', 1).replace(')', '', 1) | ||||
|             return (structure, element) | ||||
|     return (None, None) | ||||
| 
 | ||||
| def get_allowed_elem_types(schema, input): | ||||
|     #allowed_types_str = None | ||||
|  | @ -446,19 +447,25 @@ def get_allowed_elem_types(schema, input): | |||
|                 continue | ||||
|             allowed_type_list=[] | ||||
|             allowedTypes = type_constraint.allowed_type_strs | ||||
|             allowed_structure = None | ||||
|             for allowedType in allowedTypes: | ||||
|                 structure, element = get_data_structure_element(allowedType); | ||||
|                 if structure == None or element == None: | ||||
|                     return None | ||||
|                     return None, None | ||||
| 
 | ||||
|                 if allowed_structure != None and allowed_structure != structure : | ||||
|                     print("{}: one structure assumed".format(schema.name)) | ||||
|                     sys.exit(-1) | ||||
|                 allowed_structure = structure | ||||
|                 t = np_type_to_tblgen_attr_type(element) | ||||
|                 if t == None : | ||||
|                     return None | ||||
|                     return allowed_structure, None | ||||
|                 if  not t in allowed_type_list : | ||||
|                     allowed_tyoe_list = allowed_type_list.append(t) | ||||
| 
 | ||||
|             return allowed_type_list | ||||
| 
 | ||||
|     return None | ||||
|      | ||||
|             return allowed_structure,allowed_type_list | ||||
|      | ||||
|     return None, None | ||||
| 
 | ||||
| 
 | ||||
| def inc_indent(indent=None): | ||||
|  | @ -486,14 +493,37 @@ def get_operands_or_results(schema, is_input): | |||
| 
 | ||||
|     name_to_types = OrderedDict() | ||||
|     for i, value in enumerate(value_list): | ||||
|         elem_types = get_allowed_elem_types(schema, value) | ||||
|         structure, elem_types = get_allowed_elem_types(schema, value) | ||||
| 
 | ||||
|         if elem_types is None: | ||||
|             types = ["AnyMemRef", "AnyTensor"] | ||||
|         if structure == 'tensor' : | ||||
|             if elem_types is None: | ||||
|                 types = ["AnyMemRef", "AnyTensor"] | ||||
|             else: | ||||
|                 elem_types_str = ','.join(elem_types) | ||||
|                 types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] | ||||
|                 types = list(map(lambda x: x.format(elem_types_str), types)) | ||||
|         elif structure == 'seq' : | ||||
|             # Seq is not supported yet. | ||||
|             # Use of TensorOf<[AnyTensor]> as a placeholder for tablegen. | ||||
|             # When the Operation is used, warning/error will be generated at runtime. | ||||
|             if elem_types is None: | ||||
|                 types = ["AnyMemRef", "TensorOf<[AnyTensor]>"] | ||||
|             else: | ||||
|                 elem_types_str = ','.join(elem_types) | ||||
|                 types = ["TensorOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"] | ||||
|                 types = list(map(lambda x: x.format(elem_types_str), types)) | ||||
|         elif structure == 'map' : | ||||
|             # Map is not supported yet. | ||||
|             # Use of TupleOf as a placeholder for tablegen. | ||||
|             # When the Operation is used, warning/error will be generated at runtime. | ||||
|             if elem_types is None: | ||||
|                 types = ["AnyMemRef", "TupleOf<[AnyTensor]>"] | ||||
|             else: | ||||
|                 elem_types_str = ','.join(elem_types) | ||||
|                 types = ["TupleOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"] | ||||
|                 types = list(map(lambda x: x.format(elem_types_str), types)) | ||||
|         else: | ||||
|             elem_types_str = ','.join(elem_types) | ||||
|             types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] | ||||
|             types = list(map(lambda x: x.format(elem_types_str), types)) | ||||
|             types = ["AnyMemRef", "AnyTensor"] | ||||
| 
 | ||||
|         # If operand is promotable to an attribute, then it must be | ||||
|         # nullable in case it migrates to be an attribute. | ||||
|  | @ -592,7 +622,7 @@ def get_output_type_mapping(schema): | |||
|     mapping=[] | ||||
|     for output in schema.outputs : | ||||
|         #if only one type is allowed, just set that | ||||
|         allowed_elem_types = get_allowed_elem_types(schema, output) | ||||
|         structure, allowed_elem_types = get_allowed_elem_types(schema, output) | ||||
|         if allowed_elem_types != None and len(allowed_elem_types) == 1 : | ||||
|             mapping.append(str(get_tblgen_type_index(allowed_elem_types[0]))) | ||||
|             continue | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue