From 400676e37190b6b62323a20faca40e82f7c9bbb9 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 30 Jan 2020 01:11:49 +0900 Subject: [PATCH] Lowering Gemm (#19) * Initial implementation * Support transposing inputs * Revise unidirectional broadcasting and unknown dimensions * Revise gemm * Add testcase * Rename some variables * Update SharingWork.md * Change from the use of Value* to Value * Insert deallocation * Initilize the output matrix and fix wrong computation * Add end-to-end testcases * Edit lowering tests * Change attribute names * Use emplace_push for SmallVector * Use the new way of getting attributes * Revise the use of attributes * Check the bias's shape Co-authored-by: Gheorghe-Teodor Bercea --- SharingWork.md | 2 +- src/dialect/onnx/onnx_ops.cpp | 30 +++- src/pass/lower_frontend_to_krnl.cpp | 217 +++++++++++++++++++++++++++- test/backend/test.py | 13 ++ test/mlir/onnx/onnx_lowering.mlir | 31 ++++ 5 files changed, 287 insertions(+), 6 deletions(-) diff --git a/SharingWork.md b/SharingWork.md index fe43494..625db96 100644 --- a/SharingWork.md +++ b/SharingWork.md @@ -15,7 +15,7 @@ ONNX operations for which some work is needed. | Elu | Tung | v | v | | | Exp | Tung | v | v | | | FullGemm | | | | noU | -| Gemm | | | | noU | +| Gemm | Tung | v | | U | | HardSigmoid | Tung | v | v | | | LeakyRelu | Tung | v | v | | | MatMul | | | | noM | diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index bcb9b27..33db13e 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -487,13 +487,37 @@ void ONNXMatMulOp::inferShapes() { void ONNXGemmOp::inferShapes() { // Cannot infer shape if no shape exists. if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) + !getOperand(1).getType().isa() || + !getOperand(2).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); + auto biasTy = getOperand(2).getType().cast(); + + int64_t M, N, K_A, K_B; + M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; + K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0]; + N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0]; + K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; + + if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) { + emitError("Tensor shapes mismatched."); + } + + // Check whether bias is unidirectional broadcasting or not. + auto shape = biasTy.getShape(); + int rank = shape.size(); + if ((rank > 2) || + (rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] && + shape[rank - 1] != 1) || + (rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] && + shape[rank - 2] != 1)) { + emitError("Bias shape mismatched."); + } + SmallVector dims; - dims.emplace_back(lhsTy.getShape()[0]); - dims.emplace_back(rhsTy.getShape()[1]); + dims.emplace_back(M); + dims.emplace_back(N); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 0695c26..9c45dfe 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -1176,6 +1176,219 @@ struct ONNXReshapeOpLowering : public ConversionPattern { } }; +struct ONNXGemmOpLowering : public ConversionPattern { + ONNXGemmOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + Value A, B, C; + A = operands[0]; + B = operands[1]; + C = operands[2]; + + auto alphaAttr = FloatAttr::get(tensorType.getElementType(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttr = FloatAttr::get(tensorType.getElementType(), + llvm::dyn_cast(op).beta().convertToFloat()); + auto alpha = rewriter.create(loc, alphaAttr); + auto beta = rewriter.create(loc, betaAttr); + + bool isTransA = (llvm::dyn_cast(op).transA() != 0); + bool isTransB = (llvm::dyn_cast(op).transB() != 0); + + // Result type + auto memRefType = convertTensorToMemRef(tensorType); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else { + auto memRefShape = memRefType.getShape(); + SmallVector allocOperands; + if (memRefShape[0] < 0) { + auto dim = rewriter.create(loc, A, (isTransA) ? 1 : 0); + allocOperands.emplace_back(dim); + } + if (memRefShape[1] < 0) { + auto dim = rewriter.create(loc, B, (isTransB) ? 0 : 1); + allocOperands.emplace_back(dim); + } + alloc = rewriter.create(loc, memRefType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t numLoops = 3; + + // Define loops. + auto loopsOp = rewriter.create(loc, numLoops); + std::vector originalLoops; + originalLoops.reserve(numLoops); + for (auto result : loopsOp.getResults()) { + originalLoops.push_back(result); + } + + auto optimizedLoopsOp = rewriter.create(loc, numLoops); + std::vector optimizedLoops; + optimizedLoops.reserve(numLoops); + for (auto result : optimizedLoopsOp.getResults()) { + optimizedLoops.push_back(result); + } + Block &optimizationBlock = optimizedLoopsOp.region().front(); + + // We have two Krnl loops: + // - Outer loop iterates over the output matrix dimensions, and + // - Reduction loop iterates over the reduction dimension. + + // Outer loop + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(2); + optimizedOuterLoops.reserve(2); + for (int i = 0; i < 2; ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, + optimizedOuterLoops); + // Induction variables for the outer loops + for (int i = 0; i < 2; ++i) { + if (memRefShape[i] < 0) { + outerPack.pushConstantBound(0); + outerPack.pushOperandBound( + rewriter.create(loc, alloc, i).getResult()); + } else { + outerPack.pushConstantBound(0); + outerPack.pushConstantBound(memRefShape[i]); + } + } + // Reduction loop + std::vector reductionLoops, optimizedReductionLoops; + reductionLoops.reserve(1); + optimizedReductionLoops.reserve(1); + reductionLoops.push_back(originalLoops[2]); + optimizedReductionLoops.push_back(optimizedLoops[2]); + KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, + optimizedReductionLoops); + // Induction variable for the reduction dimension + // Try to find and use a static value from A or B first. + // If it failed then use a dynamic value. + auto ATy = A.getType().cast(); + auto BTy = B.getType().cast(); + int64_t K_A_Idx = (isTransA) ? 0 : 1; + int64_t K_B_Idx = (isTransB) ? 1 : 0; + reductionPack.pushConstantBound(0); + if (ATy.getShape()[K_A_Idx] != -1) + reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); + else + if (BTy.getShape()[K_B_Idx] != -1) + reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); + else + reductionPack.pushOperandBound( + rewriter.create(loc, B, K_B_Idx).getResult()); + + // Get run-time dimension information for unknown dimensions used for + // broadcasting. + // GemmOp supports unidirectional broadcasting from C to A*B. + // Hence, it must be enough to get broadcasting information for C only. + std::map broadcastedDimInfo; + auto shape = C.getType().cast().getShape(); + for (int i = 0; i < shape.size(); ++i) { + if (shape[i] < 0) { + auto dim = rewriter.create(loc, C, i).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); + } + } + + auto outerIterateOp = rewriter.create(loc, outerPack); + + // Now perform the insertions into the body of the + // just generated instructions: + + // No optimization + rewriter.setInsertionPointToEnd(&optimizationBlock); + rewriter.create(loc, originalLoops); + rewriter.setInsertionPoint(optimizedLoopsOp); + + // Insert instructions inside the outer loop. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + + // Induction variables + SmallVector loopMNIVs; + for (auto arg : outerIterationBlock.getArguments()) { + loopMNIVs.emplace_back(arg); + } + + // Initialize the output of A*B + auto zero = rewriter.create( + loc, FloatAttr::get(memRefType.getElementType(), 0)); + rewriter.create(loc, zero, alloc, loopMNIVs); + + // Compute A*B + auto matmulIterateOp = rewriter.create(loc, reductionPack); + + // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) + auto loopCIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopMNIVs, C, broadcastedDimInfo); + auto loadedC = rewriter.create(loc, C, loopCIVs); + auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); + auto alphaAB = rewriter.create(loc, alpha, loadedAB); + auto betaC = rewriter.create(loc, beta, loadedC); + auto Y = rewriter.create(loc, alphaAB, betaC); + rewriter.create(loc, Y, alloc, loopMNIVs); + + // Insert instructions to do matrix multiplication: A*B + Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&matmulIterationBlock); + + // Induction variables + SmallVector loopKIVs, loopAIVs, loopBIVs; + for (auto arg : matmulIterationBlock.getArguments()) + loopKIVs.emplace_back(arg); + if (isTransA) { + loopAIVs.emplace_back(loopKIVs[0]); + loopAIVs.emplace_back(loopMNIVs[0]); + } else { + loopAIVs.emplace_back(loopMNIVs[0]); + loopAIVs.emplace_back(loopKIVs[0]); + } + if (isTransB) { + loopBIVs.emplace_back(loopMNIVs[1]); + loopBIVs.emplace_back(loopKIVs[0]); + } else { + loopBIVs.emplace_back(loopKIVs[0]); + loopBIVs.emplace_back(loopMNIVs[1]); + } + + // Matmul computation + auto loadedA = rewriter.create(loc, A, loopAIVs); + auto loadedB = rewriter.create(loc, B, loopBIVs); + auto loadedY = rewriter.create(loc, alloc, loopMNIVs); + auto AB = rewriter.create(loc, loadedA, loadedB); + auto accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create(loc, accumulated, alloc, loopMNIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + struct ONNXUnsqueezeOpLowering : public ConversionPattern { ONNXUnsqueezeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} @@ -1242,7 +1455,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { dealloc.getOperation()->moveBefore(&parentBlock->back()); } } - rewriter.create(loc, alloc, operands[0], tensorSize); rewriter.replaceOp(op, alloc); return matchSuccess(); @@ -1377,7 +1589,8 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXReshapeOpLowering, ONNXEntryPointLowering, - ONNXSoftmaxOpLowering, ONNXUnsqueezeOpLowering>(&getContext()); + ONNXSoftmaxOpLowering, ONNXGemmOpLowering, + ONNXUnsqueezeOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/test/backend/test.py b/test/backend/test.py index 6ec754e..1d45c14 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -93,6 +93,19 @@ test_to_enable = [ "test_exp_cpu", "test_exp_example_cpu", + # Gemm Op: + "test_gemm_all_attributes_cpu", + "test_gemm_alpha_cpu", + "test_gemm_beta_cpu", + "test_gemm_default_matrix_bias_cpu", + # "test_gemm_default_no_bias_cpu", <- error, need support for optional operands + # "test_gemm_default_scalar_bias_cpu", <- error, shapes mismatch, why? + "test_gemm_default_single_elem_vector_bias_cpu", + "test_gemm_default_vector_bias_cpu", + "test_gemm_default_zero_bias_cpu", + "test_gemm_transposeA_cpu", + "test_gemm_transposeB_cpu", + # Hard Sigmoid Op: "test_hardsigmoid_cpu", "test_hardsigmoid_default_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 2217403..394ef3c 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -633,6 +633,37 @@ func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { // CHECK: return [[RES]] : memref<10x10xf32> } +func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tensor<10xf32>) -> tensor<*xf32> { + %0 ="onnx.Gemm"(%arg0, %arg1, %arg2) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>, tensor<10xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gemm + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32 + // CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32 + // CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg3 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg4 = 0 to 10) { + // CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg5 = 0 to 5) { + // CHECK: [[A:%.+]] = load %arg0[%arg5, %arg3] : memref<5x10xf32> + // CHECK: [[B:%.+]] = load %arg1[%arg5, %arg4] : memref<5x10xf32> + // CHECK: [[Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32> + // CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32 + // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32 + // CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32> + // CHECK: } + // CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32> + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32> + // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32 + // CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32 + // CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32 + // CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32> + // CHECK: return [[RES]] : memref<10x10xf32> + // CHECK: } +} + func @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Sqrt"(%arg0) : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> ()