support map and seq in tablegen (#159)
* support map and seq in tablegen * register MLONNX for testing * format * remove the unwanted test * add a test Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
b28fcbc745
commit
1fc43fa181
|
@ -63,7 +63,7 @@ ONNX CastMap operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tuple with any combination of tensor of 64-bit signless integer values values or memref of 64-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ ONNX DictVectorizer operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tuple with any combination of tensor of 64-bit signless integer or 32-bit float or 64-bit float values values or memref of 64-bit signless integer or 32-bit float or 64-bit float values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -593,5 +593,5 @@ ONNX ZipMap operation
|
||||||
|
|
||||||
| Result | Description |
|
| Result | Description |
|
||||||
| :----: | ----------- |
|
| :----: | ----------- |
|
||||||
`Z` | memref of any type values or tensor of any type values
|
`Z` | tensor of tensor of 32-bit float or 64-bit signless integer values values or memref of 32-bit float or 64-bit signless integer values
|
||||||
|
|
||||||
|
|
|
@ -531,7 +531,7 @@ ONNX ConcatFromSequence operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`input_sequence` | memref of any type values or tensor of any type values
|
`input_sequence` | memref of any type values or tensor of tensor of any type values values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -4363,7 +4363,7 @@ ONNX SequenceAt operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`input_sequence` | memref of any type values or tensor of any type values
|
`input_sequence` | memref of any type values or tensor of tensor of any type values values
|
||||||
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values
|
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
@ -4389,7 +4389,7 @@ ONNX SequenceConstruct operation
|
||||||
|
|
||||||
| Result | Description |
|
| Result | Description |
|
||||||
| :----: | ----------- |
|
| :----: | ----------- |
|
||||||
`output_sequence` | memref of any type values or tensor of any type values
|
`output_sequence` | memref of any type values or tensor of tensor of any type values values
|
||||||
|
|
||||||
### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp)
|
### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp)
|
||||||
|
|
||||||
|
@ -4407,7 +4407,7 @@ ONNX SequenceEmpty operation
|
||||||
|
|
||||||
| Result | Description |
|
| Result | Description |
|
||||||
| :----: | ----------- |
|
| :----: | ----------- |
|
||||||
`output` | memref of any type values or tensor of any type values
|
`output` | memref of any type values or tensor of tensor of any type values values
|
||||||
|
|
||||||
### `onnx.SequenceErase` (ONNXSequenceEraseOp)
|
### `onnx.SequenceErase` (ONNXSequenceEraseOp)
|
||||||
|
|
||||||
|
@ -4422,14 +4422,14 @@ ONNX SequenceErase operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`input_sequence` | memref of any type values or tensor of any type values
|
`input_sequence` | memref of any type values or tensor of tensor of any type values values
|
||||||
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type
|
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
| Result | Description |
|
| Result | Description |
|
||||||
| :----: | ----------- |
|
| :----: | ----------- |
|
||||||
`output_sequence` | memref of any type values or tensor of any type values
|
`output_sequence` | memref of any type values or tensor of tensor of any type values values
|
||||||
|
|
||||||
### `onnx.SequenceInsert` (ONNXSequenceInsertOp)
|
### `onnx.SequenceInsert` (ONNXSequenceInsertOp)
|
||||||
|
|
||||||
|
@ -4445,7 +4445,7 @@ ONNX SequenceInsert operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`input_sequence` | memref of any type values or tensor of any type values
|
`input_sequence` | memref of any type values or tensor of tensor of any type values values
|
||||||
`tensor` | memref of any type values or tensor of any type values
|
`tensor` | memref of any type values or tensor of any type values
|
||||||
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type
|
`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type
|
||||||
|
|
||||||
|
@ -4453,7 +4453,7 @@ ONNX SequenceInsert operation
|
||||||
|
|
||||||
| Result | Description |
|
| Result | Description |
|
||||||
| :----: | ----------- |
|
| :----: | ----------- |
|
||||||
`output_sequence` | memref of any type values or tensor of any type values
|
`output_sequence` | memref of any type values or tensor of tensor of any type values values
|
||||||
|
|
||||||
### `onnx.SequenceLength` (ONNXSequenceLengthOp)
|
### `onnx.SequenceLength` (ONNXSequenceLengthOp)
|
||||||
|
|
||||||
|
@ -4465,7 +4465,7 @@ ONNX SequenceLength operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`input_sequence` | memref of any type values or tensor of any type values
|
`input_sequence` | memref of any type values or tensor of tensor of any type values values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -4828,7 +4828,7 @@ ONNX SplitToSequence operation
|
||||||
|
|
||||||
| Result | Description |
|
| Result | Description |
|
||||||
| :----: | ----------- |
|
| :----: | ----------- |
|
||||||
`output_sequence` | memref of any type values or tensor of any type values
|
`output_sequence` | memref of any type values or tensor of tensor of any type values values
|
||||||
|
|
||||||
### `onnx.Sqrt` (ONNXSqrtOp)
|
### `onnx.Sqrt` (ONNXSqrtOp)
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ public:
|
||||||
|
|
||||||
/// Provide a utility accessor to the dialect namespace. This is used by
|
/// Provide a utility accessor to the dialect namespace. This is used by
|
||||||
/// several utilities for casting between dialects.
|
/// several utilities for casting between dialects.
|
||||||
static StringRef getDialectNamespace() { return "onnx"; }
|
static StringRef getDialectNamespace() { return "mlonnx"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Include the auto-generated header file containing the declarations of the
|
/// Include the auto-generated header file containing the declarations of the
|
||||||
|
|
|
@ -57,7 +57,7 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap",
|
||||||
" in ascending order based on this key.<br>The operator supports dense packing or sparse packing."
|
" in ascending order based on this key.<br>The operator supports dense packing or sparse packing."
|
||||||
" If using sparse packing, the key cannot exceed the max_map-1 value."
|
" If using sparse packing, the key cannot exceed the max_map-1 value."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64]>]>, MemRefOf<[I64]>]>:$X,
|
||||||
DefaultValuedAttr<StrAttr, "TO_FLOAT">:$cast_to,
|
DefaultValuedAttr<StrAttr, "TO_FLOAT">:$cast_to,
|
||||||
DefaultValuedAttr<StrAttr, "DENSE">:$map_form,
|
DefaultValuedAttr<StrAttr, "DENSE">:$map_form,
|
||||||
DefaultValuedAttr<I64Attr, "1">:$max_map);
|
DefaultValuedAttr<I64Attr, "1">:$max_map);
|
||||||
|
@ -124,7 +124,7 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
|
||||||
" then an input of ``{\"a\": 4, \"c\": 8}`` will produce an output of ``[4, 8, 0, 0]``."
|
" then an input of ``{\"a\": 4, \"c\": 8}`` will produce an output of ``[4, 8, 0, 0]``."
|
||||||
" "
|
" "
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TupleOf<[TensorOf<[I64,F32,F64]>]>, MemRefOf<[I64,F32,F64]>]>:$X,
|
||||||
OptionalAttr<I64ArrayAttr>:$int64_vocabulary,
|
OptionalAttr<I64ArrayAttr>:$int64_vocabulary,
|
||||||
OptionalAttr<StrArrayAttr>:$string_vocabulary);
|
OptionalAttr<StrArrayAttr>:$string_vocabulary);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
@ -555,7 +555,7 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
||||||
OptionalAttr<I64ArrayAttr>:$classlabels_int64s,
|
OptionalAttr<I64ArrayAttr>:$classlabels_int64s,
|
||||||
OptionalAttr<StrArrayAttr>:$classlabels_strings);
|
OptionalAttr<StrArrayAttr>:$classlabels_strings);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[F32,I64]>]>, MemRefOf<[F32,I64]>]>:$Z);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 1;
|
return 1;
|
||||||
|
|
|
@ -601,7 +601,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence",
|
||||||
"By default 'new_axis' is 0, the behavior is similar to numpy.concatenate."
|
"By default 'new_axis' is 0, the behavior is similar to numpy.concatenate."
|
||||||
"When 'new_axis' is 1, the behavior is similar to numpy.stack."
|
"When 'new_axis' is 1, the behavior is similar to numpy.stack."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
|
||||||
I64Attr:$axis,
|
I64Attr:$axis,
|
||||||
DefaultValuedAttr<I64Attr, "0">:$new_axis);
|
DefaultValuedAttr<I64Attr, "0">:$new_axis);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$concat_result);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$concat_result);
|
||||||
|
@ -4562,7 +4562,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt",
|
||||||
"Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'."
|
"Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'."
|
||||||
"Negative value means counting positions from the back."
|
"Negative value means counting positions from the back."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
|
||||||
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$position);
|
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$position);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
|
@ -4586,7 +4586,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct",
|
||||||
"All tensors in 'inputs' must have the same data type."
|
"All tensors in 'inputs' must have the same data type."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs);
|
let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
|
let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -4607,7 +4607,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty",
|
||||||
"Construct an empty tensor sequence, with given data type."
|
"Construct an empty tensor sequence, with given data type."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins OptionalAttr<I64Attr>:$dtype);
|
let arguments = (ins OptionalAttr<I64Attr>:$dtype);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -4630,9 +4630,9 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase",
|
||||||
"Negative value means counting positions from the back."
|
"Negative value means counting positions from the back."
|
||||||
"'position' is optional, by default it erases the last tensor from 'input_sequence'."
|
"'position' is optional, by default it erases the last tensor from 'input_sequence'."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
|
||||||
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position);
|
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
|
let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -4656,10 +4656,10 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert",
|
||||||
"Negative value means counting positions from the back."
|
"Negative value means counting positions from the back."
|
||||||
"'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'."
|
"'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor,
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor,
|
||||||
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position);
|
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
|
let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 3;
|
return 3;
|
||||||
|
@ -4679,7 +4679,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",
|
||||||
let description = [{
|
let description = [{
|
||||||
"Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'."
|
"Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$input_sequence);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$length);
|
let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$length);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
|
@ -5054,7 +5054,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
|
||||||
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$split,
|
AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$split,
|
||||||
DefaultValuedAttr<I64Attr, "0">:$axis,
|
DefaultValuedAttr<I64Attr, "0">:$axis,
|
||||||
DefaultValuedAttr<I64Attr, "1">:$keepdims);
|
DefaultValuedAttr<I64Attr, "1">:$keepdims);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
|
let results = (outs AnyTypeOf<[AnyMemRef, TensorOf<[AnyTensor]>]>:$output_sequence);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
|
|
@ -246,6 +246,7 @@ void registerDialects() {
|
||||||
mlir::registerDialect<mlir::scf::SCFDialect>();
|
mlir::registerDialect<mlir::scf::SCFDialect>();
|
||||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::MLONNXOpsDialect>();
|
||||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
#include "src/Builder/FrontendDialectTransformer.hpp"
|
#include "src/Builder/FrontendDialectTransformer.hpp"
|
||||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
#include "src/Dialect/MLONNX/MLONNXOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
#include "src/Pass/Passes.hpp"
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
|
|
@ -18,3 +18,8 @@ target_link_libraries(onnx-mlir-opt
|
||||||
OMONNXOps
|
OMONNXOps
|
||||||
OMONNXRewrite
|
OMONNXRewrite
|
||||||
${MLIRLibs})
|
${MLIRLibs})
|
||||||
|
|
||||||
|
if (INCLUDE_ONNX_ML)
|
||||||
|
target_link_libraries(onnx-mlir-opt OMMLONNXOps)
|
||||||
|
add_dependencies(onnx-mlir-opt OMMLONNXOpsInc)
|
||||||
|
endif()
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <mlir/Support/MlirOptMain.h>
|
#include <mlir/Support/MlirOptMain.h>
|
||||||
|
|
||||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
#include "src/Dialect/MLONNX/MLONNXOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
#include "src/Pass/Passes.hpp"
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
@ -67,6 +68,7 @@ int main(int argc, char **argv) {
|
||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
|
||||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::MLONNXOpsDialect>();
|
||||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||||
|
|
||||||
mlir::registerAsmPrinterCLOptions();
|
mlir::registerAsmPrinterCLOptions();
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
// RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// CHECK-LABEL: @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> {
|
||||||
|
func @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> {
|
||||||
|
%0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64>
|
||||||
|
return %0 : tensor<*xi64>
|
||||||
|
// CHECK-NEXT: %0 = "mlonnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64>
|
||||||
|
}
|
|
@ -23,7 +23,7 @@ add_custom_target(OMONNXOpsIncTranslation
|
||||||
DEPENDS OMONNXOpsTableGenIncGen
|
DEPENDS OMONNXOpsTableGenIncGen
|
||||||
OMONNXOpsBuildTableIncGen)
|
OMONNXOpsBuildTableIncGen)
|
||||||
|
|
||||||
# Invoke gen_onnx_mlir.py to obtain ONNXOps.td.inc, OpBuildTable.inc.
|
# Invoke gen_onnx_mlir.py to obtain MLONNXOps.td.inc, MLOpBuildTable.inc.
|
||||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/MLONNXOps.td.inc
|
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/MLONNXOps.td.inc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/MLOpBuildTable.inc
|
${CMAKE_CURRENT_SOURCE_DIR}/MLOpBuildTable.inc
|
||||||
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py --domain="ONNX_ML"
|
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py --domain="ONNX_ML"
|
||||||
|
|
|
@ -424,13 +424,14 @@ def get_tblgen_type_index(type_str):
|
||||||
return tblgen_types.index(type_str)
|
return tblgen_types.index(type_str)
|
||||||
|
|
||||||
#the possible data structures are tensor, map and seq(tensor())
|
#the possible data structures are tensor, map and seq(tensor())
|
||||||
#TOFIX: currently, only tensor structure is supported
|
def get_data_structure_element(allowed_type_str):
|
||||||
def get_data_structure_element(allowed_type_str):
|
structure_list = ['tensor', 'seq', 'map']
|
||||||
if allowed_type_str.startswith('tensor') :
|
for structure in structure_list:
|
||||||
element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1)
|
if allowed_type_str.startswith(structure) :
|
||||||
return ('tensor', element)
|
element = allowed_type_str.replace(
|
||||||
else :
|
structure+'(', '', 1).replace(')', '', 1)
|
||||||
return (None, None)
|
return (structure, element)
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
def get_allowed_elem_types(schema, input):
|
def get_allowed_elem_types(schema, input):
|
||||||
#allowed_types_str = None
|
#allowed_types_str = None
|
||||||
|
@ -446,19 +447,25 @@ def get_allowed_elem_types(schema, input):
|
||||||
continue
|
continue
|
||||||
allowed_type_list=[]
|
allowed_type_list=[]
|
||||||
allowedTypes = type_constraint.allowed_type_strs
|
allowedTypes = type_constraint.allowed_type_strs
|
||||||
|
allowed_structure = None
|
||||||
for allowedType in allowedTypes:
|
for allowedType in allowedTypes:
|
||||||
structure, element = get_data_structure_element(allowedType);
|
structure, element = get_data_structure_element(allowedType);
|
||||||
if structure == None or element == None:
|
if structure == None or element == None:
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
|
if allowed_structure != None and allowed_structure != structure :
|
||||||
|
print("{}: one structure assumed".format(schema.name))
|
||||||
|
sys.exit(-1)
|
||||||
|
allowed_structure = structure
|
||||||
t = np_type_to_tblgen_attr_type(element)
|
t = np_type_to_tblgen_attr_type(element)
|
||||||
if t == None :
|
if t == None :
|
||||||
return None
|
return allowed_structure, None
|
||||||
if not t in allowed_type_list :
|
if not t in allowed_type_list :
|
||||||
allowed_tyoe_list = allowed_type_list.append(t)
|
allowed_tyoe_list = allowed_type_list.append(t)
|
||||||
|
|
||||||
return allowed_type_list
|
return allowed_structure,allowed_type_list
|
||||||
|
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def inc_indent(indent=None):
|
def inc_indent(indent=None):
|
||||||
|
@ -486,14 +493,37 @@ def get_operands_or_results(schema, is_input):
|
||||||
|
|
||||||
name_to_types = OrderedDict()
|
name_to_types = OrderedDict()
|
||||||
for i, value in enumerate(value_list):
|
for i, value in enumerate(value_list):
|
||||||
elem_types = get_allowed_elem_types(schema, value)
|
structure, elem_types = get_allowed_elem_types(schema, value)
|
||||||
|
|
||||||
if elem_types is None:
|
if structure == 'tensor' :
|
||||||
types = ["AnyMemRef", "AnyTensor"]
|
if elem_types is None:
|
||||||
|
types = ["AnyMemRef", "AnyTensor"]
|
||||||
|
else:
|
||||||
|
elem_types_str = ','.join(elem_types)
|
||||||
|
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
|
||||||
|
types = list(map(lambda x: x.format(elem_types_str), types))
|
||||||
|
elif structure == 'seq' :
|
||||||
|
# Seq is not supported yet.
|
||||||
|
# Use of TensorOf<[AnyTensor]> as a placeholder for tablegen.
|
||||||
|
# When the Operation is used, warning/error will be generated at runtime.
|
||||||
|
if elem_types is None:
|
||||||
|
types = ["AnyMemRef", "TensorOf<[AnyTensor]>"]
|
||||||
|
else:
|
||||||
|
elem_types_str = ','.join(elem_types)
|
||||||
|
types = ["TensorOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"]
|
||||||
|
types = list(map(lambda x: x.format(elem_types_str), types))
|
||||||
|
elif structure == 'map' :
|
||||||
|
# Map is not supported yet.
|
||||||
|
# Use of TupleOf as a placeholder for tablegen.
|
||||||
|
# When the Operation is used, warning/error will be generated at runtime.
|
||||||
|
if elem_types is None:
|
||||||
|
types = ["AnyMemRef", "TupleOf<[AnyTensor]>"]
|
||||||
|
else:
|
||||||
|
elem_types_str = ','.join(elem_types)
|
||||||
|
types = ["TupleOf<[TensorOf<[{}]>]>", "MemRefOf<[{}]>"]
|
||||||
|
types = list(map(lambda x: x.format(elem_types_str), types))
|
||||||
else:
|
else:
|
||||||
elem_types_str = ','.join(elem_types)
|
types = ["AnyMemRef", "AnyTensor"]
|
||||||
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
|
|
||||||
types = list(map(lambda x: x.format(elem_types_str), types))
|
|
||||||
|
|
||||||
# If operand is promotable to an attribute, then it must be
|
# If operand is promotable to an attribute, then it must be
|
||||||
# nullable in case it migrates to be an attribute.
|
# nullable in case it migrates to be an attribute.
|
||||||
|
@ -592,7 +622,7 @@ def get_output_type_mapping(schema):
|
||||||
mapping=[]
|
mapping=[]
|
||||||
for output in schema.outputs :
|
for output in schema.outputs :
|
||||||
#if only one type is allowed, just set that
|
#if only one type is allowed, just set that
|
||||||
allowed_elem_types = get_allowed_elem_types(schema, output)
|
structure, allowed_elem_types = get_allowed_elem_types(schema, output)
|
||||||
if allowed_elem_types != None and len(allowed_elem_types) == 1 :
|
if allowed_elem_types != None and len(allowed_elem_types) == 1 :
|
||||||
mapping.append(str(get_tblgen_type_index(allowed_elem_types[0])))
|
mapping.append(str(get_tblgen_type_index(allowed_elem_types[0])))
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue