Merge branch 'master' into matmul-shape

This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-15 18:03:03 -05:00 committed by GitHub
commit a87f01747a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 23 additions and 34 deletions

View File

@ -611,7 +611,7 @@ private:
* a specialized function is used * a specialized function is used
*/ */
void ImportNodeConv( void ImportNodeConv(
onnx::NodeProto node, int nIn, int nOut, onnx::NodeProto node, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) { attrs) {
// Conv has attribute dilations, kernel_shape, pads, the default value of // Conv has attribute dilations, kernel_shape, pads, the default value of
@ -624,12 +624,12 @@ private:
// similar situation for pads, strides in AveragePool // similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool // axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference // TODO: fix this after type inference
int nOps = node.input().size();
if (node.input().size() == 1) { if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs);
} else { else
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs);
}
} }
/*! /*!

View File

@ -87,7 +87,7 @@
{"value","", ""} {"value","", ""}
}); });
}else if (OpName == "Conv") { }else if (OpName == "Conv") {
ImportNodeConv(node, 3, 1, { ImportNodeConv(node, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad","str","NOTSET"}
,{"dilations","", ""} ,{"dilations","", ""}
,{"group","int", "1"} ,{"group","int", "1"}

View File

@ -90,39 +90,28 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
// or outputs. This decision affects only ONNX operations with optional // or outputs. This decision affects only ONNX operations with optional
// arguments not ONNX operations with variadic operands. // arguments not ONNX operations with variadic operands.
def ONNXFullGemmOp: ONNX_Op<"FullGemm", def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation"; let summary = "ONNX general matrix multiply operation without bias.";
let description = [{ let description = [{
The "onnx.gemm" generic matrix multiplication with bias. The "onnx.Gemm" generic matrix multiplication without bias.
}]; }];
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in);
let results = (outs AnyTensor); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
} }
def ONNXConv1Op:ONNX_Op<"Conv1", def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
[NoSideEffect]> { [NoSideEffect]> {
let summary = "ONNX Conv operation"; let summary = "ONNX Conv operation with no Bias operand.";
let description = [{ let description = [{
"The convolution operator consumes an input tensor and a filter, and" "The convolution operator consumes an input tensor and a filter, and"
"computes the output." "computes the output."
}]; }];
let arguments = (ins AnyTensor:$X); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
let results = (outs AnyTensor); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
def ONNXConv3Op:ONNX_Op<"Conv3",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
let results = (outs AnyTensor);
} }
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",

View File

@ -424,9 +424,9 @@ void ONNXGemmOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
} }
// FullGemm // GemmNoBias
void ONNXFullGemmOp::inferShapes() { void ONNXGemmNoBiasOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())

View File

@ -30,9 +30,9 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// Pattern-Match and Rewrite // Pattern-Match and Rewrite
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z) // onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXFullGemmOp $m1, $m2, $m3), (ONNXGemmOp $m1, $m2, $m3),
[(HasOneUse $res)]>; [(HasOneUse $res)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)

View File

@ -114,7 +114,7 @@ public:
op->getName().getStringRef() != "onnx.Identity" && op->getName().getStringRef() != "onnx.Identity" &&
op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm" && op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose") op->getName().getStringRef() != "onnx.Transpose")
return false; return false;

View File

@ -2,7 +2,7 @@
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> { func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-LABEL: test_matmul_add_simplification // CHECK-LABEL: test_matmul_add_simplification
// CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> // CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
"std.return"(%1) : (tensor<10x10xf32>) -> () "std.return"(%1) : (tensor<10x10xf32>) -> ()