Fix Gemm translation to ONNX dialect.

This commit is contained in:
Doru Bercea 2020-01-15 14:11:32 -05:00
parent deef363309
commit a42fdd08f3
4 changed files with 8 additions and 8 deletions

View File

@ -79,17 +79,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
} }
def ONNXFullGemmOp: ONNX_Op<"FullGemm", def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation"; let summary = "ONNX general matrix multiply operation without bias.";
let description = [{ let description = [{
The "onnx.gemm" generic matrix multiplication with bias. The "onnx.Gemm" generic matrix multiplication without bias.
}]; }];
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in);
let results = (outs AnyTensor); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
} }
def ONNXConv1Op:ONNX_Op<"Conv1", def ONNXConv1Op:ONNX_Op<"Conv1",

View File

@ -349,7 +349,7 @@ void ONNXGemmOp::inferShapes() {
// FullGemm // FullGemm
void ONNXFullGemmOp::inferShapes() { void ONNXGemmNoBiasOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())

View File

@ -32,7 +32,7 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z) // onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z)
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXFullGemmOp $m1, $m2, $m3), (ONNXGemmOp $m1, $m2, $m3),
[(HasOneUse $res)]>; [(HasOneUse $res)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)

View File

@ -2,7 +2,7 @@
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> { func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-LABEL: test_matmul_add_simplification // CHECK-LABEL: test_matmul_add_simplification
// CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> // CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
"std.return"(%1) : (tensor<10x10xf32>) -> () "std.return"(%1) : (tensor<10x10xf32>) -> ()