Allow importing variadic inputs/outputs of onnx operators (#16)
* Allow importing variadic inputs/outputs of onnx operators * Enable testcases for variadic ops * Modify gen_doc.py
This commit is contained in:
parent
31116ec3c2
commit
7c889548a7
|
@ -330,8 +330,8 @@ private:
|
|||
* default}
|
||||
*/
|
||||
template <typename T>
|
||||
void
|
||||
ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut) {
|
||||
void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
|
||||
bool variadicIn = false, bool variadicOut = false) {
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (const auto &item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
|
@ -348,8 +348,8 @@ private:
|
|||
auto attributes = ImportNodeAttributes(node);
|
||||
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
|
||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {
|
||||
if ((variadicIn || nIn == inputs.size()) &&
|
||||
(variadicOut || nOut == outputTypes.size())) {
|
||||
auto op =
|
||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
|
||||
|
@ -360,8 +360,9 @@ private:
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ImportNodeMultipleOuts(
|
||||
const onnx::NodeProto &node, int nIn, int nOut) {
|
||||
void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
|
||||
bool variadicIn = false,
|
||||
bool variadicOut = false) {
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (const auto &item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
|
@ -379,7 +380,8 @@ private:
|
|||
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
|
||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {
|
||||
if ((variadicIn || nIn == inputs.size()) &&
|
||||
(variadicOut || nOut == outputTypes.size())) {
|
||||
auto op =
|
||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||
for (int i = 0; i < node.output().size(); i++) {
|
||||
|
|
|
@ -36,7 +36,7 @@
|
|||
}else if (OpName == "Compress") {
|
||||
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
|
||||
}else if (OpName == "Concat") {
|
||||
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false);
|
||||
}else if (OpName == "ConcatFromSequence") {
|
||||
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
|
||||
}else if (OpName == "Constant") {
|
||||
|
@ -138,7 +138,7 @@
|
|||
}else if (OpName == "MatMulInteger") {
|
||||
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
|
||||
}else if (OpName == "Max") {
|
||||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false);
|
||||
}else if (OpName == "MaxPool") {
|
||||
ImportNodeMaxPool(node, 1, 2);
|
||||
}else if (OpName == "MaxRoiPool") {
|
||||
|
@ -146,11 +146,11 @@
|
|||
}else if (OpName == "MaxUnpool") {
|
||||
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
|
||||
}else if (OpName == "Mean") {
|
||||
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false);
|
||||
}else if (OpName == "MeanVarianceNormalization") {
|
||||
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
|
||||
}else if (OpName == "Min") {
|
||||
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false);
|
||||
}else if (OpName == "Mod") {
|
||||
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
|
||||
}else if (OpName == "Mul") {
|
||||
|
@ -240,7 +240,7 @@
|
|||
}else if (OpName == "SequenceAt") {
|
||||
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
|
||||
}else if (OpName == "SequenceConstruct") {
|
||||
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false);
|
||||
}else if (OpName == "SequenceEmpty") {
|
||||
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
|
||||
}else if (OpName == "SequenceErase") {
|
||||
|
@ -286,7 +286,7 @@
|
|||
}else if (OpName == "Sub") {
|
||||
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
|
||||
}else if (OpName == "Sum") {
|
||||
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false);
|
||||
}else if (OpName == "Tan") {
|
||||
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
|
||||
}else if (OpName == "Tanh") {
|
||||
|
|
|
@ -384,6 +384,8 @@ def gen_code(schema,fefile) :
|
|||
#("Transpose", "ImportNodeTranspose")
|
||||
])
|
||||
|
||||
handle_variadic = False
|
||||
|
||||
line_indent = ' '
|
||||
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
||||
op_type_str='mlir::ONNX'+schema.name+'Op'
|
||||
|
@ -399,7 +401,22 @@ def gen_code(schema,fefile) :
|
|||
fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
|
||||
+str(len(schema.inputs))
|
||||
+', ' +str(len(schema.outputs)))
|
||||
|
||||
variadicIn = 'false'
|
||||
variadicOut = 'false'
|
||||
for input in schema.inputs:
|
||||
if OpSchema.FormalParameterOption.Variadic == input.option:
|
||||
if input.isHomogeneous:
|
||||
variadicIn = 'true'
|
||||
handle_variadic = True
|
||||
for output in schema.outputs:
|
||||
if OpSchema.FormalParameterOption.Variadic == output.option:
|
||||
if output.isHomogeneous:
|
||||
variadicOut = 'true'
|
||||
if not handle_variadic:
|
||||
fefile.write(');\n')
|
||||
else:
|
||||
fefile.write(', '+variadicIn+', '+variadicOut+');\n')
|
||||
|
||||
def gen_attr_ins(schema, isfirst) :
|
||||
special_defaults = dict([
|
||||
|
|
|
@ -104,14 +104,14 @@ test_to_enable = [
|
|||
"test_leakyrelu_example_cpu",
|
||||
|
||||
# Max Op:
|
||||
# "test_max_example_cpu", <- error
|
||||
"test_max_example_cpu",
|
||||
"test_max_one_input_cpu",
|
||||
# "test_max_two_inputs_cpu", <- error
|
||||
"test_max_two_inputs_cpu",
|
||||
|
||||
# Min Op:
|
||||
# "test_min_example_cpu", <- error
|
||||
"test_min_example_cpu",
|
||||
"test_min_one_input_cpu",
|
||||
# "test_min_two_inputs_cpu", <- error
|
||||
"test_min_two_inputs_cpu",
|
||||
|
||||
# Mul Op:
|
||||
"test_mul_cpu",
|
||||
|
@ -139,9 +139,9 @@ test_to_enable = [
|
|||
"test_softmax_large_number_cpu",
|
||||
|
||||
# Sum Op:
|
||||
#"test_sum_example_cpu", <- error
|
||||
"test_sum_example_cpu",
|
||||
"test_sum_one_input_cpu",
|
||||
#"test_sum_two_inputs_cpu", <- error
|
||||
"test_sum_two_inputs_cpu",
|
||||
|
||||
# Reciprocal Op:
|
||||
"test_reciprocal_cpu",
|
||||
|
|
Loading…
Reference in New Issue