Add support of GemmNoBias (#91)
* Add support of GemmNoBias * Fix a wrong indentation
This commit is contained in:
		
							parent
							
								
									a3f042220e
								
							
						
					
					
						commit
						f1d20e368f
					
				|  | @ -8,31 +8,34 @@ | |||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| template <typename GemmOp> | ||||
| struct ONNXGemmOpLowering : public ConversionPattern { | ||||
|   ONNXGemmOpLowering(MLIRContext *ctx) | ||||
|       : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} | ||||
|       : ConversionPattern(GemmOp::getOperationName(), 1, ctx) {} | ||||
| 
 | ||||
|   PatternMatchResult | ||||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||
|                   ConversionPatternRewriter &rewriter) const final { | ||||
|     auto loc = op->getLoc(); | ||||
|     auto has_bias = (operands.size() == 3); | ||||
| 
 | ||||
|     Value A, B, C; | ||||
|     A = operands[0]; | ||||
|     B = operands[1]; | ||||
|     C = operands[2]; | ||||
|     if (has_bias) | ||||
|       C = operands[2]; | ||||
| 
 | ||||
|     auto memRefType = convertToMemRefType(*op->result_type_begin()); | ||||
| 
 | ||||
|     auto alphaAttr = FloatAttr::get(memRefType.getElementType(), | ||||
|         llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat()); | ||||
|         llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat()); | ||||
|     auto betaAttr = FloatAttr::get(memRefType.getElementType(), | ||||
|         llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat()); | ||||
|         llvm::dyn_cast<GemmOp>(op).beta().convertToFloat()); | ||||
|     auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); | ||||
|     auto beta = rewriter.create<ConstantOp>(loc, betaAttr); | ||||
| 
 | ||||
|     bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0); | ||||
|     bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0); | ||||
|     bool isTransA = (llvm::dyn_cast<GemmOp>(op).transA() != 0); | ||||
|     bool isTransB = (llvm::dyn_cast<GemmOp>(op).transB() != 0); | ||||
| 
 | ||||
|     // Insert an allocation and deallocation for the result of this operation.
 | ||||
|     Value alloc; | ||||
|  | @ -116,14 +119,16 @@ struct ONNXGemmOpLowering : public ConversionPattern { | |||
|     // GemmOp supports unidirectional broadcasting from C to A*B.
 | ||||
|     // Hence, it must be enough to get broadcasting information for C only.
 | ||||
|     std::map<int, Value> broadcastedDimInfo; | ||||
|     auto shape = C.getType().cast<MemRefType>().getShape(); | ||||
|     for (int i = 0; i < shape.size(); ++i) { | ||||
|       if (shape[i] < 0) { | ||||
|         auto dim = rewriter.create<DimOp>(loc, C, i).getResult(); | ||||
|         auto one = rewriter.create<ConstantIndexOp>(loc, 1); | ||||
|         auto isBroadcasted = | ||||
|           rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one); | ||||
|         broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); | ||||
|     if (has_bias) { | ||||
|       auto shape = C.getType().cast<MemRefType>().getShape(); | ||||
|       for (int i = 0; i < shape.size(); ++i) { | ||||
|         if (shape[i] < 0) { | ||||
|           auto dim = rewriter.create<DimOp>(loc, C, i).getResult(); | ||||
|           auto one = rewriter.create<ConstantIndexOp>(loc, 1); | ||||
|           auto isBroadcasted = | ||||
|               rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one); | ||||
|           broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|  | @ -155,14 +160,18 @@ struct ONNXGemmOpLowering : public ConversionPattern { | |||
|     auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack); | ||||
| 
 | ||||
|     // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
 | ||||
|     auto loopCIVs = getLoopIVsForBroadcasting( | ||||
|         loc, rewriter, loopMNIVs, C, broadcastedDimInfo); | ||||
|     auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs); | ||||
|     auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs); | ||||
|     auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB); | ||||
|     auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC); | ||||
|     auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC); | ||||
|     rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs); | ||||
|     if (has_bias) { | ||||
|       auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, | ||||
|                                                 broadcastedDimInfo); | ||||
|       auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs); | ||||
|       auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC); | ||||
|       auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC); | ||||
|       rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs); | ||||
|     } else { | ||||
|       rewriter.create<StoreOp>(loc, alphaAB, alloc, loopMNIVs); | ||||
|     } | ||||
| 
 | ||||
|     // Insert instructions to do matrix multiplication: A*B
 | ||||
|     Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); | ||||
|  | @ -203,5 +212,6 @@ struct ONNXGemmOpLowering : public ConversionPattern { | |||
| 
 | ||||
| void populateLoweringONNXGemmOpPattern( | ||||
|     OwningRewritePatternList &patterns, MLIRContext *ctx) { | ||||
|   patterns.insert<ONNXGemmOpLowering>(ctx); | ||||
|   patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx); | ||||
|   patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx); | ||||
| } | ||||
|  |  | |||
|  | @ -586,9 +586,20 @@ void ONNXGemmNoBiasOp::inferShapes() { | |||
|     return; | ||||
|   auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); | ||||
| 
 | ||||
|   int64_t M, N, K_A, K_B; | ||||
|   M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; | ||||
|   K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0]; | ||||
|   N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0]; | ||||
|   K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; | ||||
| 
 | ||||
|   if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) { | ||||
|     emitError("Tensor shapes mismatched."); | ||||
|   } | ||||
| 
 | ||||
|   SmallVector<int64_t, 2> dims; | ||||
|   dims.emplace_back(lhsTy.getShape()[0]); | ||||
|   dims.emplace_back(rhsTy.getShape()[1]); | ||||
|   dims.emplace_back(M); | ||||
|   dims.emplace_back(N); | ||||
|   getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -98,7 +98,7 @@ test_to_enable = [ | |||
|     "test_gemm_alpha_cpu", | ||||
|     "test_gemm_beta_cpu", | ||||
|     "test_gemm_default_matrix_bias_cpu", | ||||
|     # "test_gemm_default_no_bias_cpu", <- error, need support for optional operands | ||||
|     "test_gemm_default_no_bias_cpu", | ||||
|     "test_gemm_default_scalar_bias_cpu", | ||||
|     "test_gemm_default_single_elem_vector_bias_cpu", | ||||
|     "test_gemm_default_vector_bias_cpu", | ||||
|  |  | |||
|  | @ -795,12 +795,42 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso | |||
|   // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32 | ||||
|   // CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32> | ||||
|   // CHECK: } | ||||
|   // CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32> | ||||
|   // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32> | ||||
|   // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32 | ||||
|   // CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32> | ||||
|   // CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32 | ||||
|   // CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32 | ||||
|   // CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32> | ||||
|   // CHECK: } | ||||
|   // CHECK: return [[RES]] : memref<10x10xf32> | ||||
|   // CHECK: } | ||||
| } | ||||
| 
 | ||||
| func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> { | ||||
|   %0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_gemm_no_bias | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> | ||||
|   // CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32 | ||||
|   // CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32 | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3 | ||||
|   // CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops  { | ||||
|   // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) { | ||||
|   // CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32> | ||||
|   // CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32> | ||||
|   // CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> | ||||
|   // CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32 | ||||
|   // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32 | ||||
|   // CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32> | ||||
|   // CHECK: } | ||||
|   // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> | ||||
|   // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32 | ||||
|   // CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32> | ||||
|   // CHECK: } | ||||
|   // CHECK: return [[RES]] : memref<10x10xf32> | ||||
|   // CHECK: } | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue