Merge branch 'master' into matmul-shape

This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-15 18:03:03 -05:00 committed by GitHub
commit a87f01747a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 23 additions and 34 deletions

View File

@ -611,7 +611,7 @@ private:
* a specialized function is used
*/
void ImportNodeConv(
onnx::NodeProto node, int nIn, int nOut,
onnx::NodeProto node, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
@ -624,12 +624,12 @@ private:
// similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference
int nOps = node.input().size();
if (node.input().size() == 1) {
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs);
} else {
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs);
}
if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs);
else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs);
}
/*!

View File

@ -87,7 +87,7 @@
{"value","", ""}
});
}else if (OpName == "Conv") {
ImportNodeConv(node, 3, 1, {
ImportNodeConv(node, 1, {
{"auto_pad","str","NOTSET"}
,{"dilations","", ""}
,{"group","int", "1"}

View File

@ -90,39 +90,28 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
// or outputs. This decision affects only ONNX operations with optional
// arguments not ONNX operations with variadic operands.
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation";
let summary = "ONNX general matrix multiply operation without bias.";
let description = [{
The "onnx.gemm" generic matrix multiplication with bias.
The "onnx.Gemm" generic matrix multiplication without bias.
}];
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
let results = (outs AnyTensor);
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
def ONNXConv1Op:ONNX_Op<"Conv1",
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let summary = "ONNX Conv operation with no Bias operand.";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X);
let results = (outs AnyTensor);
}
def ONNXConv3Op:ONNX_Op<"Conv3",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
let results = (outs AnyTensor);
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",

View File

@ -424,9 +424,9 @@ void ONNXGemmOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
// FullGemm
// GemmNoBias
void ONNXFullGemmOp::inferShapes() {
void ONNXGemmNoBiasOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())

View File

@ -30,9 +30,9 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// Pattern-Match and Rewrite
//===----------------------------------------------------------------------===//
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z)
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXFullGemmOp $m1, $m2, $m3),
(ONNXGemmOp $m1, $m2, $m3),
[(HasOneUse $res)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)

View File

@ -114,7 +114,7 @@ public:
op->getName().getStringRef() != "onnx.Identity" &&
op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm" &&
op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose")
return false;

View File

@ -2,7 +2,7 @@
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-LABEL: test_matmul_add_simplification
// CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
// CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
"std.return"(%1) : (tensor<10x10xf32>) -> ()