support map and seq in tablegen (#159)

* support map and seq in tablegen

* register MLONNX for testing

* format

* remove the unwanted test

* add a test

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
chentong319 2020-06-18 09:49:40 -04:00 committed by GitHub
parent b28fcbc745
commit 1fc43fa181
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 96 additions and 48 deletions

View File

@ -63,7 +63,7 @@ ONNX CastMap operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tuple with any combination of tensor of 64-bit signless integer values values or memref of 64-bit signless integer values
#### Results: #### Results:
@ -133,7 +133,7 @@ ONNX DictVectorizer operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tuple with any combination of tensor of 64-bit signless integer or 32-bit float or 64-bit float values values or memref of 64-bit signless integer or 32-bit float or 64-bit float values
#### Results: #### Results:
@ -593,5 +593,5 @@ ONNX ZipMap operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`Z` | memref of any type values or tensor of any type values `Z` | tensor of tensor of 32-bit float or 64-bit signless integer values values or memref of 32-bit float or 64-bit signless integer values

View File

@ -531,7 +531,7 @@ ONNX ConcatFromSequence operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | memref of any type values or tensor of any type values `input_sequence` | memref of any type values or tensor of tensor of any type values values
#### Results: #### Results:
@ -4363,7 +4363,7 @@ ONNX SequenceAt operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | memref of any type values or tensor of any type values `input_sequence` | memref of any type values or tensor of tensor of any type values values
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values `position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values
#### Results: #### Results:
@ -4389,7 +4389,7 @@ ONNX SequenceConstruct operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output_sequence` | memref of any type values or tensor of any type values `output_sequence` | memref of any type values or tensor of tensor of any type values values
### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp) ### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp)
@ -4407,7 +4407,7 @@ ONNX SequenceEmpty operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output` | memref of any type values or tensor of any type values `output` | memref of any type values or tensor of tensor of any type values values
### `onnx.SequenceErase` (ONNXSequenceEraseOp) ### `onnx.SequenceErase` (ONNXSequenceEraseOp)
@ -4422,14 +4422,14 @@ ONNX SequenceErase operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | memref of any type values or tensor of any type values `input_sequence` | memref of any type values or tensor of tensor of any type values values
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type `position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type
#### Results: #### Results:
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output_sequence` | memref of any type values or tensor of any type values `output_sequence` | memref of any type values or tensor of tensor of any type values values
### `onnx.SequenceInsert` (ONNXSequenceInsertOp) ### `onnx.SequenceInsert` (ONNXSequenceInsertOp)
@ -4445,7 +4445,7 @@ ONNX SequenceInsert operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | memref of any type values or tensor of any type values `input_sequence` | memref of any type values or tensor of tensor of any type values values
`tensor` | memref of any type values or tensor of any type values `tensor` | memref of any type values or tensor of any type values
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type `position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type
@ -4453,7 +4453,7 @@ ONNX SequenceInsert operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output_sequence` | memref of any type values or tensor of any type values `output_sequence` | memref of any type values or tensor of tensor of any type values values
### `onnx.SequenceLength` (ONNXSequenceLengthOp) ### `onnx.SequenceLength` (ONNXSequenceLengthOp)
@ -4465,7 +4465,7 @@ ONNX SequenceLength operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | memref of any type values or tensor of any type values `input_sequence` | memref of any type values or tensor of tensor of any type values values
#### Results: #### Results:
@ -4828,7 +4828,7 @@ ONNX SplitToSequence operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output_sequence` | memref of any type values or tensor of any type values `output_sequence` | memref of any type values or tensor of tensor of any type values values
### `onnx.Sqrt` (ONNXSqrtOp) ### `onnx.Sqrt` (ONNXSqrtOp)

View File

@ -31,7 +31,7 @@ public:
/// Provide a utility accessor to the dialect namespace. This is used by /// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects. /// several utilities for casting between dialects.
static StringRef getDialectNamespace() { return "onnx"; } static StringRef getDialectNamespace() { return "mlonnx"; }
}; };
/// Include the auto-generated header file containing the declarations of the /// Include the auto-generated header file containing the declarations of the

View File

@ -57,7 +57,7 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap",
" in ascending order based on this key.<br>The operator supports dense packing or sparse packing." " in ascending order based on this key.<br>The operator supports dense packing or sparse packing."
" If using sparse packing, the key cannot exceed the max_map-1 value." " If using sparse packing, the key cannot exceed the max_map-1 value."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64]>]>, MemRefOf<[I64]>]>:$X,
DefaultValuedAttr<StrAttr, "TO_FLOAT">:$cast_to, DefaultValuedAttr<StrAttr, "TO_FLOAT">:$cast_to,
DefaultValuedAttr<StrAttr, "DENSE">:$map_form, DefaultValuedAttr<StrAttr, "DENSE">:$map_form,
DefaultValuedAttr<I64Attr, "1">:$max_map); DefaultValuedAttr<I64Attr, "1">:$max_map);
@ -124,7 +124,7 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
" then an input of ``{\"a\": 4, \"c\": 8}`` will produce an output of ``[4, 8, 0, 0]``." " then an input of ``{\"a\": 4, \"c\": 8}`` will produce an output of ``[4, 8, 0, 0]``."
" " " "
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64,F32,F64]>]>, MemRefOf<[I64,F32,F64]>]>:$X,
OptionalAttr<I64ArrayAttr>:$int64_vocabulary, OptionalAttr<I64ArrayAttr>:$int64_vocabulary,
OptionalAttr<StrArrayAttr>:$string_vocabulary); OptionalAttr<StrArrayAttr>:$string_vocabulary);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
@ -555,7 +555,7 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
OptionalAttr<I64ArrayAttr>:$classlabels_int64s, OptionalAttr<I64ArrayAttr>:$classlabels_int64s,
OptionalAttr<StrArrayAttr>:$classlabels_strings); OptionalAttr<StrArrayAttr>:$classlabels_strings);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[F32,I64]>]>, MemRefOf<[F32,I64]>]>:$Z);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 1; return 1;

