onnx-mlir/src/Conversion/ONNXToKrnl/Math/MatMul.cpp

340 lines
14 KiB
C++

//===----------------- Matmul.cpp - Lowering Matmul Op --------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Matmul Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir;
struct ONNXMatMulOpLowering : public ConversionPattern {
ONNXMatMulOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
Value A = operands[0];
Value B = operands[1];
auto AShape = A.getType().cast<MemRefType>().getShape();
auto BShape = B.getType().cast<MemRefType>().getShape();
// There are three cases related to the shapes of the two arguments:
// - Both arguments are N-D, N >= 2
// - Either argument is 1-D, the other is N-D, N >= 2
// - Both arguments are 1-D
// Result type
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto elementType = memRefType.getElementType();
auto memRefShape = memRefType.getShape();
// A value zero
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else {
SmallVector<Value, 4> allocOperands;
if (AShape.size() >= 2 && BShape.size() >= 2) {
// Both arguments are N-D, N >= 2
// (s1 x s2 x... x sK x M x K) MATMUL (K x N)
// =>
// (s1 x s2 x... x sK x M x N)
for (int i = 0; i < memRefShape.size() - 2; ++i) {
if (memRefShape[i] < 0) {
if ((AShape.size() == 2) && (BShape.size() > 2))
allocOperands.emplace_back(rewriter.create<DimOp>(loc, B, i));
else if ((AShape.size() > 2) && (BShape.size() == 2))
allocOperands.emplace_back(rewriter.create<DimOp>(loc, A, i));
}
}
if (memRefShape[memRefShape.size() - 2] < 0) {
auto dim = rewriter.create<DimOp>(loc, A, memRefShape.size() - 2);
allocOperands.emplace_back(dim);
}
if (memRefShape[memRefShape.size() - 1] < 0) {
auto dim = rewriter.create<DimOp>(loc, B, memRefShape.size() - 1);
allocOperands.emplace_back(dim);
}
} else if (AShape.size() == 1 && BShape.size() >= 2) {
// Either argument is 1-D
// K MATMUL (s1 x s2 x... x sK x K x N)
// =>
// (s1 x s2 x... x sK x N)
for (int i = 0; i < memRefShape.size() - 1; ++i) {
if (memRefShape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, B, i);
allocOperands.emplace_back(dim);
}
}
if (memRefShape[memRefShape.size() - 1] < 0) {
auto dim = rewriter.create<DimOp>(loc, B, BShape.size() - 1);
allocOperands.emplace_back(dim);
}
} else if (AShape.size() >= 2 && BShape.size() == 1) {
// Either argument is 1-D
// (s1 x s2 x... x sK x M x K) MATMUL K
// =>
// (s1 x s2 x... x sK x M)
for (int i = 0; i < memRefShape.size() - 1; ++i) {
if (memRefShape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, A, i);
allocOperands.emplace_back(dim);
}
}
if (memRefShape[memRefShape.size() - 1] < 0) {
auto dim = rewriter.create<DimOp>(loc, A, AShape.size() - 2);
allocOperands.emplace_back(dim);
}
} else if (AShape.size() == 1 && BShape.size() == 1) {
// Both arguments are 1-D
if (memRefShape[0] < 0) {
auto dim = rewriter.create<DimOp>(loc, A, 0);
allocOperands.emplace_back(dim);
}
} else {
emitError(loc, "Invalid shapes");
}
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
}
if (AShape.size() >= 2 || BShape.size() >= 2) {
// Cases 1 and 2:
// - Both arguments are N-D, N >= 2
// - Either argument is 1-D, the other is N-D, N >= 2
// Define loops for batch dimensions.
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, memRefShape.size());
// Outer KrnlIterateOp
SmallVector<Value, 4> loopBatchIVs;
bool hasBatchLoop = false;
if (AShape.size() > 2 || BShape.size() > 2) {
SmallVector<int, 4> batchAxes;
int matmulResultDims =
((AShape.size() == 1 || BShape.size() == 1)) ? 1 : 2;
for (int i = 0; i < memRefShape.size() - matmulResultDims; ++i)
batchAxes.emplace_back(i);
std::vector<Value> outerLoops, optimizedOuterLoops;
outerLoops.reserve(batchAxes.size());
optimizedOuterLoops.reserve(batchAxes.size());
for (int i = 0; i < batchAxes.size(); ++i) {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
optimizedOuterLoops);
for (int i = 0; i < batchAxes.size(); ++i) {
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
}
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
// Insert instructions into the outer KrnlIterateOp.
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&outerIterationBlock);
// Induction variables: non-matrix-multiplication variables.
for (auto arg : outerIterationBlock.getArguments()) {
loopBatchIVs.emplace_back(arg);
}
hasBatchLoop = true;
}
// Now, we define loops for matrix multiplication.
// Create a KrnlIterateOp for matrix multiplication.
KrnlIterateOp matmulIterateOp;
std::vector<Value> matmulLoops, optimizedMatmulLoops;
if (AShape.size() >= 2 && BShape.size() >= 2) {
// 2-D x 2-D. Result has two dimensions.
matmulLoops.reserve(2);
optimizedMatmulLoops.reserve(2);
for (int i = 2; i > 0; --i) {
matmulLoops.emplace_back(originalLoops[memRefShape.size() - i]);
optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - i]);
}
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
optimizedMatmulLoops);
for (int i = 2; i > 0; --i) {
addDimensionToPack(rewriter, loc, matmulPack, alloc,
memRefShape.size() - i);
}
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
} else {
// 1-D x 2-D, and vice versa. Result has one dimension.
matmulLoops.reserve(1);
optimizedMatmulLoops.reserve(1);
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - 1]);
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
optimizedMatmulLoops);
addDimensionToPack(rewriter, loc, matmulPack, alloc,
memRefShape.size() - 1);
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
}
if (!hasBatchLoop) {
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
}
// Insert instructions into the matmul KrnlIterateOp.
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&matmulIterationBlock);
// Induction variables: M, N
SmallVector<Value, 4> loopMNIVs;
for (auto arg : matmulIterationBlock.getArguments()) {
loopMNIVs.emplace_back(arg);
}
// Induction variables for the final result.
SmallVector<Value, 4> loopBatchMNIVs;
for (auto arg : loopBatchIVs) {
loopBatchMNIVs.emplace_back(arg);
}
for (auto arg : loopMNIVs) {
loopBatchMNIVs.emplace_back(arg);
}
// Fill the output with value 0.
rewriter.create<StoreOp>(loc, zero, alloc, loopBatchMNIVs);
// Iterate along the reduction dimension.
// Use a value from A.
std::vector<Value> reduceLoops;
std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
// No optimization
rewriter.setInsertionPointToEnd(optimizationReduceBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
// Insert instructions into the reduction KrnlIterateOp.
Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&reduceIterationBlock);
// Induction variables
SmallVector<Value, 4> loopKIVs, loopBatchMKIVs, loopBatchKNIVs;
// K
loopKIVs.emplace_back(reduceIterationBlock.getArguments()[0]);
// MK
if (AShape.size() > 2)
for (auto arg : loopBatchIVs)
loopBatchMKIVs.emplace_back(arg);
if (AShape.size() >= 2)
loopBatchMKIVs.emplace_back(loopMNIVs[0]);
loopBatchMKIVs.emplace_back(loopKIVs[0]);
// KN
if (BShape.size() > 2)
for (auto arg : loopBatchIVs)
loopBatchKNIVs.emplace_back(arg);
loopBatchKNIVs.emplace_back(loopKIVs[0]);
if (BShape.size() >= 2)
if (AShape.size() >= 2)
loopBatchKNIVs.emplace_back(loopMNIVs[1]);
else
loopBatchKNIVs.emplace_back(loopMNIVs[0]);
// Matmul computation
auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs);
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs);
auto loadedY = rewriter.create<LoadOp>(loc, alloc, loopBatchMNIVs);
if (elementType.isa<IntegerType>()) {
auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
} else if (elementType.isa<FloatType>()) {
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
}
} else if ((AShape.size() == 1) && (BShape.size() == 1)) {
// Case 3:
// - Both arguments are 1-D
// Fill the output with value 0.
Value zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
rewriter.create<StoreOp>(loc, zero, alloc, zeroIndex);
// Iterate along the reduction dimension.
// Use a value from A.
std::vector<Value> reduceLoops;
std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, 0);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
// No optimization
rewriter.setInsertionPointToEnd(optimizationReduceBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
// Insert instructions into the reduction KrnlIterateOp.
Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&reduceIterationBlock);
// Induction variables
SmallVector<Value, 4> loopKIVs;
// K
loopKIVs.emplace_back(reduceIterationBlock.getArgument(0));
// Matmul computation
auto loadedA = rewriter.create<LoadOp>(loc, A, loopKIVs);
auto loadedB = rewriter.create<LoadOp>(loc, B, loopKIVs);
auto loadedY = rewriter.create<LoadOp>(loc, alloc, zeroIndex);
if (elementType.isa<IntegerType>()) {
auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
} else if (elementType.isa<FloatType>()) {
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
}
} else {
// No scalar matrix multiplication.
llvm_unreachable("Unsupported scalar matrix multiplication.");
}
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
void populateLoweringONNXMatMulOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXMatMulOpLowering>(ctx);
}