Lowering Gemm (#19)

* Initial implementation

* Support transposing inputs

* Revise unidirectional broadcasting and unknown dimensions

* Revise gemm

* Add testcase

* Rename some variables

* Update SharingWork.md

* Change from the use of Value* to Value

* Insert deallocation

* Initilize the output matrix and fix wrong computation

* Add end-to-end testcases

* Edit lowering tests

* Change attribute names

* Use emplace_push for SmallVector

* Use the new way of getting attributes

* Revise the use of attributes

* Check the bias's shape

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-01-30 01:11:49 +09:00 committed by GitHub
parent 9e82d388f0
commit 400676e371
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 287 additions and 6 deletions

View File

@ -15,7 +15,7 @@ ONNX operations for which some work is needed.
| Elu | Tung | v | v | |
| Exp | Tung | v | v | |
| FullGemm | | | | noU |
| Gemm | | | | noU |
| Gemm | Tung | v | | U |
| HardSigmoid | Tung | v | v | |
| LeakyRelu | Tung | v | v | |
| MatMul | | | | noM |

View File

@ -487,13 +487,37 @@ void ONNXMatMulOp::inferShapes() {
void ONNXGemmOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
!getOperand(1).getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>())
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched.");
}
// Check whether bias is unidirectional broadcasting or not.
auto shape = biasTy.getShape();
int rank = shape.size();
if ((rank > 2) ||
(rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] &&
shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] &&
shape[rank - 2] != 1)) {
emitError("Bias shape mismatched.");
}
SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]);
dims.emplace_back(M);
dims.emplace_back(N);
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}

View File

@ -1176,6 +1176,219 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
}
};
struct ONNXGemmOpLowering : public ConversionPattern {
ONNXGemmOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
Value A, B, C;
A = operands[0];
B = operands[1];
C = operands[2];
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
// Result type
auto memRefType = convertTensorToMemRef(tensorType);
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else {
auto memRefShape = memRefType.getShape();
SmallVector<Value, 2> allocOperands;
if (memRefShape[0] < 0) {
auto dim = rewriter.create<DimOp>(loc, A, (isTransA) ? 1 : 0);
allocOperands.emplace_back(dim);
}
if (memRefShape[1] < 0) {
auto dim = rewriter.create<DimOp>(loc, B, (isTransB) ? 0 : 1);
allocOperands.emplace_back(dim);
}
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
if (insertDealloc) {
auto *parentBlock = alloc.getDefiningOp()->getBlock();
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
}
// Number of loops
auto memRefShape = memRefType.getShape();
int64_t numLoops = 3;
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
std::vector<Value> originalLoops;
originalLoops.reserve(numLoops);
for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result);
}
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
std::vector<Value> optimizedLoops;
optimizedLoops.reserve(numLoops);
for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result);
}
Block &optimizationBlock = optimizedLoopsOp.region().front();
// We have two Krnl loops:
// - Outer loop iterates over the output matrix dimensions, and
// - Reduction loop iterates over the reduction dimension.
// Outer loop
std::vector<Value> outerLoops, optimizedOuterLoops;
outerLoops.reserve(2);
optimizedOuterLoops.reserve(2);
for (int i = 0; i < 2; ++i) {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
optimizedOuterLoops);
// Induction variables for the outer loops
for (int i = 0; i < 2; ++i) {
if (memRefShape[i] < 0) {
outerPack.pushConstantBound(0);
outerPack.pushOperandBound(
rewriter.create<DimOp>(loc, alloc, i).getResult());
} else {
outerPack.pushConstantBound(0);
outerPack.pushConstantBound(memRefShape[i]);
}
}
// Reduction loop
std::vector<Value> reductionLoops, optimizedReductionLoops;
reductionLoops.reserve(1);
optimizedReductionLoops.reserve(1);
reductionLoops.push_back(originalLoops[2]);
optimizedReductionLoops.push_back(optimizedLoops[2]);
KrnlIterateOperandPack reductionPack(rewriter, reductionLoops,
optimizedReductionLoops);
// Induction variable for the reduction dimension
// Try to find and use a static value from A or B first.
// If it failed then use a dynamic value.
auto ATy = A.getType().cast<MemRefType>();
auto BTy = B.getType().cast<MemRefType>();
int64_t K_A_Idx = (isTransA) ? 0 : 1;
int64_t K_B_Idx = (isTransB) ? 1 : 0;
reductionPack.pushConstantBound(0);
if (ATy.getShape()[K_A_Idx] != -1)
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
else
if (BTy.getShape()[K_B_Idx] != -1)
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
else
reductionPack.pushOperandBound(
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
// GemmOp supports unidirectional broadcasting from C to A*B.
// Hence, it must be enough to get broadcasting information for C only.
std::map<int, Value> broadcastedDimInfo;
auto shape = C.getType().cast<MemRefType>().getShape();
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, C, i).getResult();
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
auto isBroadcasted =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
}
}
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
// Now perform the insertions into the body of the
// just generated instructions:
// No optimization
rewriter.setInsertionPointToEnd(&optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.setInsertionPoint(optimizedLoopsOp);
// Insert instructions inside the outer loop.
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&outerIterationBlock);
// Induction variables
SmallVector<Value, 4> loopMNIVs;
for (auto arg : outerIterationBlock.getArguments()) {
loopMNIVs.emplace_back(arg);
}
// Initialize the output of A*B
auto zero = rewriter.create<ConstantOp>(
loc, FloatAttr::get(memRefType.getElementType(), 0));
rewriter.create<StoreOp>(loc, zero, alloc, loopMNIVs);
// Compute A*B
auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
auto loopCIVs = getLoopIVsForBroadcasting(
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs);
// Insert instructions to do matrix multiplication: A*B
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&matmulIterationBlock);
// Induction variables
SmallVector<Value, 4> loopKIVs, loopAIVs, loopBIVs;
for (auto arg : matmulIterationBlock.getArguments())
loopKIVs.emplace_back(arg);
if (isTransA) {
loopAIVs.emplace_back(loopKIVs[0]);
loopAIVs.emplace_back(loopMNIVs[0]);
} else {
loopAIVs.emplace_back(loopMNIVs[0]);
loopAIVs.emplace_back(loopKIVs[0]);
}
if (isTransB) {
loopBIVs.emplace_back(loopMNIVs[1]);
loopBIVs.emplace_back(loopKIVs[0]);
} else {
loopBIVs.emplace_back(loopKIVs[0]);
loopBIVs.emplace_back(loopMNIVs[1]);
}
// Matmul computation
auto loadedA = rewriter.create<LoadOp>(loc, A, loopAIVs);
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBIVs);
auto loadedY = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
rewriter.create<StoreOp>(loc, accumulated, alloc, loopMNIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
@ -1242,7 +1455,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
}
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
rewriter.replaceOp(op, alloc);
return matchSuccess();
@ -1377,7 +1589,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering, ONNXEntryPointLowering,
ONNXSoftmaxOpLowering, ONNXUnsqueezeOpLowering>(&getContext());
ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
ONNXUnsqueezeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`

View File

@ -93,6 +93,19 @@ test_to_enable = [
"test_exp_cpu",
"test_exp_example_cpu",
# Gemm Op:
"test_gemm_all_attributes_cpu",
"test_gemm_alpha_cpu",
"test_gemm_beta_cpu",
"test_gemm_default_matrix_bias_cpu",
# "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
# "test_gemm_default_scalar_bias_cpu", <- error, shapes mismatch, why?
"test_gemm_default_single_elem_vector_bias_cpu",
"test_gemm_default_vector_bias_cpu",
"test_gemm_default_zero_bias_cpu",
"test_gemm_transposeA_cpu",
"test_gemm_transposeB_cpu",
# Hard Sigmoid Op:
"test_hardsigmoid_cpu",
"test_hardsigmoid_default_cpu",

View File

@ -633,6 +633,37 @@ func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : memref<10x10xf32>
}
func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tensor<10xf32>) -> tensor<*xf32> {
%0 ="onnx.Gemm"(%arg0, %arg1, %arg2) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>, tensor<10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_gemm
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg3 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg4 = 0 to 10) {
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg5 = 0 to 5) {
// CHECK: [[A:%.+]] = load %arg0[%arg5, %arg3] : memref<5x10xf32>
// CHECK: [[B:%.+]] = load %arg1[%arg5, %arg4] : memref<5x10xf32>
// CHECK: [[Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32>
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
// CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
// CHECK: }
// CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32>
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32>
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
// CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32
// CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32
// CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
// CHECK: return [[RES]] : memref<10x10xf32>
// CHECK: }
}
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()