Allow importing variadic inputs/outputs of onnx operators (#16)

* Allow importing variadic inputs/outputs of onnx operators

* Enable testcases for variadic ops

* Modify gen_doc.py
This commit is contained in:
Tung D. Le 2020-01-28 22:48:11 +09:00 committed by Tian Jin
parent 31116ec3c2
commit 7c889548a7
4 changed files with 39 additions and 20 deletions

View File

@ -330,8 +330,8 @@ private:
* default} * default}
*/ */
template <typename T> template <typename T>
void void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut) { bool variadicIn = false, bool variadicOut = false) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
@ -348,8 +348,8 @@ private:
auto attributes = ImportNodeAttributes(node); auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type(); llvm::StringRef OpName = node.op_type();
if ((variadicIn || nIn == inputs.size()) &&
if (nIn == inputs.size() && nOut == outputTypes.size()) { (variadicOut || nOut == outputTypes.size())) {
auto op = auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
frontend_symbols_.AddMapping(legalize_name(node.output()[0]), frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
@ -360,8 +360,9 @@ private:
} }
template <typename T> template <typename T>
void ImportNodeMultipleOuts( void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
const onnx::NodeProto &node, int nIn, int nOut) { bool variadicIn = false,
bool variadicOut = false) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
@ -379,7 +380,8 @@ private:
llvm::StringRef OpName = node.op_type(); llvm::StringRef OpName = node.op_type();
if (nIn == inputs.size() && nOut == outputTypes.size()) { if ((variadicIn || nIn == inputs.size()) &&
(variadicOut || nOut == outputTypes.size())) {
auto op = auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {

View File

@ -36,7 +36,7 @@
}else if (OpName == "Compress") { }else if (OpName == "Compress") {
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1); ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
}else if (OpName == "Concat") { }else if (OpName == "Concat") {
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1); ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false);
}else if (OpName == "ConcatFromSequence") { }else if (OpName == "ConcatFromSequence") {
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1); ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
}else if (OpName == "Constant") { }else if (OpName == "Constant") {
@ -138,7 +138,7 @@
}else if (OpName == "MatMulInteger") { }else if (OpName == "MatMulInteger") {
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1); ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
}else if (OpName == "Max") { }else if (OpName == "Max") {
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1); ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false);
}else if (OpName == "MaxPool") { }else if (OpName == "MaxPool") {
ImportNodeMaxPool(node, 1, 2); ImportNodeMaxPool(node, 1, 2);
}else if (OpName == "MaxRoiPool") { }else if (OpName == "MaxRoiPool") {
@ -146,11 +146,11 @@
}else if (OpName == "MaxUnpool") { }else if (OpName == "MaxUnpool") {
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1); ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
}else if (OpName == "Mean") { }else if (OpName == "Mean") {
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1); ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false);
}else if (OpName == "MeanVarianceNormalization") { }else if (OpName == "MeanVarianceNormalization") {
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1); ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
}else if (OpName == "Min") { }else if (OpName == "Min") {
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1); ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false);
}else if (OpName == "Mod") { }else if (OpName == "Mod") {
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1); ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
}else if (OpName == "Mul") { }else if (OpName == "Mul") {
@ -240,7 +240,7 @@
}else if (OpName == "SequenceAt") { }else if (OpName == "SequenceAt") {
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1); ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
}else if (OpName == "SequenceConstruct") { }else if (OpName == "SequenceConstruct") {
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1); ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false);
}else if (OpName == "SequenceEmpty") { }else if (OpName == "SequenceEmpty") {
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1); ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
}else if (OpName == "SequenceErase") { }else if (OpName == "SequenceErase") {
@ -286,7 +286,7 @@
}else if (OpName == "Sub") { }else if (OpName == "Sub") {
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1); ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
}else if (OpName == "Sum") { }else if (OpName == "Sum") {
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1); ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false);
}else if (OpName == "Tan") { }else if (OpName == "Tan") {
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1); ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
}else if (OpName == "Tanh") { }else if (OpName == "Tanh") {

View File

@ -384,6 +384,8 @@ def gen_code(schema,fefile) :
#("Transpose", "ImportNodeTranspose") #("Transpose", "ImportNodeTranspose")
]) ])
handle_variadic = False
line_indent = ' ' line_indent = ' '
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n') fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
op_type_str='mlir::ONNX'+schema.name+'Op' op_type_str='mlir::ONNX'+schema.name+'Op'
@ -399,7 +401,22 @@ def gen_code(schema,fefile) :
fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, ' fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
+str(len(schema.inputs)) +str(len(schema.inputs))
+', ' +str(len(schema.outputs))) +', ' +str(len(schema.outputs)))
fefile.write(');\n')
variadicIn = 'false'
variadicOut = 'false'
for input in schema.inputs:
if OpSchema.FormalParameterOption.Variadic == input.option:
if input.isHomogeneous:
variadicIn = 'true'
handle_variadic = True
for output in schema.outputs:
if OpSchema.FormalParameterOption.Variadic == output.option:
if output.isHomogeneous:
variadicOut = 'true'
if not handle_variadic:
fefile.write(');\n')
else:
fefile.write(', '+variadicIn+', '+variadicOut+');\n')
def gen_attr_ins(schema, isfirst) : def gen_attr_ins(schema, isfirst) :
special_defaults = dict([ special_defaults = dict([

View File

@ -104,14 +104,14 @@ test_to_enable = [
"test_leakyrelu_example_cpu", "test_leakyrelu_example_cpu",
# Max Op: # Max Op:
# "test_max_example_cpu", <- error "test_max_example_cpu",
"test_max_one_input_cpu", "test_max_one_input_cpu",
# "test_max_two_inputs_cpu", <- error "test_max_two_inputs_cpu",
# Min Op: # Min Op:
# "test_min_example_cpu", <- error "test_min_example_cpu",
"test_min_one_input_cpu", "test_min_one_input_cpu",
# "test_min_two_inputs_cpu", <- error "test_min_two_inputs_cpu",
# Mul Op: # Mul Op:
"test_mul_cpu", "test_mul_cpu",
@ -139,9 +139,9 @@ test_to_enable = [
"test_softmax_large_number_cpu", "test_softmax_large_number_cpu",
# Sum Op: # Sum Op:
#"test_sum_example_cpu", <- error "test_sum_example_cpu",
"test_sum_one_input_cpu", "test_sum_one_input_cpu",
#"test_sum_two_inputs_cpu", <- error "test_sum_two_inputs_cpu",
# Reciprocal Op: # Reciprocal Op:
"test_reciprocal_cpu", "test_reciprocal_cpu",