Allow importing variadic inputs/outputs of onnx operators (#16)
* Allow importing variadic inputs/outputs of onnx operators * Enable testcases for variadic ops * Modify gen_doc.py
This commit is contained in:
parent
31116ec3c2
commit
7c889548a7
|
@ -330,8 +330,8 @@ private:
|
||||||
* default}
|
* default}
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void
|
void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
|
||||||
ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut) {
|
bool variadicIn = false, bool variadicOut = false) {
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (const auto &item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
@ -348,8 +348,8 @@ private:
|
||||||
auto attributes = ImportNodeAttributes(node);
|
auto attributes = ImportNodeAttributes(node);
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
if ((variadicIn || nIn == inputs.size()) &&
|
||||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {
|
(variadicOut || nOut == outputTypes.size())) {
|
||||||
auto op =
|
auto op =
|
||||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||||
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
|
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
|
||||||
|
@ -360,8 +360,9 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ImportNodeMultipleOuts(
|
void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
|
||||||
const onnx::NodeProto &node, int nIn, int nOut) {
|
bool variadicIn = false,
|
||||||
|
bool variadicOut = false) {
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (const auto &item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
@ -379,7 +380,8 @@ private:
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
|
||||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {
|
if ((variadicIn || nIn == inputs.size()) &&
|
||||||
|
(variadicOut || nOut == outputTypes.size())) {
|
||||||
auto op =
|
auto op =
|
||||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||||
for (int i = 0; i < node.output().size(); i++) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
|
|
|
@ -36,7 +36,7 @@
|
||||||
}else if (OpName == "Compress") {
|
}else if (OpName == "Compress") {
|
||||||
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
|
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
|
||||||
}else if (OpName == "Concat") {
|
}else if (OpName == "Concat") {
|
||||||
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false);
|
||||||
}else if (OpName == "ConcatFromSequence") {
|
}else if (OpName == "ConcatFromSequence") {
|
||||||
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
|
||||||
}else if (OpName == "Constant") {
|
}else if (OpName == "Constant") {
|
||||||
|
@ -138,7 +138,7 @@
|
||||||
}else if (OpName == "MatMulInteger") {
|
}else if (OpName == "MatMulInteger") {
|
||||||
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
|
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
|
||||||
}else if (OpName == "Max") {
|
}else if (OpName == "Max") {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false);
|
||||||
}else if (OpName == "MaxPool") {
|
}else if (OpName == "MaxPool") {
|
||||||
ImportNodeMaxPool(node, 1, 2);
|
ImportNodeMaxPool(node, 1, 2);
|
||||||
}else if (OpName == "MaxRoiPool") {
|
}else if (OpName == "MaxRoiPool") {
|
||||||
|
@ -146,11 +146,11 @@
|
||||||
}else if (OpName == "MaxUnpool") {
|
}else if (OpName == "MaxUnpool") {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
|
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
|
||||||
}else if (OpName == "Mean") {
|
}else if (OpName == "Mean") {
|
||||||
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false);
|
||||||
}else if (OpName == "MeanVarianceNormalization") {
|
}else if (OpName == "MeanVarianceNormalization") {
|
||||||
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
|
||||||
}else if (OpName == "Min") {
|
}else if (OpName == "Min") {
|
||||||
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false);
|
||||||
}else if (OpName == "Mod") {
|
}else if (OpName == "Mod") {
|
||||||
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
|
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
|
||||||
}else if (OpName == "Mul") {
|
}else if (OpName == "Mul") {
|
||||||
|
@ -240,7 +240,7 @@
|
||||||
}else if (OpName == "SequenceAt") {
|
}else if (OpName == "SequenceAt") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
|
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
|
||||||
}else if (OpName == "SequenceConstruct") {
|
}else if (OpName == "SequenceConstruct") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false);
|
||||||
}else if (OpName == "SequenceEmpty") {
|
}else if (OpName == "SequenceEmpty") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
|
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
|
||||||
}else if (OpName == "SequenceErase") {
|
}else if (OpName == "SequenceErase") {
|
||||||
|
@ -286,7 +286,7 @@
|
||||||
}else if (OpName == "Sub") {
|
}else if (OpName == "Sub") {
|
||||||
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
|
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
|
||||||
}else if (OpName == "Sum") {
|
}else if (OpName == "Sum") {
|
||||||
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false);
|
||||||
}else if (OpName == "Tan") {
|
}else if (OpName == "Tan") {
|
||||||
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
|
||||||
}else if (OpName == "Tanh") {
|
}else if (OpName == "Tanh") {
|
||||||
|
|
|
@ -384,6 +384,8 @@ def gen_code(schema,fefile) :
|
||||||
#("Transpose", "ImportNodeTranspose")
|
#("Transpose", "ImportNodeTranspose")
|
||||||
])
|
])
|
||||||
|
|
||||||
|
handle_variadic = False
|
||||||
|
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
||||||
op_type_str='mlir::ONNX'+schema.name+'Op'
|
op_type_str='mlir::ONNX'+schema.name+'Op'
|
||||||
|
@ -399,7 +401,22 @@ def gen_code(schema,fefile) :
|
||||||
fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
|
fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
|
||||||
+str(len(schema.inputs))
|
+str(len(schema.inputs))
|
||||||
+', ' +str(len(schema.outputs)))
|
+', ' +str(len(schema.outputs)))
|
||||||
fefile.write(');\n')
|
|
||||||
|
variadicIn = 'false'
|
||||||
|
variadicOut = 'false'
|
||||||
|
for input in schema.inputs:
|
||||||
|
if OpSchema.FormalParameterOption.Variadic == input.option:
|
||||||
|
if input.isHomogeneous:
|
||||||
|
variadicIn = 'true'
|
||||||
|
handle_variadic = True
|
||||||
|
for output in schema.outputs:
|
||||||
|
if OpSchema.FormalParameterOption.Variadic == output.option:
|
||||||
|
if output.isHomogeneous:
|
||||||
|
variadicOut = 'true'
|
||||||
|
if not handle_variadic:
|
||||||
|
fefile.write(');\n')
|
||||||
|
else:
|
||||||
|
fefile.write(', '+variadicIn+', '+variadicOut+');\n')
|
||||||
|
|
||||||
def gen_attr_ins(schema, isfirst) :
|
def gen_attr_ins(schema, isfirst) :
|
||||||
special_defaults = dict([
|
special_defaults = dict([
|
||||||
|
|
|
@ -104,14 +104,14 @@ test_to_enable = [
|
||||||
"test_leakyrelu_example_cpu",
|
"test_leakyrelu_example_cpu",
|
||||||
|
|
||||||
# Max Op:
|
# Max Op:
|
||||||
# "test_max_example_cpu", <- error
|
"test_max_example_cpu",
|
||||||
"test_max_one_input_cpu",
|
"test_max_one_input_cpu",
|
||||||
# "test_max_two_inputs_cpu", <- error
|
"test_max_two_inputs_cpu",
|
||||||
|
|
||||||
# Min Op:
|
# Min Op:
|
||||||
# "test_min_example_cpu", <- error
|
"test_min_example_cpu",
|
||||||
"test_min_one_input_cpu",
|
"test_min_one_input_cpu",
|
||||||
# "test_min_two_inputs_cpu", <- error
|
"test_min_two_inputs_cpu",
|
||||||
|
|
||||||
# Mul Op:
|
# Mul Op:
|
||||||
"test_mul_cpu",
|
"test_mul_cpu",
|
||||||
|
@ -139,9 +139,9 @@ test_to_enable = [
|
||||||
"test_softmax_large_number_cpu",
|
"test_softmax_large_number_cpu",
|
||||||
|
|
||||||
# Sum Op:
|
# Sum Op:
|
||||||
#"test_sum_example_cpu", <- error
|
"test_sum_example_cpu",
|
||||||
"test_sum_one_input_cpu",
|
"test_sum_one_input_cpu",
|
||||||
#"test_sum_two_inputs_cpu", <- error
|
"test_sum_two_inputs_cpu",
|
||||||
|
|
||||||
# Reciprocal Op:
|
# Reciprocal Op:
|
||||||
"test_reciprocal_cpu",
|
"test_reciprocal_cpu",
|
||||||
|
|
Loading…
Reference in New Issue