diff --git a/docs/Dialects/mlonnx.md b/docs/Dialects/mlonnx.md index 485fc42..88f41f6 100644 --- a/docs/Dialects/mlonnx.md +++ b/docs/Dialects/mlonnx.md @@ -63,7 +63,7 @@ ONNX CastMap operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tuple with any combination of tensor of 64-bit signless integer values values or memref of 64-bit signless integer values #### Results: @@ -133,7 +133,7 @@ ONNX DictVectorizer operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tuple with any combination of tensor of 64-bit signless integer or 32-bit float or 64-bit float values values or memref of 64-bit signless integer or 32-bit float or 64-bit float values #### Results: @@ -593,5 +593,5 @@ ONNX ZipMap operation | Result | Description | | :----: | ----------- | -`Z` | memref of any type values or tensor of any type values +`Z` | tensor of tensor of 32-bit float or 64-bit signless integer values values or memref of 32-bit float or 64-bit signless integer values diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index a447655..9e34f23 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -531,7 +531,7 @@ ONNX ConcatFromSequence operation | Operand | Description | | :-----: | ----------- | -`input_sequence` | memref of any type values or tensor of any type values +`input_sequence` | memref of any type values or tensor of tensor of any type values values #### Results: @@ -4363,7 +4363,7 @@ ONNX SequenceAt operation | Operand | Description | | :-----: | ----------- | -`input_sequence` | memref of any type values or tensor of any type values +`input_sequence` | memref of any type values or tensor of tensor of any type values values `position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values #### Results: @@ -4389,7 +4389,7 @@ ONNX SequenceConstruct operation | Result | Description | | :----: | ----------- | -`output_sequence` | memref of any type values or tensor of any type values +`output_sequence` | memref of any type values or tensor of tensor of any type values values ### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp) @@ -4407,7 +4407,7 @@ ONNX SequenceEmpty operation | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | memref of any type values or tensor of tensor of any type values values ### `onnx.SequenceErase` (ONNXSequenceEraseOp) @@ -4422,14 +4422,14 @@ ONNX SequenceErase operation | Operand | Description | | :-----: | ----------- | -`input_sequence` | memref of any type values or tensor of any type values +`input_sequence` | memref of any type values or tensor of tensor of any type values values `position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type #### Results: | Result | Description | | :----: | ----------- | -`output_sequence` | memref of any type values or tensor of any type values +`output_sequence` | memref of any type values or tensor of tensor of any type values values ### `onnx.SequenceInsert` (ONNXSequenceInsertOp) @@ -4445,7 +4445,7 @@ ONNX SequenceInsert operation | Operand | Description | | :-----: | ----------- | -`input_sequence` | memref of any type values or tensor of any type values +`input_sequence` | memref of any type values or tensor of tensor of any type values values `tensor` | memref of any type values or tensor of any type values `position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type @@ -4453,7 +4453,7 @@ ONNX SequenceInsert operation | Result | Description | | :----: | ----------- | -`output_sequence` | memref of any type values or tensor of any type values +`output_sequence` | memref of any type values or tensor of tensor of any type values values ### `onnx.SequenceLength` (ONNXSequenceLengthOp) @@ -4465,7 +4465,7 @@ ONNX SequenceLength operation | Operand | Description | | :-----: | ----------- | -`input_sequence` | memref of any type values or tensor of any type values +`input_sequence` | memref of any type values or tensor of tensor of any type values values #### Results: @@ -4828,7 +4828,7 @@ ONNX SplitToSequence operation | Result | Description | | :----: | ----------- | -`output_sequence` | memref of any type values or tensor of any type values +`output_sequence` | memref of any type values or tensor of tensor of any type values values ### `onnx.Sqrt` (ONNXSqrtOp) diff --git a/src/Dialect/MLONNX/MLONNXOps.hpp b/src/Dialect/MLONNX/MLONNXOps.hpp index 5457a0b..9474422 100644 --- a/src/Dialect/MLONNX/MLONNXOps.hpp +++ b/src/Dialect/MLONNX/MLONNXOps.hpp @@ -31,7 +31,7 @@ public: /// Provide a utility accessor to the dialect namespace. This is used by /// several utilities for casting between dialects. - static StringRef getDialectNamespace() { return "onnx"; } + static StringRef getDialectNamespace() { return "mlonnx"; } }; /// Include the auto-generated header file containing the declarations of the diff --git a/src/Dialect/MLONNX/MLONNXOps.td.inc b/src/Dialect/MLONNX/MLONNXOps.td.inc index f56efa6..c288113 100644 --- a/src/Dialect/MLONNX/MLONNXOps.td.inc +++ b/src/Dialect/MLONNX/MLONNXOps.td.inc @@ -57,7 +57,7 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap", " in ascending order based on this key.
The operator supports dense packing or sparse packing." " If using sparse packing, the key cannot exceed the max_map-1 value." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64]>]>, MemRefOf<[I64]>]>:$X, DefaultValuedAttr:$cast_to, DefaultValuedAttr:$map_form, DefaultValuedAttr:$max_map); @@ -124,7 +124,7 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer", " then an input of ``{\"a\": 4, \"c\": 8}`` will produce an output of ``[4, 8, 0, 0]``." " " }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64,F32,F64]>]>, MemRefOf<[I64,F32,F64]>]>:$X, OptionalAttr:$int64_vocabulary, OptionalAttr:$string_vocabulary); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); @@ -555,7 +555,7 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, OptionalAttr:$classlabels_int64s, OptionalAttr:$classlabels_strings); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); + let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[F32,I64]>]>, MemRefOf<[F32,I64]>]>:$Z); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index b082fd8..d2b5c57 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -601,7 +601,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", "By default 'new_axis' is 0, the behavior is similar to numpy.concatenate." "When 'new_axis' is 1, the behavior is similar to numpy.stack." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, + let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence, I64Attr:$axis, DefaultValuedAttr:$new_axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$concat_result); @@ -4562,7 +4562,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", "Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'." "Negative value means counting positions from the back." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, + let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence, AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$position); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor); let extraClassDeclaration = [{ @@ -4586,7 +4586,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", "All tensors in 'inputs' must have the same data type." }]; let arguments = (ins Variadic>:$inputs); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); + let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence); let extraClassDeclaration = [{ static int getNumberOfOperands() { return -1; @@ -4607,7 +4607,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", "Construct an empty tensor sequence, with given data type." }]; let arguments = (ins OptionalAttr:$dtype); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 0; @@ -4630,9 +4630,9 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", "Negative value means counting positions from the back." "'position' is optional, by default it erases the last tensor from 'input_sequence'." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, + let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence, AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); + let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 2; @@ -4656,10 +4656,10 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", "Negative value means counting positions from the back." "'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, + let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence, AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor, AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); + let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 3; @@ -4679,7 +4679,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", let description = [{ "Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence); + let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence); let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$length); let extraClassDeclaration = [{ static int getNumberOfOperands() { @@ -5054,7 +5054,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$split, DefaultValuedAttr:$axis, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); + let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 2; diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index 50f2a7f..99f283d 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -246,6 +246,7 @@ void registerDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); + mlir::registerDialect(); mlir::registerDialect(); } diff --git a/src/MainUtils.hpp b/src/MainUtils.hpp index 2ccb60f..834dbbd 100644 --- a/src/MainUtils.hpp +++ b/src/MainUtils.hpp @@ -22,6 +22,7 @@ #include "src/Builder/FrontendDialectTransformer.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Dialect/MLONNX/MLONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Pass/Passes.hpp" diff --git a/src/Tool/ONNXMLIROpt/CMakeLists.txt b/src/Tool/ONNXMLIROpt/CMakeLists.txt index 5de897b..8cdafef 100644 --- a/src/Tool/ONNXMLIROpt/CMakeLists.txt +++ b/src/Tool/ONNXMLIROpt/CMakeLists.txt @@ -18,3 +18,8 @@ target_link_libraries(onnx-mlir-opt OMONNXOps OMONNXRewrite ${MLIRLibs}) + +if (INCLUDE_ONNX_ML) + target_link_libraries(onnx-mlir-opt OMMLONNXOps) + add_dependencies(onnx-mlir-opt OMMLONNXOpsInc) +endif() diff --git a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp index ec35aed..7c121f2 100644 --- a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp +++ b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp @@ -20,6 +20,7 @@ #include #include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Dialect/MLONNX/MLONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Pass/Passes.hpp" @@ -67,6 +68,7 @@ int main(int argc, char **argv) { llvm::InitLLVM y(argc, argv); mlir::registerDialect(); + mlir::registerDialect(); mlir::registerDialect(); mlir::registerAsmPrinterCLOptions(); diff --git a/test/mlir/onnx/onnx_structure.mlir b/test/mlir/onnx/onnx_structure.mlir new file mode 100644 index 0000000..f974c84 --- /dev/null +++ b/test/mlir/onnx/onnx_structure.mlir @@ -0,0 +1,9 @@ +// RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s + +//===----------------------------------------------------------------------===// +// CHECK-LABEL: @check_map1(%arg0: tuple, tensor<10xi64>>) -> tensor<*xi64> { +func @check_map1(%arg0: tuple, tensor<10xi64>>) -> tensor<*xi64> { + %0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple, tensor<10xi64>>) -> tensor<*xi64> + return %0 : tensor<*xi64> + // CHECK-NEXT: %0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple, tensor<10xi64>>) -> tensor<*xi64> +} diff --git a/utils/CMakeLists.txt b/utils/CMakeLists.txt index 9a74721..b3aa2a6 100644 --- a/utils/CMakeLists.txt +++ b/utils/CMakeLists.txt @@ -23,7 +23,7 @@ add_custom_target(OMONNXOpsIncTranslation DEPENDS OMONNXOpsTableGenIncGen OMONNXOpsBuildTableIncGen) -# Invoke gen_onnx_mlir.py to obtain ONNXOps.td.inc, OpBuildTable.inc. +# Invoke gen_onnx_mlir.py to obtain MLONNXOps.td.inc, MLOpBuildTable.inc. add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/MLONNXOps.td.inc ${CMAKE_CURRENT_SOURCE_DIR}/MLOpBuildTable.inc COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py --domain="ONNX_ML" diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 5e30d25..c498c95 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -424,13 +424,14 @@ def get_tblgen_type_index(type_str): return tblgen_types.index(type_str) #the possible data structures are tensor, map and seq(tensor()) -#TOFIX: currently, only tensor structure is supported -def get_data_structure_element(allowed_type_str): - if allowed_type_str.startswith('tensor') : - element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1) - return ('tensor', element) - else : - return (None, None) +def get_data_structure_element(allowed_type_str): + structure_list = ['tensor', 'seq', 'map'] + for structure in structure_list: + if allowed_type_str.startswith(structure) : + element = allowed_type_str.replace( + structure+'(', '', 1).replace(')', '', 1) + return (structure, element) + return (None, None) def get_allowed_elem_types(schema, input): #allowed_types_str = None @@ -446,19 +447,25 @@ def get_allowed_elem_types(schema, input): continue allowed_type_list=[] allowedTypes = type_constraint.allowed_type_strs + allowed_structure = None for allowedType in allowedTypes: structure, element = get_data_structure_element(allowedType); if structure == None or element == None: - return None + return None, None + + if allowed_structure != None and allowed_structure != structure : + print("{}: one structure assumed".format(schema.name)) + sys.exit(-1) + allowed_structure = structure t = np_type_to_tblgen_attr_type(element) if t == None : - return None + return allowed_structure, None if not t in allowed_type_list : allowed_tyoe_list = allowed_type_list.append(t) - - return allowed_type_list - - return None + + return allowed_structure,allowed_type_list + + return None, None def inc_indent(indent=None): @@ -486,14 +493,37 @@ def get_operands_or_results(schema, is_input): name_to_types = OrderedDict() for i, value in enumerate(value_list): - elem_types = get_allowed_elem_types(schema, value) + structure, elem_types = get_allowed_elem_types(schema, value) - if elem_types is None: - types = ["AnyMemRef", "AnyTensor"] + if structure == 'tensor' : + if elem_types is None: + types = ["AnyMemRef", "AnyTensor"] + else: + elem_types_str = ','.join(elem_types) + types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] + types = list(map(lambda x: x.format(elem_types_str), types)) + elif structure == 'seq' : + # Seq is not supported yet. + # Use of TensorOf<[AnyTensor]> as a placeholder for tablegen. + # When the Operation is used, warning/error will be generated at runtime. + if elem_types is None: + types = ["AnyMemRef", "TensorOf<[AnyTensor]>"] + else: + elem_types_str = ','.join(elem_types) + types = ["TensorOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"] + types = list(map(lambda x: x.format(elem_types_str), types)) + elif structure == 'map' : + # Map is not supported yet. + # Use of TupleOf as a placeholder for tablegen. + # When the Operation is used, warning/error will be generated at runtime. + if elem_types is None: + types = ["AnyMemRef", "TupleOf<[AnyTensor]>"] + else: + elem_types_str = ','.join(elem_types) + types = ["TupleOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"] + types = list(map(lambda x: x.format(elem_types_str), types)) else: - elem_types_str = ','.join(elem_types) - types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] - types = list(map(lambda x: x.format(elem_types_str), types)) + types = ["AnyMemRef", "AnyTensor"] # If operand is promotable to an attribute, then it must be # nullable in case it migrates to be an attribute. @@ -592,7 +622,7 @@ def get_output_type_mapping(schema): mapping=[] for output in schema.outputs : #if only one type is allowed, just set that - allowed_elem_types = get_allowed_elem_types(schema, output) + structure, allowed_elem_types = get_allowed_elem_types(schema, output) if allowed_elem_types != None and len(allowed_elem_types) == 1 : mapping.append(str(get_tblgen_type_index(allowed_elem_types[0]))) continue