Fix MaxPool translation to ONNX dialect.

This commit is contained in:
Doru Bercea 2020-01-15 15:16:45 -05:00
parent deef363309
commit 3f6efdf4a4
4 changed files with 28 additions and 1 deletions

View File

@ -632,6 +632,21 @@ private:
} }
} }
/*!
* Special handle for MaxPool operations.
*/
void ImportNodeMaxPool(
onnx::NodeProto node, int nIn,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
int nOuts = node.output().size();
if (nOuts == 1) {
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts, attrs);
} else {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, nIn, nOuts, attrs);
}
}
void ImportNode(onnx::NodeProto node) { void ImportNode(onnx::NodeProto node) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (auto item : node.input()) {

View File

@ -303,7 +303,7 @@
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
}); });
}else if (OpName == "MaxPool") { }else if (OpName == "MaxPool") {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, 1, 2, { ImportNodeMaxPool(node, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad","str","NOTSET"}
,{"ceil_mode","int","0"} ,{"ceil_mode","int","0"}
,{"dilations","", ""} ,{"dilations","", ""}

View File

@ -365,6 +365,7 @@ special cases:
def gen_code(schema,fefile) : def gen_code(schema,fefile) :
special_handler = dict([ special_handler = dict([
("Conv", "ImportNodeConv"), ("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"),
#("Transpose", "ImportNodeTranspose") #("Transpose", "ImportNodeTranspose")
]) ])
special_type = dict([ special_type = dict([

View File

@ -114,4 +114,15 @@ def ONNXConv3Op:ONNX_Op<"Conv3",
let results = (outs AnyTensor); let results = (outs AnyTensor);
} }
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
[NoSideEffect]> {
let summary = "ONNX MaxPool operation with a single output.";
let description = [{
"ONNX MaxPool operation with a single output."
"See ONNXMaxPoolOp for a full description of the MaxPool semantics."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
#endif // ONNX_OPS #endif // ONNX_OPS