Merge pull request #32 from clang-ykt/fix-conv

Fix convolution translation to ONNX dialect
This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-15 18:01:41 -05:00 committed by GitHub
commit c2d31c0b78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 23 deletions

View File

@ -611,7 +611,7 @@ private:
* a specialized function is used
*/
void ImportNodeConv(
onnx::NodeProto node, int nIn, int nOut,
onnx::NodeProto node, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
@ -624,12 +624,12 @@ private:
// similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference
int nOps = node.input().size();
if (node.input().size() == 1) {
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs);
} else {
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs);
}
if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs);
else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs);
}
/*!

View File

@ -87,7 +87,7 @@
{"value","", ""}
});
}else if (OpName == "Conv") {
ImportNodeConv(node, 3, 1, {
ImportNodeConv(node, 1, {
{"auto_pad","str","NOTSET"}
,{"dilations","", ""}
,{"group","int", "1"}

View File

@ -103,26 +103,15 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
def ONNXConv1Op:ONNX_Op<"Conv1",
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let summary = "ONNX Conv operation with no Bias operand.";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X);
let results = (outs AnyTensor);
}
def ONNXConv3Op:ONNX_Op<"Conv3",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
let results = (outs AnyTensor);
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",

View File

@ -315,7 +315,7 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
def ONNXConvOp:ONNX_Op<"Conv",
def ONNXConvOp:ONNX_Op<"Conv",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let description = [{