From 7c889548a73166cf2b93c2cb4420367ed361314b Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Tue, 28 Jan 2020 22:48:11 +0900 Subject: [PATCH] Allow importing variadic inputs/outputs of onnx operators (#16) * Allow importing variadic inputs/outputs of onnx operators * Enable testcases for variadic ops * Modify gen_doc.py --- src/builder/frontend_dialect_transformer.cpp | 16 +++++++++------- src/builder/op_build_table.inc | 12 ++++++------ src/dialect/onnx/gen_doc.py | 19 ++++++++++++++++++- test/backend/test.py | 12 ++++++------ 4 files changed, 39 insertions(+), 20 deletions(-) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 5f60809..7058cdc 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -330,8 +330,8 @@ private: * default} */ template - void - ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut) { + void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut, + bool variadicIn = false, bool variadicOut = false) { std::vector inputs; for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { @@ -348,8 +348,8 @@ private: auto attributes = ImportNodeAttributes(node); llvm::StringRef OpName = node.op_type(); - - if (nIn == inputs.size() && nOut == outputTypes.size()) { + if ((variadicIn || nIn == inputs.size()) && + (variadicOut || nOut == outputTypes.size())) { auto op = builder_.create(UnknownLoc(), outputTypes, inputs, attributes); frontend_symbols_.AddMapping(legalize_name(node.output()[0]), @@ -360,8 +360,9 @@ private: } template - void ImportNodeMultipleOuts( - const onnx::NodeProto &node, int nIn, int nOut) { + void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut, + bool variadicIn = false, + bool variadicOut = false) { std::vector inputs; for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { @@ -379,7 +380,8 @@ private: llvm::StringRef OpName = node.op_type(); - if (nIn == inputs.size() && nOut == outputTypes.size()) { + if ((variadicIn || nIn == inputs.size()) && + (variadicOut || nOut == outputTypes.size())) { auto op = builder_.create(UnknownLoc(), outputTypes, inputs, attributes); for (int i = 0; i < node.output().size(); i++) { diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index 7771473..3596b65 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -36,7 +36,7 @@ }else if (OpName == "Compress") { ImportNodeOneOut(node, 2, 1); }else if (OpName == "Concat") { - ImportNodeOneOut(node, 1, 1); + ImportNodeOneOut(node, 1, 1, true, false); }else if (OpName == "ConcatFromSequence") { ImportNodeOneOut(node, 1, 1); }else if (OpName == "Constant") { @@ -138,7 +138,7 @@ }else if (OpName == "MatMulInteger") { ImportNodeOneOut(node, 4, 1); }else if (OpName == "Max") { - ImportNodeOneOut(node, 1, 1); + ImportNodeOneOut(node, 1, 1, true, false); }else if (OpName == "MaxPool") { ImportNodeMaxPool(node, 1, 2); }else if (OpName == "MaxRoiPool") { @@ -146,11 +146,11 @@ }else if (OpName == "MaxUnpool") { ImportNodeOneOut(node, 3, 1); }else if (OpName == "Mean") { - ImportNodeOneOut(node, 1, 1); + ImportNodeOneOut(node, 1, 1, true, false); }else if (OpName == "MeanVarianceNormalization") { ImportNodeOneOut(node, 1, 1); }else if (OpName == "Min") { - ImportNodeOneOut(node, 1, 1); + ImportNodeOneOut(node, 1, 1, true, false); }else if (OpName == "Mod") { ImportNodeOneOut(node, 2, 1); }else if (OpName == "Mul") { @@ -240,7 +240,7 @@ }else if (OpName == "SequenceAt") { ImportNodeOneOut(node, 2, 1); }else if (OpName == "SequenceConstruct") { - ImportNodeOneOut(node, 1, 1); + ImportNodeOneOut(node, 1, 1, true, false); }else if (OpName == "SequenceEmpty") { ImportNodeOneOut(node, 0, 1); }else if (OpName == "SequenceErase") { @@ -286,7 +286,7 @@ }else if (OpName == "Sub") { ImportNodeOneOut(node, 2, 1); }else if (OpName == "Sum") { - ImportNodeOneOut(node, 1, 1); + ImportNodeOneOut(node, 1, 1, true, false); }else if (OpName == "Tan") { ImportNodeOneOut(node, 1, 1); }else if (OpName == "Tanh") { diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index 9b4356a..a40680d 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -384,6 +384,8 @@ def gen_code(schema,fefile) : #("Transpose", "ImportNodeTranspose") ]) + handle_variadic = False + line_indent = ' ' fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n') op_type_str='mlir::ONNX'+schema.name+'Op' @@ -399,7 +401,22 @@ def gen_code(schema,fefile) : fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, ' +str(len(schema.inputs)) +', ' +str(len(schema.outputs))) - fefile.write(');\n') + + variadicIn = 'false' + variadicOut = 'false' + for input in schema.inputs: + if OpSchema.FormalParameterOption.Variadic == input.option: + if input.isHomogeneous: + variadicIn = 'true' + handle_variadic = True + for output in schema.outputs: + if OpSchema.FormalParameterOption.Variadic == output.option: + if output.isHomogeneous: + variadicOut = 'true' + if not handle_variadic: + fefile.write(');\n') + else: + fefile.write(', '+variadicIn+', '+variadicOut+');\n') def gen_attr_ins(schema, isfirst) : special_defaults = dict([ diff --git a/test/backend/test.py b/test/backend/test.py index a369a5b..39f1b6e 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -104,14 +104,14 @@ test_to_enable = [ "test_leakyrelu_example_cpu", # Max Op: - # "test_max_example_cpu", <- error + "test_max_example_cpu", "test_max_one_input_cpu", - # "test_max_two_inputs_cpu", <- error + "test_max_two_inputs_cpu", # Min Op: - # "test_min_example_cpu", <- error + "test_min_example_cpu", "test_min_one_input_cpu", - # "test_min_two_inputs_cpu", <- error + "test_min_two_inputs_cpu", # Mul Op: "test_mul_cpu", @@ -139,9 +139,9 @@ test_to_enable = [ "test_softmax_large_number_cpu", # Sum Op: - #"test_sum_example_cpu", <- error + "test_sum_example_cpu", "test_sum_one_input_cpu", - #"test_sum_two_inputs_cpu", <- error + "test_sum_two_inputs_cpu", # Reciprocal Op: "test_reciprocal_cpu",