diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 027af43..dba9f14 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -90,17 +90,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { // or outputs. This decision affects only ONNX operations with optional // arguments not ONNX operations with variadic operands. -def ONNXFullGemmOp: ONNX_Op<"FullGemm", +def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", [NoSideEffect, DeclareOpInterfaceMethods]> { - let summary = "ONNX general matrix multiply operation"; + let summary = "ONNX general matrix multiply operation without bias."; let description = [{ - The "onnx.gemm" generic matrix multiplication with bias. + The "onnx.Gemm" generic matrix multiplication without bias. }]; - let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXConv1Op:ONNX_Op<"Conv1", diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 4b68fe7..985f63d 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -347,9 +347,9 @@ void ONNXGemmOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } -// FullGemm +// GemmNoBias -void ONNXFullGemmOp::inferShapes() { +void ONNXGemmNoBiasOp::inferShapes() { // Cannot infer shape if no shape exists. if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) diff --git a/src/pass/onnx_combine.td b/src/pass/onnx_combine.td index 8a40928..25a4656 100644 --- a/src/pass/onnx_combine.td +++ b/src/pass/onnx_combine.td @@ -30,9 +30,9 @@ def HasOneUse : Constraint>; // Pattern-Match and Rewrite //===----------------------------------------------------------------------===// -// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z) +// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z) def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), - (ONNXFullGemmOp $m1, $m2, $m3), + (ONNXGemmOp $m1, $m2, $m3), [(HasOneUse $res)]>; // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index cbdf04b..5ccb9a4 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -114,7 +114,7 @@ public: op->getName().getStringRef() != "onnx.Identity" && op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.Gemm" && - op->getName().getStringRef() != "onnx.FullGemm" && + op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Transpose") return false; diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index ad4fc6a..75fd4a5 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -2,7 +2,7 @@ func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> { // CHECK-LABEL: test_matmul_add_simplification - // CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + // CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> "std.return"(%1) : (tensor<10x10xf32>) -> ()