Merge pull request #32 from clang-ykt/fix-conv

Fix convolution translation to ONNX dialect
This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-15 18:01:41 -05:00 committed by GitHub
commit c2d31c0b78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 23 deletions

View File

@ -611,7 +611,7 @@ private:
* a specialized function is used * a specialized function is used
*/ */
void ImportNodeConv( void ImportNodeConv(
onnx::NodeProto node, int nIn, int nOut, onnx::NodeProto node, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) { attrs) {
// Conv has attribute dilations, kernel_shape, pads, the default value of // Conv has attribute dilations, kernel_shape, pads, the default value of
@ -624,12 +624,12 @@ private:
// similar situation for pads, strides in AveragePool // similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool // axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference // TODO: fix this after type inference
int nOps = node.input().size();
if (node.input().size() == 1) { if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs);
} else { else
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs);
}
} }
/*! /*!

View File

@ -87,7 +87,7 @@
{"value","", ""} {"value","", ""}
}); });
}else if (OpName == "Conv") { }else if (OpName == "Conv") {
ImportNodeConv(node, 3, 1, { ImportNodeConv(node, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad","str","NOTSET"}
,{"dilations","", ""} ,{"dilations","", ""}
,{"group","int", "1"} ,{"group","int", "1"}

View File

@ -103,26 +103,15 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
} }
def ONNXConv1Op:ONNX_Op<"Conv1", def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "ONNX Conv operation"; let summary = "ONNX Conv operation with no Bias operand.";
let description = [{ let description = [{
"The convolution operator consumes an input tensor and a filter, and" "The convolution operator consumes an input tensor and a filter, and"
"computes the output." "computes the output."
}]; }];
let arguments = (ins AnyTensor:$X); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
let results = (outs AnyTensor); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
def ONNXConv3Op:ONNX_Op<"Conv3",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
let results = (outs AnyTensor);
} }
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",

View File

@ -315,7 +315,7 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
} }
def ONNXConvOp:ONNX_Op<"Conv", def ONNXConvOp:ONNX_Op<"Conv",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "ONNX Conv operation"; let summary = "ONNX Conv operation";
let description = [{ let description = [{