Merge branch 'master' into fix-gemm
This commit is contained in:
commit
514cbcb1dc
|
@ -632,6 +632,21 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Special handle for MaxPool operations.
|
||||||
|
*/
|
||||||
|
void ImportNodeMaxPool(
|
||||||
|
onnx::NodeProto node, int nIn,
|
||||||
|
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||||
|
attrs) {
|
||||||
|
int nOuts = node.output().size();
|
||||||
|
if (nOuts == 1) {
|
||||||
|
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts, attrs);
|
||||||
|
} else {
|
||||||
|
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, nIn, nOuts, attrs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ImportNode(onnx::NodeProto node) {
|
void ImportNode(onnx::NodeProto node) {
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (auto item : node.input()) {
|
for (auto item : node.input()) {
|
||||||
|
|
|
@ -303,7 +303,7 @@
|
||||||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "MaxPool") {
|
}else if (OpName == "MaxPool") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, 1, 2, {
|
ImportNodeMaxPool(node, 1, {
|
||||||
{"auto_pad","str","NOTSET"}
|
{"auto_pad","str","NOTSET"}
|
||||||
,{"ceil_mode","int","0"}
|
,{"ceil_mode","int","0"}
|
||||||
,{"dilations","", ""}
|
,{"dilations","", ""}
|
||||||
|
|
|
@ -365,6 +365,7 @@ special cases:
|
||||||
def gen_code(schema,fefile) :
|
def gen_code(schema,fefile) :
|
||||||
special_handler = dict([
|
special_handler = dict([
|
||||||
("Conv", "ImportNodeConv"),
|
("Conv", "ImportNodeConv"),
|
||||||
|
("MaxPool", "ImportNodeMaxPool"),
|
||||||
#("Transpose", "ImportNodeTranspose")
|
#("Transpose", "ImportNodeTranspose")
|
||||||
])
|
])
|
||||||
special_type = dict([
|
special_type = dict([
|
||||||
|
|
|
@ -78,6 +78,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNX Operations for handling optional arguments
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// To allow pattern matching on operations with optional arguments/outputs we
|
||||||
|
// implement variants of the original ONNX dialect operations. The ONNX
|
||||||
|
// operations automatically generated by the `gen_doc.py` script and included
|
||||||
|
// in the `onnxop.inc` file have all optional arguments and outputs present.
|
||||||
|
// In the operations below we include the variants with missing operands
|
||||||
|
// or outputs. This decision affects only ONNX operations with optional
|
||||||
|
// arguments not ONNX operations with variadic operands.
|
||||||
|
|
||||||
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
|
@ -114,4 +125,15 @@ def ONNXConv3Op:ONNX_Op<"Conv3",
|
||||||
let results = (outs AnyTensor);
|
let results = (outs AnyTensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||||
|
[NoSideEffect]> {
|
||||||
|
let summary = "ONNX MaxPool operation with a single output.";
|
||||||
|
let description = [{
|
||||||
|
"ONNX MaxPool operation with a single output."
|
||||||
|
"See ONNXMaxPoolOp for a full description of the MaxPool semantics."
|
||||||
|
}];
|
||||||
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
|
||||||
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // ONNX_OPS
|
#endif // ONNX_OPS
|
||||||
|
|
Loading…
Reference in New Issue