Merge pull request #34 from clang-ykt/fix-maxpool
Fix MaxPool translation to ONNX dialect.
This commit is contained in:
commit
b50fc1fdeb
|
@ -632,6 +632,21 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Special handle for MaxPool operations.
|
||||
*/
|
||||
void ImportNodeMaxPool(
|
||||
onnx::NodeProto node, int nIn,
|
||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||
attrs) {
|
||||
int nOuts = node.output().size();
|
||||
if (nOuts == 1) {
|
||||
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts, attrs);
|
||||
} else {
|
||||
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, nIn, nOuts, attrs);
|
||||
}
|
||||
}
|
||||
|
||||
void ImportNode(onnx::NodeProto node) {
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (auto item : node.input()) {
|
||||
|
|
|
@ -303,7 +303,7 @@
|
|||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "MaxPool") {
|
||||
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, 1, 2, {
|
||||
ImportNodeMaxPool(node, 1, {
|
||||
{"auto_pad","str","NOTSET"}
|
||||
,{"ceil_mode","int","0"}
|
||||
,{"dilations","", ""}
|
||||
|
|
|
@ -365,6 +365,7 @@ special cases:
|
|||
def gen_code(schema,fefile) :
|
||||
special_handler = dict([
|
||||
("Conv", "ImportNodeConv"),
|
||||
("MaxPool", "ImportNodeMaxPool"),
|
||||
#("Transpose", "ImportNodeTranspose")
|
||||
])
|
||||
special_type = dict([
|
||||
|
|
|
@ -78,6 +78,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNX Operations for handling optional arguments
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// To allow pattern matching on operations with optional arguments/outputs we
|
||||
// implement variants of the original ONNX dialect operations. The ONNX
|
||||
// operations automatically generated by the `gen_doc.py` script and included
|
||||
// in the `onnxop.inc` file have all optional arguments and outputs present.
|
||||
// In the operations below we include the variants with missing operands
|
||||
// or outputs. This decision affects only ONNX operations with optional
|
||||
// arguments not ONNX operations with variadic operands.
|
||||
|
||||
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
|
@ -114,4 +125,15 @@ def ONNXConv3Op:ONNX_Op<"Conv3",
|
|||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||
[NoSideEffect]> {
|
||||
let summary = "ONNX MaxPool operation with a single output.";
|
||||
let description = [{
|
||||
"ONNX MaxPool operation with a single output."
|
||||
"See ONNXMaxPoolOp for a full description of the MaxPool semantics."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
}
|
||||
|
||||
#endif // ONNX_OPS
|
||||
|
|
Loading…
Reference in New Issue