View File

@ -601,7 +601,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence",
"By default 'new_axis' is 0, the behavior is similar to numpy.concatenate." "By default 'new_axis' is 0, the behavior is similar to numpy.concatenate."
"When 'new_axis' is 1, the behavior is similar to numpy.stack." "When 'new_axis' is 1, the behavior is similar to numpy.stack."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
I64Attr:$axis, I64Attr:$axis,
DefaultValuedAttr<I64Attr, "0">:$new_axis); DefaultValuedAttr<I64Attr, "0">:$new_axis);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$concat_result); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$concat_result);
@ -4562,7 +4562,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt",
"Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'." "Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'."
"Negative value means counting positions from the back." "Negative value means counting positions from the back."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$position); AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$position);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
@ -4586,7 +4586,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct",
"All tensors in 'inputs' must have the same data type." "All tensors in 'inputs' must have the same data type."
}]; }];
let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs); let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return -1; return -1;
@ -4607,7 +4607,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty",
"Construct an empty tensor sequence, with given data type." "Construct an empty tensor sequence, with given data type."
}]; }];
let arguments = (ins OptionalAttr<I64Attr>:$dtype); let arguments = (ins OptionalAttr<I64Attr>:$dtype);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 0; return 0;
@ -4630,9 +4630,9 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase",
"Negative value means counting positions from the back." "Negative value means counting positions from the back."
"'position' is optional, by default it erases the last tensor from 'input_sequence'." "'position' is optional, by default it erases the last tensor from 'input_sequence'."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position); AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -4656,10 +4656,10 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert",
"Negative value means counting positions from the back." "Negative value means counting positions from the back."
"'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'." "'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor, AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor,
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position); AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 3; return 3;
@ -4679,7 +4679,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",
let description = [{ let description = [{
"Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'." "Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence); let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence);
let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$length); let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$length);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
@ -5054,7 +5054,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$split, AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$split,
DefaultValuedAttr<I64Attr, "0">:$axis, DefaultValuedAttr<I64Attr, "0">:$axis,
DefaultValuedAttr<I64Attr, "1">:$keepdims); DefaultValuedAttr<I64Attr, "1">:$keepdims);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;

View File

@ -246,6 +246,7 @@ void registerDialects() {
mlir::registerDialect<mlir::scf::SCFDialect>(); mlir::registerDialect<mlir::scf::SCFDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>(); mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::ONNXOpsDialect>(); mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::MLONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>(); mlir::registerDialect<mlir::KrnlOpsDialect>();
} }

View File

@ -22,6 +22,7 @@
#include "src/Builder/FrontendDialectTransformer.hpp" #include "src/Builder/FrontendDialectTransformer.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/MLONNX/MLONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"

View File

@ -18,3 +18,8 @@ target_link_libraries(onnx-mlir-opt
OMONNXOps OMONNXOps
OMONNXRewrite OMONNXRewrite
${MLIRLibs}) ${MLIRLibs})
if (INCLUDE_ONNX_ML)
target_link_libraries(onnx-mlir-opt OMMLONNXOps)
add_dependencies(onnx-mlir-opt OMMLONNXOpsInc)
endif()

View File

@ -20,6 +20,7 @@
#include <mlir/Support/MlirOptMain.h> #include <mlir/Support/MlirOptMain.h>
#include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/MLONNX/MLONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
@ -67,6 +68,7 @@ int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);
mlir::registerDialect<mlir::ONNXOpsDialect>(); mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::MLONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>(); mlir::registerDialect<mlir::KrnlOpsDialect>();
mlir::registerAsmPrinterCLOptions(); mlir::registerAsmPrinterCLOptions();

View File

@ -0,0 +1,9 @@
// RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> {
func @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> {
%0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64>
return %0 : tensor<*xi64>
// CHECK-NEXT: %0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64>
}

View File

