From f1d20e368f778c45d165cf88f7e86cdf3ccc2179 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 21 Feb 2020 00:55:24 +0900 Subject: [PATCH] Add support of GemmNoBias (#91) * Add support of GemmNoBias * Fix a wrong indentation --- .../rewrite_patterns/math/gemm.inc | 52 +++++++++++-------- src/dialect/onnx/onnx_ops.cpp | 15 +++++- test/backend/test.py | 2 +- test/mlir/onnx/onnx_lowering.mlir | 32 +++++++++++- 4 files changed, 76 insertions(+), 25 deletions(-) diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc index d8bbc55..af1da9e 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc @@ -8,31 +8,34 @@ // //===----------------------------------------------------------------------===// +template struct ONNXGemmOpLowering : public ConversionPattern { ONNXGemmOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} + : ConversionPattern(GemmOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); + auto has_bias = (operands.size() == 3); Value A, B, C; A = operands[0]; B = operands[1]; - C = operands[2]; + if (has_bias) + C = operands[2]; auto memRefType = convertToMemRefType(*op->result_type_begin()); auto alphaAttr = FloatAttr::get(memRefType.getElementType(), - llvm::dyn_cast(op).alpha().convertToFloat()); + llvm::dyn_cast(op).alpha().convertToFloat()); auto betaAttr = FloatAttr::get(memRefType.getElementType(), - llvm::dyn_cast(op).beta().convertToFloat()); + llvm::dyn_cast(op).beta().convertToFloat()); auto alpha = rewriter.create(loc, alphaAttr); auto beta = rewriter.create(loc, betaAttr); - bool isTransA = (llvm::dyn_cast(op).transA() != 0); - bool isTransB = (llvm::dyn_cast(op).transB() != 0); + bool isTransA = (llvm::dyn_cast(op).transA() != 0); + bool isTransB = (llvm::dyn_cast(op).transB() != 0); // Insert an allocation and deallocation for the result of this operation. Value alloc; @@ -116,14 +119,16 @@ struct ONNXGemmOpLowering : public ConversionPattern { // GemmOp supports unidirectional broadcasting from C to A*B. // Hence, it must be enough to get broadcasting information for C only. std::map broadcastedDimInfo; - auto shape = C.getType().cast().getShape(); - for (int i = 0; i < shape.size(); ++i) { - if (shape[i] < 0) { - auto dim = rewriter.create(loc, C, i).getResult(); - auto one = rewriter.create(loc, 1); - auto isBroadcasted = - rewriter.create(loc, CmpIPredicate::eq, dim, one); - broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); + if (has_bias) { + auto shape = C.getType().cast().getShape(); + for (int i = 0; i < shape.size(); ++i) { + if (shape[i] < 0) { + auto dim = rewriter.create(loc, C, i).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); + } } } @@ -155,14 +160,18 @@ struct ONNXGemmOpLowering : public ConversionPattern { auto matmulIterateOp = rewriter.create(loc, reductionPack); // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) - auto loopCIVs = getLoopIVsForBroadcasting( - loc, rewriter, loopMNIVs, C, broadcastedDimInfo); - auto loadedC = rewriter.create(loc, C, loopCIVs); auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); auto alphaAB = rewriter.create(loc, alpha, loadedAB); - auto betaC = rewriter.create(loc, beta, loadedC); - auto Y = rewriter.create(loc, alphaAB, betaC); - rewriter.create(loc, Y, alloc, loopMNIVs); + if (has_bias) { + auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, + broadcastedDimInfo); + auto loadedC = rewriter.create(loc, C, loopCIVs); + auto betaC = rewriter.create(loc, beta, loadedC); + auto Y = rewriter.create(loc, alphaAB, betaC); + rewriter.create(loc, Y, alloc, loopMNIVs); + } else { + rewriter.create(loc, alphaAB, alloc, loopMNIVs); + } // Insert instructions to do matrix multiplication: A*B Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); @@ -203,5 +212,6 @@ struct ONNXGemmOpLowering : public ConversionPattern { void populateLoweringONNXGemmOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); } diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index f7e3623..4524467 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -586,9 +586,20 @@ void ONNXGemmNoBiasOp::inferShapes() { return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); + + int64_t M, N, K_A, K_B; + M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; + K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0]; + N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0]; + K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; + + if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) { + emitError("Tensor shapes mismatched."); + } + SmallVector dims; - dims.emplace_back(lhsTy.getShape()[0]); - dims.emplace_back(rhsTy.getShape()[1]); + dims.emplace_back(M); + dims.emplace_back(N); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } diff --git a/test/backend/test.py b/test/backend/test.py index 2d4e012..495f1c6 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -98,7 +98,7 @@ test_to_enable = [ "test_gemm_alpha_cpu", "test_gemm_beta_cpu", "test_gemm_default_matrix_bias_cpu", - # "test_gemm_default_no_bias_cpu", <- error, need support for optional operands + "test_gemm_default_no_bias_cpu", "test_gemm_default_scalar_bias_cpu", "test_gemm_default_single_elem_vector_bias_cpu", "test_gemm_default_vector_bias_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 8f4843a..08b2cf1 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -795,12 +795,42 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32 // CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32> // CHECK: } - // CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32> // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32> // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32 + // CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32> // CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32 // CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32 // CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32> + // CHECK: } + // CHECK: return [[RES]] : memref<10x10xf32> + // CHECK: } +} + +func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> { + %0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gemm_no_bias + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32 + // CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32 + // CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) { + // CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32> + // CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32> + // CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32 + // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32 + // CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: } + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32 + // CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: } // CHECK: return [[RES]] : memref<10x10xf32> // CHECK: } }