diff --git a/doc/Dialects/onnx.md b/doc/Dialects/onnx.md index 95746f6..69349aa 100644 --- a/doc/Dialects/onnx.md +++ b/doc/Dialects/onnx.md @@ -1558,33 +1558,6 @@ ONNX Gather operation 1. `output`: memref of any type values or tensor of any type values -### onnx.GemmNoBias (ONNXGemmNoBiasOp) -ONNX general matrix multiply operation without bias. - -#### Description: - - -The "onnx.Gemm" generic matrix multiplication without bias. - - -#### Operands: - -1. `A`: memref of any type values or tensor of any type values -1. `B`: memref of any type values or tensor of any type values - -#### Attributes: - -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `alpha` | `FloatAttr` | 32-bit float attribute attribute | -| `beta` | `FloatAttr` | 32-bit float attribute attribute | -| `transA` | `IntegerAttr` | 64-bit integer attribute attribute | -| `transB` | `IntegerAttr` | 64-bit integer attribute attribute | - -#### Results: - -1. `o_Y`: memref of any type values or tensor of any type values - ### onnx.Gemm (ONNXGemmOp) ONNX Gemm operation diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc index 8a9bf8e..ee395b5 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc @@ -17,9 +17,7 @@ struct ONNXGemmOpLowering : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - // The first predicate is unnecessary when we remove ONXGemmNoBiasOp. - bool hasBias = (operands.size() == 3) && - (!op->getOperand(2).getType().isa()); + bool hasBias = !op->getOperand(2).getType().isa(); Value A, B, C; A = operands[0]; @@ -215,5 +213,4 @@ struct ONNXGemmOpLowering : public ConversionPattern { void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert>(ctx); - patterns.insert>(ctx); } diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 43d4a10..1cc88c3 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { // or outputs. This decision affects only ONNX operations with optional // arguments not ONNX operations with variadic operands. -def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", - [NoSideEffect, DeclareOpInterfaceMethods]> { - let summary = "ONNX general matrix multiply operation without bias."; - let description = [{ - - The "onnx.Gemm" generic matrix multiplication without bias. - - }]; - - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, - DefaultValuedAttr:$alpha, - DefaultValuedAttr:$beta, - DefaultValuedAttr:$transA, - DefaultValuedAttr:$transB); - - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); -} - def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", [NoSideEffect, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 4de481a..5d93020 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -565,32 +565,6 @@ void ONNXGemmOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } -// GemmNoBias - -void ONNXGemmNoBiasOp::inferShapes() { - // Cannot infer shape if no shape exists. - if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) - return; - auto lhsTy = getOperand(0).getType().cast(); - auto rhsTy = getOperand(1).getType().cast(); - - int64_t M, N, K_A, K_B; - M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; - K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0]; - N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0]; - K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; - - if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) { - emitError("Tensor shapes mismatched."); - } - - SmallVector dims; - dims.emplace_back(M); - dims.emplace_back(N); - getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); -} - /// BatchNormalizationTestMode void ONNXBatchNormalizationTestModeOp::inferShapes() { // Cannot infer shape if no shape exists. diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 7ff0374..4038ec3 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -118,7 +118,6 @@ public: op->getName().getStringRef() != "onnx.Identity" && op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.Gemm" && - op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Transpose" && op->getName().getStringRef() != "onnx.ReduceMax" && diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 9da12ac..c35536d 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso // CHECK: } } -func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> { - %0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_gemm_no_bias - // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> - // CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32 - // CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32 - // CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3 - // CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2 - // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) { - // CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32> - // CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32> - // CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> - // CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32 - // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32 - // CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32> - // CHECK: } - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> - // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32 - // CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32> - // CHECK: } - // CHECK: return [[RES]] : memref<10x10xf32> - // CHECK: } -} - func @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Sqrt"(%arg0) : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> ()