Lowering Gemm (#19)
* Initial implementation * Support transposing inputs * Revise unidirectional broadcasting and unknown dimensions * Revise gemm * Add testcase * Rename some variables * Update SharingWork.md * Change from the use of Value* to Value * Insert deallocation * Initilize the output matrix and fix wrong computation * Add end-to-end testcases * Edit lowering tests * Change attribute names * Use emplace_push for SmallVector * Use the new way of getting attributes * Revise the use of attributes * Check the bias's shape Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
9e82d388f0
commit
400676e371
|
@ -15,7 +15,7 @@ ONNX operations for which some work is needed.
|
|||
| Elu | Tung | v | v | |
|
||||
| Exp | Tung | v | v | |
|
||||
| FullGemm | | | | noU |
|
||||
| Gemm | | | | noU |
|
||||
| Gemm | Tung | v | | U |
|
||||
| HardSigmoid | Tung | v | v | |
|
||||
| LeakyRelu | Tung | v | v | |
|
||||
| MatMul | | | | noM |
|
||||
|
|
|
@ -487,13 +487,37 @@ void ONNXMatMulOp::inferShapes() {
|
|||
void ONNXGemmOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(2).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
|
||||
|
||||
int64_t M, N, K_A, K_B;
|
||||
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
||||
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
|
||||
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
|
||||
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
||||
|
||||
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
||||
emitError("Tensor shapes mismatched.");
|
||||
}
|
||||
|
||||
// Check whether bias is unidirectional broadcasting or not.
|
||||
auto shape = biasTy.getShape();
|
||||
int rank = shape.size();
|
||||
if ((rank > 2) ||
|
||||
(rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] &&
|
||||
shape[rank - 1] != 1) ||
|
||||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] &&
|
||||
shape[rank - 2] != 1)) {
|
||||
emitError("Bias shape mismatched.");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
dims.emplace_back(M);
|
||||
dims.emplace_back(N);
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
|
|
|
@ -1176,6 +1176,219 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
struct ONNXGemmOpLowering : public ConversionPattern {
|
||||
ONNXGemmOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
Value A, B, C;
|
||||
A = operands[0];
|
||||
B = operands[1];
|
||||
C = operands[2];
|
||||
|
||||
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
|
||||
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||
|
||||
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
|
||||
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
|
||||
|
||||
// Result type
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else {
|
||||
auto memRefShape = memRefType.getShape();
|
||||
SmallVector<Value, 2> allocOperands;
|
||||
if (memRefShape[0] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, A, (isTransA) ? 1 : 0);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
if (memRefShape[1] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, B, (isTransB) ? 0 : 1);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||
if (insertDealloc) {
|
||||
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
}
|
||||
|
||||
// Number of loops
|
||||
auto memRefShape = memRefType.getShape();
|
||||
int64_t numLoops = 3;
|
||||
|
||||
// Define loops.
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
|
||||
std::vector<Value> originalLoops;
|
||||
originalLoops.reserve(numLoops);
|
||||
for (auto result : loopsOp.getResults()) {
|
||||
originalLoops.push_back(result);
|
||||
}
|
||||
|
||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
|
||||
std::vector<Value> optimizedLoops;
|
||||
optimizedLoops.reserve(numLoops);
|
||||
for (auto result : optimizedLoopsOp.getResults()) {
|
||||
optimizedLoops.push_back(result);
|
||||
}
|
||||
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||
|
||||
// We have two Krnl loops:
|
||||
// - Outer loop iterates over the output matrix dimensions, and
|
||||
// - Reduction loop iterates over the reduction dimension.
|
||||
|
||||
// Outer loop
|
||||
std::vector<Value> outerLoops, optimizedOuterLoops;
|
||||
outerLoops.reserve(2);
|
||||
optimizedOuterLoops.reserve(2);
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
outerLoops.push_back(originalLoops[i]);
|
||||
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||
}
|
||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
|
||||
optimizedOuterLoops);
|
||||
// Induction variables for the outer loops
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (memRefShape[i] < 0) {
|
||||
outerPack.pushConstantBound(0);
|
||||
outerPack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, alloc, i).getResult());
|
||||
} else {
|
||||
outerPack.pushConstantBound(0);
|
||||
outerPack.pushConstantBound(memRefShape[i]);
|
||||
}
|
||||
}
|
||||
// Reduction loop
|
||||
std::vector<Value> reductionLoops, optimizedReductionLoops;
|
||||
reductionLoops.reserve(1);
|
||||
optimizedReductionLoops.reserve(1);
|
||||
reductionLoops.push_back(originalLoops[2]);
|
||||
optimizedReductionLoops.push_back(optimizedLoops[2]);
|
||||
KrnlIterateOperandPack reductionPack(rewriter, reductionLoops,
|
||||
optimizedReductionLoops);
|
||||
// Induction variable for the reduction dimension
|
||||
// Try to find and use a static value from A or B first.
|
||||
// If it failed then use a dynamic value.
|
||||
auto ATy = A.getType().cast<MemRefType>();
|
||||
auto BTy = B.getType().cast<MemRefType>();
|
||||
int64_t K_A_Idx = (isTransA) ? 0 : 1;
|
||||
int64_t K_B_Idx = (isTransB) ? 1 : 0;
|
||||
reductionPack.pushConstantBound(0);
|
||||
if (ATy.getShape()[K_A_Idx] != -1)
|
||||
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
|
||||
else
|
||||
if (BTy.getShape()[K_B_Idx] != -1)
|
||||
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
|
||||
else
|
||||
reductionPack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
|
||||
|
||||
// Get run-time dimension information for unknown dimensions used for
|
||||
// broadcasting.
|
||||
// GemmOp supports unidirectional broadcasting from C to A*B.
|
||||
// Hence, it must be enough to get broadcasting information for C only.
|
||||
std::map<int, Value> broadcastedDimInfo;
|
||||
auto shape = C.getType().cast<MemRefType>().getShape();
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, C, i).getResult();
|
||||
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
auto isBroadcasted =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||
broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
|
||||
}
|
||||
}
|
||||
|
||||
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
||||
|
||||
// Now perform the insertions into the body of the
|
||||
// just generated instructions:
|
||||
|
||||
// No optimization
|
||||
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
rewriter.setInsertionPoint(optimizedLoopsOp);
|
||||
|
||||
// Insert instructions inside the outer loop.
|
||||
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
||||
|
||||
// Induction variables
|
||||
SmallVector<Value, 4> loopMNIVs;
|
||||
for (auto arg : outerIterationBlock.getArguments()) {
|
||||
loopMNIVs.emplace_back(arg);
|
||||
}
|
||||
|
||||
// Initialize the output of A*B
|
||||
auto zero = rewriter.create<ConstantOp>(
|
||||
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||
rewriter.create<StoreOp>(loc, zero, alloc, loopMNIVs);
|
||||
|
||||
// Compute A*B
|
||||
auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
|
||||
|
||||
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
||||
auto loopCIVs = getLoopIVsForBroadcasting(
|
||||
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
|
||||
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
||||
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
|
||||
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
|
||||
rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs);
|
||||
|
||||
// Insert instructions to do matrix multiplication: A*B
|
||||
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&matmulIterationBlock);
|
||||
|
||||
// Induction variables
|
||||
SmallVector<Value, 4> loopKIVs, loopAIVs, loopBIVs;
|
||||
for (auto arg : matmulIterationBlock.getArguments())
|
||||
loopKIVs.emplace_back(arg);
|
||||
if (isTransA) {
|
||||
loopAIVs.emplace_back(loopKIVs[0]);
|
||||
loopAIVs.emplace_back(loopMNIVs[0]);
|
||||
} else {
|
||||
loopAIVs.emplace_back(loopMNIVs[0]);
|
||||
loopAIVs.emplace_back(loopKIVs[0]);
|
||||
}
|
||||
if (isTransB) {
|
||||
loopBIVs.emplace_back(loopMNIVs[1]);
|
||||
loopBIVs.emplace_back(loopKIVs[0]);
|
||||
} else {
|
||||
loopBIVs.emplace_back(loopKIVs[0]);
|
||||
loopBIVs.emplace_back(loopMNIVs[1]);
|
||||
}
|
||||
|
||||
// Matmul computation
|
||||
auto loadedA = rewriter.create<LoadOp>(loc, A, loopAIVs);
|
||||
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBIVs);
|
||||
auto loadedY = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
|
||||
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
|
||||
rewriter.create<StoreOp>(loc, accumulated, alloc, loopMNIVs);
|
||||
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
||||
|
@ -1242,7 +1455,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return matchSuccess();
|
||||
|
@ -1377,7 +1589,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||
ONNXReshapeOpLowering, ONNXEntryPointLowering,
|
||||
ONNXSoftmaxOpLowering, ONNXUnsqueezeOpLowering>(&getContext());
|
||||
ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
|
||||
ONNXUnsqueezeOpLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
|
|
|
@ -93,6 +93,19 @@ test_to_enable = [
|
|||
"test_exp_cpu",
|
||||
"test_exp_example_cpu",
|
||||
|
||||
# Gemm Op:
|
||||
"test_gemm_all_attributes_cpu",
|
||||
"test_gemm_alpha_cpu",
|
||||
"test_gemm_beta_cpu",
|
||||
"test_gemm_default_matrix_bias_cpu",
|
||||
# "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
|
||||
# "test_gemm_default_scalar_bias_cpu", <- error, shapes mismatch, why?
|
||||
"test_gemm_default_single_elem_vector_bias_cpu",
|
||||
"test_gemm_default_vector_bias_cpu",
|
||||
"test_gemm_default_zero_bias_cpu",
|
||||
"test_gemm_transposeA_cpu",
|
||||
"test_gemm_transposeB_cpu",
|
||||
|
||||
# Hard Sigmoid Op:
|
||||
"test_hardsigmoid_cpu",
|
||||
"test_hardsigmoid_default_cpu",
|
||||
|
|
|
@ -633,6 +633,37 @@ func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
|||
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||
}
|
||||
|
||||
func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tensor<10xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.Gemm"(%arg0, %arg1, %arg2) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>, tensor<10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_gemm
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
||||
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
|
||||
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg3 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg4 = 0 to 10) {
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg5 = 0 to 5) {
|
||||
// CHECK: [[A:%.+]] = load %arg0[%arg5, %arg3] : memref<5x10xf32>
|
||||
// CHECK: [[B:%.+]] = load %arg1[%arg5, %arg4] : memref<5x10xf32>
|
||||
// CHECK: [[Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
|
||||
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
|
||||
// CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32>
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
|
||||
// CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32
|
||||
// CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32
|
||||
// CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue