Merge pull request #32 from clang-ykt/fix-conv
Fix convolution translation to ONNX dialect
This commit is contained in:
commit
c2d31c0b78
|
@ -611,7 +611,7 @@ private:
|
|||
* a specialized function is used
|
||||
*/
|
||||
void ImportNodeConv(
|
||||
onnx::NodeProto node, int nIn, int nOut,
|
||||
onnx::NodeProto node, int nOut,
|
||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||
attrs) {
|
||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||
|
@ -624,12 +624,12 @@ private:
|
|||
// similar situation for pads, strides in AveragePool
|
||||
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
|
||||
// TODO: fix this after type inference
|
||||
int nOps = node.input().size();
|
||||
|
||||
if (node.input().size() == 1) {
|
||||
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs);
|
||||
} else {
|
||||
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs);
|
||||
}
|
||||
if (nOps == 2)
|
||||
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs);
|
||||
else
|
||||
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs);
|
||||
}
|
||||
|
||||
/*!
|
||||
|
|
|
@ -87,7 +87,7 @@
|
|||
{"value","", ""}
|
||||
});
|
||||
}else if (OpName == "Conv") {
|
||||
ImportNodeConv(node, 3, 1, {
|
||||
ImportNodeConv(node, 1, {
|
||||
{"auto_pad","str","NOTSET"}
|
||||
,{"dilations","", ""}
|
||||
,{"group","int", "1"}
|
||||
|
|
|
@ -103,26 +103,15 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
|||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
}
|
||||
|
||||
def ONNXConv1Op:ONNX_Op<"Conv1",
|
||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||
[NoSideEffect]> {
|
||||
let summary = "ONNX Conv operation";
|
||||
let summary = "ONNX Conv operation with no Bias operand.";
|
||||
let description = [{
|
||||
"The convolution operator consumes an input tensor and a filter, and"
|
||||
"computes the output."
|
||||
}];
|
||||
let arguments = (ins AnyTensor:$X);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def ONNXConv3Op:ONNX_Op<"Conv3",
|
||||
[NoSideEffect]> {
|
||||
let summary = "ONNX Conv operation";
|
||||
let description = [{
|
||||
"The convolution operator consumes an input tensor and a filter, and"
|
||||
"computes the output."
|
||||
}];
|
||||
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
|
||||
let results = (outs AnyTensor);
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
}
|
||||
|
||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||
|
|
|
@ -315,7 +315,7 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
|||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
}
|
||||
|
||||
def ONNXConvOp:ONNX_Op<"Conv",
|
||||
def ONNXConvOp:ONNX_Op<"Conv",
|
||||
[NoSideEffect]> {
|
||||
let summary = "ONNX Conv operation";
|
||||
let description = [{
|
||||
|
|
Loading…
Reference in New Issue