Fix convolution translation to MLIR.

This commit is contained in:
Doru Bercea 2020-01-15 13:26:50 -05:00
parent fc352745e0
commit 67ec9e9009
4 changed files with 13 additions and 24 deletions

View File

@ -611,7 +611,7 @@ private:
* a specialized function is used
*/
void ImportNodeConv(
onnx::NodeProto node, int nIn, int nOut,
onnx::NodeProto node, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
@ -624,12 +624,12 @@ private:
// similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference
int nOps = node.input().size();
if (node.input().size() == 1) {
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs);
} else {
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs);
}
if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs);
else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs);
}
void ImportNode(onnx::NodeProto node) {

View File

@ -87,7 +87,7 @@
{"value","", ""}
});
}else if (OpName == "Conv") {
ImportNodeConv(node, 3, 1, {
ImportNodeConv(node, 1, {
{"auto_pad","str","NOTSET"}
,{"dilations","", ""}
,{"group","int", "1"}

View File

@ -92,26 +92,15 @@ def ONNXFullGemmOp: ONNX_Op<"FullGemm",
let results = (outs AnyTensor);
}
def ONNXConv1Op:ONNX_Op<"Conv1",
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let summary = "ONNX Conv operation with no Bias operand.";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X);
let results = (outs AnyTensor);
}
def ONNXConv3Op:ONNX_Op<"Conv3",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
let results = (outs AnyTensor);
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
#endif // ONNX_OPS

View File

@ -322,7 +322,7 @@ def ONNXConvOp:ONNX_Op<"Conv",
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$B);
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}