Fix convolution translation to MLIR.

This commit is contained in:
Doru Bercea 2020-01-15 13:26:50 -05:00
parent fc352745e0
commit 67ec9e9009
4 changed files with 13 additions and 24 deletions

View File

@ -611,7 +611,7 @@ private:
* a specialized function is used * a specialized function is used
*/ */
void ImportNodeConv( void ImportNodeConv(
onnx::NodeProto node, int nIn, int nOut, onnx::NodeProto node, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) { attrs) {
// Conv has attribute dilations, kernel_shape, pads, the default value of // Conv has attribute dilations, kernel_shape, pads, the default value of
@ -624,12 +624,12 @@ private:
// similar situation for pads, strides in AveragePool // similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool // axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference // TODO: fix this after type inference
int nOps = node.input().size();
if (node.input().size() == 1) { if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs);
} else { else
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs);
}
} }
void ImportNode(onnx::NodeProto node) { void ImportNode(onnx::NodeProto node) {

View File

@ -87,7 +87,7 @@
{"value","", ""} {"value","", ""}
}); });
}else if (OpName == "Conv") { }else if (OpName == "Conv") {
ImportNodeConv(node, 3, 1, { ImportNodeConv(node, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad","str","NOTSET"}
,{"dilations","", ""} ,{"dilations","", ""}
,{"group","int", "1"} ,{"group","int", "1"}

View File

@ -92,26 +92,15 @@ def ONNXFullGemmOp: ONNX_Op<"FullGemm",
let results = (outs AnyTensor); let results = (outs AnyTensor);
} }
def ONNXConv1Op:ONNX_Op<"Conv1", def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "ONNX Conv operation"; let summary = "ONNX Conv operation with no Bias operand.";
let description = [{ let description = [{
"The convolution operator consumes an input tensor and a filter, and" "The convolution operator consumes an input tensor and a filter, and"
"computes the output." "computes the output."
}]; }];
let arguments = (ins AnyTensor:$X); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
let results = (outs AnyTensor); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
def ONNXConv3Op:ONNX_Op<"Conv3",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
let results = (outs AnyTensor);
} }
#endif // ONNX_OPS #endif // ONNX_OPS

View File

@ -315,14 +315,14 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
} }
def ONNXConvOp:ONNX_Op<"Conv", def ONNXConvOp:ONNX_Op<"Conv",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "ONNX Conv operation"; let summary = "ONNX Conv operation";
let description = [{ let description = [{
"The convolution operator consumes an input tensor and a filter, and" "The convolution operator consumes an input tensor and a filter, and"
"computes the output." "computes the output."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$B); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
} }