diff --git a/docs/Dialects/mlonnx.md b/docs/Dialects/mlonnx.md
index 485fc42..88f41f6 100644
--- a/docs/Dialects/mlonnx.md
+++ b/docs/Dialects/mlonnx.md
@@ -63,7 +63,7 @@ ONNX CastMap operation
| Operand | Description |
| :-----: | ----------- |
-`X` | memref of any type values or tensor of any type values
+`X` | tuple with any combination of tensor of 64-bit signless integer values values or memref of 64-bit signless integer values
#### Results:
@@ -133,7 +133,7 @@ ONNX DictVectorizer operation
| Operand | Description |
| :-----: | ----------- |
-`X` | memref of any type values or tensor of any type values
+`X` | tuple with any combination of tensor of 64-bit signless integer or 32-bit float or 64-bit float values values or memref of 64-bit signless integer or 32-bit float or 64-bit float values
#### Results:
@@ -593,5 +593,5 @@ ONNX ZipMap operation
| Result | Description |
| :----: | ----------- |
-`Z` | memref of any type values or tensor of any type values
+`Z` | tensor of tensor of 32-bit float or 64-bit signless integer values values or memref of 32-bit float or 64-bit signless integer values
diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md
index a447655..9e34f23 100644
--- a/docs/Dialects/onnx.md
+++ b/docs/Dialects/onnx.md
@@ -531,7 +531,7 @@ ONNX ConcatFromSequence operation
| Operand | Description |
| :-----: | ----------- |
-`input_sequence` | memref of any type values or tensor of any type values
+`input_sequence` | memref of any type values or tensor of tensor of any type values values
#### Results:
@@ -4363,7 +4363,7 @@ ONNX SequenceAt operation
| Operand | Description |
| :-----: | ----------- |
-`input_sequence` | memref of any type values or tensor of any type values
+`input_sequence` | memref of any type values or tensor of tensor of any type values values
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values
#### Results:
@@ -4389,7 +4389,7 @@ ONNX SequenceConstruct operation
| Result | Description |
| :----: | ----------- |
-`output_sequence` | memref of any type values or tensor of any type values
+`output_sequence` | memref of any type values or tensor of tensor of any type values values
### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp)
@@ -4407,7 +4407,7 @@ ONNX SequenceEmpty operation
| Result | Description |
| :----: | ----------- |
-`output` | memref of any type values or tensor of any type values
+`output` | memref of any type values or tensor of tensor of any type values values
### `onnx.SequenceErase` (ONNXSequenceEraseOp)
@@ -4422,14 +4422,14 @@ ONNX SequenceErase operation
| Operand | Description |
| :-----: | ----------- |
-`input_sequence` | memref of any type values or tensor of any type values
+`input_sequence` | memref of any type values or tensor of tensor of any type values values
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type
#### Results:
| Result | Description |
| :----: | ----------- |
-`output_sequence` | memref of any type values or tensor of any type values
+`output_sequence` | memref of any type values or tensor of tensor of any type values values
### `onnx.SequenceInsert` (ONNXSequenceInsertOp)
@@ -4445,7 +4445,7 @@ ONNX SequenceInsert operation
| Operand | Description |
| :-----: | ----------- |
-`input_sequence` | memref of any type values or tensor of any type values
+`input_sequence` | memref of any type values or tensor of tensor of any type values values
`tensor` | memref of any type values or tensor of any type values
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type
@@ -4453,7 +4453,7 @@ ONNX SequenceInsert operation
| Result | Description |
| :----: | ----------- |
-`output_sequence` | memref of any type values or tensor of any type values
+`output_sequence` | memref of any type values or tensor of tensor of any type values values
### `onnx.SequenceLength` (ONNXSequenceLengthOp)
@@ -4465,7 +4465,7 @@ ONNX SequenceLength operation
| Operand | Description |
| :-----: | ----------- |
-`input_sequence` | memref of any type values or tensor of any type values
+`input_sequence` | memref of any type values or tensor of tensor of any type values values
#### Results:
@@ -4828,7 +4828,7 @@ ONNX SplitToSequence operation
| Result | Description |
| :----: | ----------- |
-`output_sequence` | memref of any type values or tensor of any type values
+`output_sequence` | memref of any type values or tensor of tensor of any type values values
### `onnx.Sqrt` (ONNXSqrtOp)
diff --git a/src/Dialect/MLONNX/MLONNXOps.hpp b/src/Dialect/MLONNX/MLONNXOps.hpp
index 5457a0b..9474422 100644
--- a/src/Dialect/MLONNX/MLONNXOps.hpp
+++ b/src/Dialect/MLONNX/MLONNXOps.hpp
@@ -31,7 +31,7 @@ public:
/// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects.
- static StringRef getDialectNamespace() { return "onnx"; }
+ static StringRef getDialectNamespace() { return "mlonnx"; }
};
/// Include the auto-generated header file containing the declarations of the
diff --git a/src/Dialect/MLONNX/MLONNXOps.td.inc b/src/Dialect/MLONNX/MLONNXOps.td.inc
index f56efa6..c288113 100644
--- a/src/Dialect/MLONNX/MLONNXOps.td.inc
+++ b/src/Dialect/MLONNX/MLONNXOps.td.inc
@@ -57,7 +57,7 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap",
" in ascending order based on this key.
The operator supports dense packing or sparse packing."
" If using sparse packing, the key cannot exceed the max_map-1 value."
}];
- let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
+ let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64]>]>, MemRefOf<[I64]>]>:$X,
DefaultValuedAttr:$cast_to,
DefaultValuedAttr:$map_form,
DefaultValuedAttr:$max_map);
@@ -124,7 +124,7 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
" then an input of ``{\"a\": 4, \"c\": 8}`` will produce an output of ``[4, 8, 0, 0]``."
" "
}];
- let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
+ let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64,F32,F64]>]>, MemRefOf<[I64,F32,F64]>]>:$X,
OptionalAttr:$int64_vocabulary,
OptionalAttr:$string_vocabulary);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
@@ -555,7 +555,7 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
OptionalAttr:$classlabels_int64s,
OptionalAttr:$classlabels_strings);
- let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
+ let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[F32,I64]>]>, MemRefOf<[F32,I64]>]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc
index b082fd8..d2b5c57 100644
--- a/src/Dialect/ONNX/ONNXOps.td.inc
+++ b/src/Dialect/ONNX/ONNXOps.td.inc
@@ -601,7 +601,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence",
"By default 'new_axis' is 0, the behavior is similar to numpy.concatenate."
"When 'new_axis' is 1, the behavior is similar to numpy.stack."
}];
- let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
+ let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
I64Attr:$axis,
DefaultValuedAttr:$new_axis);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$concat_result);
@@ -4562,7 +4562,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt",
"Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'."
"Negative value means counting positions from the back."
}];
- let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
+ let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$position);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor);
let extraClassDeclaration = [{
@@ -4586,7 +4586,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct",
"All tensors in 'inputs' must have the same data type."
}];
let arguments = (ins Variadic>:$inputs);
- let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
+ let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return -1;
@@ -4607,7 +4607,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty",
"Construct an empty tensor sequence, with given data type."
}];
let arguments = (ins OptionalAttr:$dtype);
- let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
+ let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 0;
@@ -4630,9 +4630,9 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase",
"Negative value means counting positions from the back."
"'position' is optional, by default it erases the last tensor from 'input_sequence'."
}];
- let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
+ let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position);
- let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
+ let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 2;
@@ -4656,10 +4656,10 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert",
"Negative value means counting positions from the back."
"'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'."
}];
- let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
+ let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor,
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position);
- let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
+ let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 3;
@@ -4679,7 +4679,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",
let description = [{
"Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'."
}];
- let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence);
+ let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence);
let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$length);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
@@ -5054,7 +5054,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$split,
DefaultValuedAttr:$axis,
DefaultValuedAttr:$keepdims);
- let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
+ let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 2;
diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp
index 50f2a7f..99f283d 100644
--- a/src/MainUtils.cpp
+++ b/src/MainUtils.cpp
@@ -246,6 +246,7 @@ void registerDialects() {
mlir::registerDialect();
mlir::registerDialect();
mlir::registerDialect();
+ mlir::registerDialect();
mlir::registerDialect();
}
diff --git a/src/MainUtils.hpp b/src/MainUtils.hpp
index 2ccb60f..834dbbd 100644
--- a/src/MainUtils.hpp
+++ b/src/MainUtils.hpp
@@ -22,6 +22,7 @@
#include "src/Builder/FrontendDialectTransformer.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
+#include "src/Dialect/MLONNX/MLONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"
diff --git a/src/Tool/ONNXMLIROpt/CMakeLists.txt b/src/Tool/ONNXMLIROpt/CMakeLists.txt
index 5de897b..8cdafef 100644
--- a/src/Tool/ONNXMLIROpt/CMakeLists.txt
+++ b/src/Tool/ONNXMLIROpt/CMakeLists.txt
@@ -18,3 +18,8 @@ target_link_libraries(onnx-mlir-opt
OMONNXOps
OMONNXRewrite
${MLIRLibs})
+
+if (INCLUDE_ONNX_ML)
+ target_link_libraries(onnx-mlir-opt OMMLONNXOps)
+ add_dependencies(onnx-mlir-opt OMMLONNXOpsInc)
+endif()
diff --git a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp
index ec35aed..7c121f2 100644
--- a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp
+++ b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp
@@ -20,6 +20,7 @@
#include
#include "src/Dialect/Krnl/KrnlOps.hpp"
+#include "src/Dialect/MLONNX/MLONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"
@@ -67,6 +68,7 @@ int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
mlir::registerDialect();
+ mlir::registerDialect();
mlir::registerDialect();
mlir::registerAsmPrinterCLOptions();
diff --git a/test/mlir/onnx/onnx_structure.mlir b/test/mlir/onnx/onnx_structure.mlir
new file mode 100644
index 0000000..f974c84
--- /dev/null
+++ b/test/mlir/onnx/onnx_structure.mlir
@@ -0,0 +1,9 @@
+// RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @check_map1(%arg0: tuple, tensor<10xi64>>) -> tensor<*xi64> {
+func @check_map1(%arg0: tuple, tensor<10xi64>>) -> tensor<*xi64> {
+ %0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple, tensor<10xi64>>) -> tensor<*xi64>
+ return %0 : tensor<*xi64>
+ // CHECK-NEXT: %0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple, tensor<10xi64>>) -> tensor<*xi64>
+}
diff --git a/utils/CMakeLists.txt b/utils/CMakeLists.txt
index 9a74721..b3aa2a6 100644
--- a/utils/CMakeLists.txt
+++ b/utils/CMakeLists.txt
@@ -23,7 +23,7 @@ add_custom_target(OMONNXOpsIncTranslation
DEPENDS OMONNXOpsTableGenIncGen
OMONNXOpsBuildTableIncGen)
-# Invoke gen_onnx_mlir.py to obtain ONNXOps.td.inc, OpBuildTable.inc.
+# Invoke gen_onnx_mlir.py to obtain MLONNXOps.td.inc, MLOpBuildTable.inc.
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/MLONNXOps.td.inc
${CMAKE_CURRENT_SOURCE_DIR}/MLOpBuildTable.inc
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py --domain="ONNX_ML"
diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py
index 5e30d25..c498c95 100644
--- a/utils/gen_onnx_mlir.py
+++ b/utils/gen_onnx_mlir.py
@@ -424,13 +424,14 @@ def get_tblgen_type_index(type_str):
return tblgen_types.index(type_str)
#the possible data structures are tensor, map and seq(tensor())
-#TOFIX: currently, only tensor structure is supported
-def get_data_structure_element(allowed_type_str):
- if allowed_type_str.startswith('tensor') :
- element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1)
- return ('tensor', element)
- else :
- return (None, None)
+def get_data_structure_element(allowed_type_str):
+ structure_list = ['tensor', 'seq', 'map']
+ for structure in structure_list:
+ if allowed_type_str.startswith(structure) :
+ element = allowed_type_str.replace(
+ structure+'(', '', 1).replace(')', '', 1)
+ return (structure, element)
+ return (None, None)
def get_allowed_elem_types(schema, input):
#allowed_types_str = None
@@ -446,19 +447,25 @@ def get_allowed_elem_types(schema, input):
continue
allowed_type_list=[]
allowedTypes = type_constraint.allowed_type_strs
+ allowed_structure = None
for allowedType in allowedTypes:
structure, element = get_data_structure_element(allowedType);
if structure == None or element == None:
- return None
+ return None, None
+
+ if allowed_structure != None and allowed_structure != structure :
+ print("{}: one structure assumed".format(schema.name))
+ sys.exit(-1)
+ allowed_structure = structure
t = np_type_to_tblgen_attr_type(element)
if t == None :
- return None
+ return allowed_structure, None
if not t in allowed_type_list :
allowed_tyoe_list = allowed_type_list.append(t)
-
- return allowed_type_list
-
- return None
+
+ return allowed_structure,allowed_type_list
+
+ return None, None
def inc_indent(indent=None):
@@ -486,14 +493,37 @@ def get_operands_or_results(schema, is_input):
name_to_types = OrderedDict()
for i, value in enumerate(value_list):
- elem_types = get_allowed_elem_types(schema, value)
+ structure, elem_types = get_allowed_elem_types(schema, value)
- if elem_types is None:
- types = ["AnyMemRef", "AnyTensor"]
+ if structure == 'tensor' :
+ if elem_types is None:
+ types = ["AnyMemRef", "AnyTensor"]
+ else:
+ elem_types_str = ','.join(elem_types)
+ types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
+ types = list(map(lambda x: x.format(elem_types_str), types))
+ elif structure == 'seq' :
+ # Seq is not supported yet.
+ # Use of TensorOf<[AnyTensor]> as a placeholder for tablegen.
+ # When the Operation is used, warning/error will be generated at runtime.
+ if elem_types is None:
+ types = ["AnyMemRef", "TensorOf<[AnyTensor]>"]
+ else:
+ elem_types_str = ','.join(elem_types)
+ types = ["TensorOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"]
+ types = list(map(lambda x: x.format(elem_types_str), types))
+ elif structure == 'map' :
+ # Map is not supported yet.
+ # Use of TupleOf as a placeholder for tablegen.
+ # When the Operation is used, warning/error will be generated at runtime.
+ if elem_types is None:
+ types = ["AnyMemRef", "TupleOf<[AnyTensor]>"]
+ else:
+ elem_types_str = ','.join(elem_types)
+ types = ["TupleOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"]
+ types = list(map(lambda x: x.format(elem_types_str), types))
else:
- elem_types_str = ','.join(elem_types)
- types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
- types = list(map(lambda x: x.format(elem_types_str), types))
+ types = ["AnyMemRef", "AnyTensor"]
# If operand is promotable to an attribute, then it must be
# nullable in case it migrates to be an attribute.
@@ -592,7 +622,7 @@ def get_output_type_mapping(schema):
mapping=[]
for output in schema.outputs :
#if only one type is allowed, just set that
- allowed_elem_types = get_allowed_elem_types(schema, output)
+ structure, allowed_elem_types = get_allowed_elem_types(schema, output)
if allowed_elem_types != None and len(allowed_elem_types) == 1 :
mapping.append(str(get_tblgen_type_index(allowed_elem_types[0])))
continue