Merge branch 'master' into fix-gemm

This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-15 17:50:15 -05:00 committed by GitHub
commit 514cbcb1dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 1 deletions

View File

@ -632,6 +632,21 @@ private:
} }
} }
/*!
* Special handle for MaxPool operations.
*/
void ImportNodeMaxPool(
onnx::NodeProto node, int nIn,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
int nOuts = node.output().size();
if (nOuts == 1) {
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts, attrs);
} else {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, nIn, nOuts, attrs);
}
}
void ImportNode(onnx::NodeProto node) { void ImportNode(onnx::NodeProto node) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (auto item : node.input()) {

View File

@ -303,7 +303,7 @@
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
}); });
}else if (OpName == "MaxPool") { }else if (OpName == "MaxPool") {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, 1, 2, { ImportNodeMaxPool(node, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad","str","NOTSET"}
,{"ceil_mode","int","0"} ,{"ceil_mode","int","0"}
,{"dilations","", ""} ,{"dilations","", ""}

View File

@ -365,6 +365,7 @@ special cases:
def gen_code(schema,fefile) : def gen_code(schema,fefile) :
special_handler = dict([ special_handler = dict([
("Conv", "ImportNodeConv"), ("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"),
#("Transpose", "ImportNodeTranspose") #("Transpose", "ImportNodeTranspose")
]) ])
special_type = dict([ special_type = dict([

View File

@ -78,6 +78,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
}]; }];
} }
//===----------------------------------------------------------------------===//
// ONNX Operations for handling optional arguments
//===----------------------------------------------------------------------===//
// To allow pattern matching on operations with optional arguments/outputs we
// implement variants of the original ONNX dialect operations. The ONNX
// operations automatically generated by the `gen_doc.py` script and included
// in the `onnxop.inc` file have all optional arguments and outputs present.
// In the operations below we include the variants with missing operands
// or outputs. This decision affects only ONNX operations with optional
// arguments not ONNX operations with variadic operands.
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
@ -114,4 +125,15 @@ def ONNXConv3Op:ONNX_Op<"Conv3",
let results = (outs AnyTensor); let results = (outs AnyTensor);
} }
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
[NoSideEffect]> {
let summary = "ONNX MaxPool operation with a single output.";
let description = [{
"ONNX MaxPool operation with a single output."
"See ONNXMaxPoolOp for a full description of the MaxPool semantics."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
#endif // ONNX_OPS #endif // ONNX_OPS