@ -23,7 +23,7 @@ add_custom_target(OMONNXOpsIncTranslation
DEPENDS OMONNXOpsTableGenIncGen DEPENDS OMONNXOpsTableGenIncGen
OMONNXOpsBuildTableIncGen) OMONNXOpsBuildTableIncGen)
# Invoke gen_onnx_mlir.py to obtain ONNXOps.td.inc, OpBuildTable.inc. # Invoke gen_onnx_mlir.py to obtain MLONNXOps.td.inc, MLOpBuildTable.inc.
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/MLONNXOps.td.inc add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/MLONNXOps.td.inc
${CMAKE_CURRENT_SOURCE_DIR}/MLOpBuildTable.inc ${CMAKE_CURRENT_SOURCE_DIR}/MLOpBuildTable.inc
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py --domain="ONNX_ML" COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py --domain="ONNX_ML"

View File

@ -424,13 +424,14 @@ def get_tblgen_type_index(type_str):
return tblgen_types.index(type_str) return tblgen_types.index(type_str)
#the possible data structures are tensor, map and seq(tensor()) #the possible data structures are tensor, map and seq(tensor())
#TOFIX: currently, only tensor structure is supported
def get_data_structure_element(allowed_type_str): def get_data_structure_element(allowed_type_str):
if allowed_type_str.startswith('tensor') : structure_list = ['tensor', 'seq', 'map']
element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1) for structure in structure_list:
return ('tensor', element) if allowed_type_str.startswith(structure) :
else : element = allowed_type_str.replace(
return (None, None) structure+'(', '', 1).replace(')', '', 1)
return (structure, element)
return (None, None)
def get_allowed_elem_types(schema, input): def get_allowed_elem_types(schema, input):
#allowed_types_str = None #allowed_types_str = None
@ -446,19 +447,25 @@ def get_allowed_elem_types(schema, input):
continue continue
allowed_type_list=[] allowed_type_list=[]
allowedTypes = type_constraint.allowed_type_strs allowedTypes = type_constraint.allowed_type_strs
allowed_structure = None
for allowedType in allowedTypes: for allowedType in allowedTypes:
structure, element = get_data_structure_element(allowedType); structure, element = get_data_structure_element(allowedType);
if structure == None or element == None: if structure == None or element == None:
return None return None, None
if allowed_structure != None and allowed_structure != structure :
print("{}: one structure assumed".format(schema.name))
sys.exit(-1)
allowed_structure = structure
t = np_type_to_tblgen_attr_type(element) t = np_type_to_tblgen_attr_type(element)
if t == None : if t == None :
return None return allowed_structure, None
if not t in allowed_type_list : if not t in allowed_type_list :
allowed_tyoe_list = allowed_type_list.append(t) allowed_tyoe_list = allowed_type_list.append(t)
return allowed_type_list return allowed_structure,allowed_type_list
return None return None, None
def inc_indent(indent=None): def inc_indent(indent=None):
@ -486,14 +493,37 @@ def get_operands_or_results(schema, is_input):
name_to_types = OrderedDict() name_to_types = OrderedDict()
for i, value in enumerate(value_list): for i, value in enumerate(value_list):
elem_types = get_allowed_elem_types(schema, value) structure, elem_types = get_allowed_elem_types(schema, value)
if elem_types is None: if structure == 'tensor' :
types = ["AnyMemRef", "AnyTensor"] if elem_types is None:
types = ["AnyMemRef", "AnyTensor"]
else:
elem_types_str = ','.join(elem_types)
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
types = list(map(lambda x: x.format(elem_types_str), types))
elif structure == 'seq' :
# Seq is not supported yet.
# Use of TensorOf<[AnyTensor]> as a placeholder for tablegen.
# When the Operation is used, warning/error will be generated at runtime.
if elem_types is None:
types = ["AnyMemRef", "TensorOf<[AnyTensor]>"]
else:
elem_types_str = ','.join(elem_types)
types = ["TensorOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"]
types = list(map(lambda x: x.format(elem_types_str), types))
elif structure == 'map' :
# Map is not supported yet.
# Use of TupleOf as a placeholder for tablegen.
# When the Operation is used, warning/error will be generated at runtime.
if elem_types is None:
types = ["AnyMemRef", "TupleOf<[AnyTensor]>"]
else:
elem_types_str = ','.join(elem_types)
types = ["TupleOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"]
types = list(map(lambda x: x.format(elem_types_str), types))
else: else:
elem_types_str = ','.join(elem_types) types = ["AnyMemRef", "AnyTensor"]
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
types = list(map(lambda x: x.format(elem_types_str), types))
# If operand is promotable to an attribute, then it must be # If operand is promotable to an attribute, then it must be
# nullable in case it migrates to be an attribute. # nullable in case it migrates to be an attribute.
@ -592,7 +622,7 @@ def get_output_type_mapping(schema):
mapping=[] mapping=[]
for output in schema.outputs : for output in schema.outputs :
#if only one type is allowed, just set that #if only one type is allowed, just set that
allowed_elem_types = get_allowed_elem_types(schema, output) structure, allowed_elem_types = get_allowed_elem_types(schema, output)
if allowed_elem_types != None and len(allowed_elem_types) == 1 : if allowed_elem_types != None and len(allowed_elem_types) == 1 :
mapping.append(str(get_tblgen_type_index(allowed_elem_types[0]))) mapping.append(str(get_tblgen_type_index(allowed_elem_types[0])))
continue continue