From a42fdd08f335a321606d6e20411668261d67925f Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Wed, 15 Jan 2020 14:11:32 -0500 Subject: [PATCH] Fix Gemm translation to ONNX dialect. --- src/dialect/onnx/onnx.td | 10 +++++----- src/dialect/onnx/onnx_ops.cpp | 2 +- src/pass/onnx_combine.td | 2 +- test/mlir/onnx/onnx_canonicalization.mlir | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 87901ce..0e7f0d9 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -79,17 +79,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { } -def ONNXFullGemmOp: ONNX_Op<"FullGemm", +def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", [NoSideEffect, DeclareOpInterfaceMethods]> { - let summary = "ONNX general matrix multiply operation"; + let summary = "ONNX general matrix multiply operation without bias."; let description = [{ - The "onnx.gemm" generic matrix multiplication with bias. + The "onnx.Gemm" generic matrix multiplication without bias. }]; - let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXConv1Op:ONNX_Op<"Conv1", diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 4b68fe7..babdcb0 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -349,7 +349,7 @@ void ONNXGemmOp::inferShapes() { // FullGemm -void ONNXFullGemmOp::inferShapes() { +void ONNXGemmNoBiasOp::inferShapes() { // Cannot infer shape if no shape exists. if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) diff --git a/src/pass/onnx_combine.td b/src/pass/onnx_combine.td index 8a40928..bec67cd 100644 --- a/src/pass/onnx_combine.td +++ b/src/pass/onnx_combine.td @@ -32,7 +32,7 @@ def HasOneUse : Constraint>; // onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z) def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), - (ONNXFullGemmOp $m1, $m2, $m3), + (ONNXGemmOp $m1, $m2, $m3), [(HasOneUse $res)]>; // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index ad4fc6a..75fd4a5 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -2,7 +2,7 @@ func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> { // CHECK-LABEL: test_matmul_add_simplification - // CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + // CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> "std.return"(%1) : (tensor<10x10xf32>) -> ()