Lower Matmul operation to Krnl dialect (#57)

* Allocate memory for matmul's result

* Group cases

* Add support of N-D x N-D, N>=2

* Revise createIterateOperandPack

* Add 1-D x 1-D

* Add 1-D x N-D

* Add MLIR tests

* Change variable names

* Change type from int to int64_t for indices

* Change variable names

* Change int64_t back to int

* Change int64_t back to int

* Change int64_t back to int

* Use decltype

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-02-15 00:43:17 +09:00 committed by GitHub
parent dab862e4f1
commit b521719587
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 608 additions and 46 deletions

View File

@ -442,7 +442,7 @@ void ONNXMatMulOp::inferShapes() {
lhsShape[0] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < rhsRank - 2; ++i)
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]);
dims.emplace_back(rhsShape[rhsRank - 1]);
} else if (lhsShape.size() >= 2 && rhsShape.size() == 1) {
@ -460,7 +460,7 @@ void ONNXMatMulOp::inferShapes() {
lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < lhsRank - 2; ++i)
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
dims.emplace_back(lhsShape[i]);
dims.emplace_back(lhsShape[lhsRank - 2]);
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
@ -474,7 +474,7 @@ void ONNXMatMulOp::inferShapes() {
lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < lhsRank - 1; ++i)
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
dims.emplace_back(lhsShape[i]);
dims.emplace_back(rhsShape[1]);
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
@ -488,7 +488,7 @@ void ONNXMatMulOp::inferShapes() {
lhsShape[1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < rhsRank - 2; ++i)
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]);
dims.emplace_back(lhsShape[0]);
dims.emplace_back(rhsShape[rhsRank - 1]);
@ -506,10 +506,10 @@ void ONNXMatMulOp::inferShapes() {
// Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape;
for (int i = 0; i < lhsRank - 2; ++i)
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
lhsBcastShape.emplace_back(lhsShape[i]);
SmallVector<int64_t, 2> rhsBcastShape;
for (int i = 0; i < rhsRank - 2; ++i)
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
emitError("Broadcasted dimensions are incompatible.");

View File

@ -588,9 +588,11 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
// Constant 1)
auto loc = op->getLoc();
Value operand = operands[0];
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
auto alphaAttribute = FloatAttr::get(
rewriter.getF32Type(),
llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
auto betaAttribute = FloatAttr::get(
rewriter.getF32Type(),
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
auto elementType = result_types[0];
@ -625,7 +627,8 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
auto alphaAttribute =
FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
@ -679,7 +682,8 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
Value operand = operands[0];
auto elementType = result_types[0];
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
auto alphaAttribute = FloatAttr::get(
rewriter.getF32Type(),
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
@ -705,9 +709,11 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
// alpha)))
auto loc = op->getLoc();
Value operand = operands[0];
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
auto alphaAttribute =
FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
auto gammaAttribute =
FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
auto elementType = result_types[0];
@ -748,8 +754,9 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
// Scalar unary ops for lowering ONNXSoftplusOp
//===----------------------------------------------------------------------===//
template <>
Value mapToLowerScalarOp<ONNXSoftplusOp>(
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
Value mapToLowerScalarOp<ONNXSoftplusOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1))
auto loc = op->getLoc();
@ -768,8 +775,9 @@ Value mapToLowerScalarOp<ONNXSoftplusOp>(
// Scalar unary ops for lowering ONNXSoftsignOp
//===----------------------------------------------------------------------===//
template <>
Value mapToLowerScalarOp<ONNXSoftsignOp>(
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
Value mapToLowerScalarOp<ONNXSoftsignOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
auto loc = op->getLoc();
@ -1408,6 +1416,337 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
}
};
struct ONNXMatMulOpLowering : public ConversionPattern {
ONNXMatMulOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
Value A = operands[0];
Value B = operands[1];
auto AShape = A.getType().cast<MemRefType>().getShape();
auto BShape = B.getType().cast<MemRefType>().getShape();
// There are three cases related to the shapes of the two arguments:
// - Both arguments are N-D, N >= 2
// - Either argument is 1-D, the other is N-D, N >= 2
// - Both arguments are 1-D
// Result type
auto memRefType = convertTensorToMemRef(tensorType);
auto elementType = memRefType.getElementType();
auto memRefShape = memRefType.getShape();
// A value zero
Value zero;
if (elementType.isa<IntegerType>()) {
zero = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(memRefType.getElementType(), 0));
} else if (elementType.isa<FloatType>()) {
zero = rewriter.create<ConstantOp>(
loc, FloatAttr::get(memRefType.getElementType(), 0));
} else {
emitError(loc, "unsupported element type");
}
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else {
SmallVector<Value, 4> allocOperands;
if (AShape.size() >= 2 && BShape.size() >= 2) {
// Both arguments are N-D, N >= 2
// (s1 x s2 x... x sK x M x K) MATMUL (K x N)
// =>
// (s1 x s2 x... x sK x M x N)
for (int i = 0; i < memRefShape.size() - 2; ++i) {
if (memRefShape[i] < 0) {
if ((AShape.size() == 2) && (BShape.size() > 2))
allocOperands.emplace_back(rewriter.create<DimOp>(loc, B, i));
else if ((AShape.size() > 2) && (BShape.size() == 2))
allocOperands.emplace_back(rewriter.create<DimOp>(loc, A, i));
}
}
if (memRefShape[memRefShape.size() - 2] < 0) {
auto dim = rewriter.create<DimOp>(loc, A, memRefShape.size() - 2);
allocOperands.emplace_back(dim);
}
if (memRefShape[memRefShape.size() - 1] < 0) {
auto dim = rewriter.create<DimOp>(loc, B, memRefShape.size() - 1);
allocOperands.emplace_back(dim);
}
} else if (AShape.size() == 1 && BShape.size() >= 2) {
// Either argument is 1-D
// K MATMUL (s1 x s2 x... x sK x K x N)
// =>
// (s1 x s2 x... x sK x N)
for (int i = 0; i < memRefShape.size() - 1; ++i) {
if (memRefShape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, B, i);
allocOperands.emplace_back(dim);
}
}
if (memRefShape[memRefShape.size() - 1] < 0) {
auto dim = rewriter.create<DimOp>(loc, B, BShape.size() - 1);
allocOperands.emplace_back(dim);
}
} else if (AShape.size() >= 2 && BShape.size() == 1) {
// Either argument is 1-D
// (s1 x s2 x... x sK x M x K) MATMUL K
// =>
// (s1 x s2 x... x sK x M)
for (int i = 0; i < memRefShape.size() - 1; ++i) {
if (memRefShape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, A, i);
allocOperands.emplace_back(dim);
}
}
if (memRefShape[memRefShape.size() - 1] < 0) {
auto dim = rewriter.create<DimOp>(loc, A, AShape.size() - 2);
allocOperands.emplace_back(dim);
}
} else if (AShape.size() == 1 && BShape.size() == 1) {
// Both arguments are 1-D
if (memRefShape[0] < 0) {
auto dim = rewriter.create<DimOp>(loc, A, 0);
allocOperands.emplace_back(dim);
}
} else {
emitError(loc, "Invalid shapes");
}
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
}
if (AShape.size() >= 2 || BShape.size() >= 2) {
// Cases 1 and 2:
// - Both arguments are N-D, N >= 2
// - Either argument is 1-D, the other is N-D, N >= 2
// Define loops for batch dimensions.
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, memRefShape.size());
// Outer KrnlIterateOp
SmallVector<Value, 4> loopBatchIVs;
bool hasBatchLoop = false;
if (AShape.size() > 2 || BShape.size() > 2) {
SmallVector<int, 4> batchAxes;
int matmulResultDims =
((AShape.size() == 1 || BShape.size() == 1)) ? 1 : 2;
for (int i = 0; i < memRefShape.size() - matmulResultDims; ++i)
batchAxes.emplace_back(i);
std::vector<Value> outerLoops, optimizedOuterLoops;
outerLoops.reserve(batchAxes.size());
optimizedOuterLoops.reserve(batchAxes.size());
for (int i = 0; i < batchAxes.size(); ++i) {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
optimizedOuterLoops);
for (int i = 0; i < batchAxes.size(); ++i) {
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
}
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
// Insert instructions into the outer KrnlIterateOp.
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&outerIterationBlock);
// Induction variables: non-matrix-multiplication variables.
for (auto arg : outerIterationBlock.getArguments()) {
loopBatchIVs.emplace_back(arg);
}
hasBatchLoop = true;
}
// Now, we define loops for matrix multiplication.
// Create a KrnlIterateOp for matrix multiplication.
KrnlIterateOp matmulIterateOp;
std::vector<Value> matmulLoops, optimizedMatmulLoops;
if (AShape.size() >= 2 && BShape.size() >= 2) {
// 2-D x 2-D. Result has two dimensions.
matmulLoops.reserve(2);
optimizedMatmulLoops.reserve(2);
for (int i = 2; i > 0; --i) {
matmulLoops.emplace_back(originalLoops[memRefShape.size() - i]);
optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - i]);
}
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
optimizedMatmulLoops);
for (int i = 2; i > 0; --i) {
addDimensionToPack(rewriter, loc, matmulPack, alloc,
memRefShape.size() - i);
}
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
} else {
// 1-D x 2-D, and vice versa. Result has one dimension.
matmulLoops.reserve(1);
optimizedMatmulLoops.reserve(1);
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
optimizedMatmulLoops.emplace_back(
optimizedLoops[memRefShape.size() - 1]);
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
optimizedMatmulLoops);
addDimensionToPack(rewriter, loc, matmulPack, alloc,
memRefShape.size() - 1);
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
}
if (!hasBatchLoop) {
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
}
// Insert instructions into the matmul KrnlIterateOp.
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&matmulIterationBlock);
// Induction variables: M, N
SmallVector<Value, 4> loopMNIVs;
for (auto arg : matmulIterationBlock.getArguments()) {
loopMNIVs.emplace_back(arg);
}
// Induction variables for the final result.
SmallVector<Value, 4> loopBatchMNIVs;
for (auto arg : loopBatchIVs) {
loopBatchMNIVs.emplace_back(arg);
}
for (auto arg : loopMNIVs) {
loopBatchMNIVs.emplace_back(arg);
}
// Fill the output with value 0.
rewriter.create<StoreOp>(loc, zero, alloc, loopBatchMNIVs);
// Iterate along the reduction dimension.
// Use a value from A.
std::vector<Value> reduceLoops;
std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
// No optimization
rewriter.setInsertionPointToEnd(optimizationReduceBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
// Insert instructions into the reduction KrnlIterateOp.
Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&reduceIterationBlock);
// Induction variables
SmallVector<Value, 4> loopKIVs, loopBatchMKIVs, loopBatchKNIVs;
// K
loopKIVs.emplace_back(reduceIterationBlock.getArguments()[0]);
// MK
if (AShape.size() > 2)
for (auto arg : loopBatchIVs)
loopBatchMKIVs.emplace_back(arg);
if (AShape.size() >= 2)
loopBatchMKIVs.emplace_back(loopMNIVs[0]);
loopBatchMKIVs.emplace_back(loopKIVs[0]);
// KN
if (BShape.size() > 2)
for (auto arg : loopBatchIVs)
loopBatchKNIVs.emplace_back(arg);
loopBatchKNIVs.emplace_back(loopKIVs[0]);
if (BShape.size() >= 2)
if (AShape.size() >= 2)
loopBatchKNIVs.emplace_back(loopMNIVs[1]);
else
loopBatchKNIVs.emplace_back(loopMNIVs[0]);
// Matmul computation
auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs);
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs);
auto loadedY = rewriter.create<LoadOp>(loc, alloc, loopBatchMNIVs);
if (elementType.isa<IntegerType>()) {
auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
} else if (elementType.isa<FloatType>()) {
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
}
} else if ((AShape.size() == 1) && (BShape.size() == 1)) {
// Case 3:
// - Both arguments are 1-D
// Fill the output with value 0.
Value zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
rewriter.create<StoreOp>(loc, zero, alloc, zeroIndex);
// Iterate along the reduction dimension.
// Use a value from A.
std::vector<Value> reduceLoops;
std::vector<Value> optimizedReduceLoops;
Block *optimizationReduceBlock =
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
optimizedReduceLoops);
addDimensionToPack(rewriter, loc, reducePack, A, 0);
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
// No optimization
rewriter.setInsertionPointToEnd(optimizationReduceBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
// Insert instructions into the reduction KrnlIterateOp.
Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&reduceIterationBlock);
// Induction variables
SmallVector<Value, 4> loopKIVs;
// K
loopKIVs.emplace_back(reduceIterationBlock.getArgument(0));
// Matmul computation
auto loadedA = rewriter.create<LoadOp>(loc, A, loopKIVs);
auto loadedB = rewriter.create<LoadOp>(loc, B, loopKIVs);
auto loadedY = rewriter.create<LoadOp>(loc, alloc, zeroIndex);
if (elementType.isa<IntegerType>()) {
auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
} else if (elementType.isa<FloatType>()) {
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
}
} else {
// No scalar matrix multiplication.
llvm_unreachable("Unsupported scalar matrix multiplication.");
}
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
struct ONNXGemmOpLowering : public ConversionPattern {
ONNXGemmOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {}
@ -1423,9 +1762,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
B = operands[1];
C = operands[2];
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
auto alphaAttr =
FloatAttr::get(tensorType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
auto betaAttr =
FloatAttr::get(tensorType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -1482,8 +1823,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
optimizedOuterLoops);
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
// Induction variables for the outer loops
for (int i = 0; i < 2; ++i)
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
@ -1501,13 +1841,12 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// If it failed then use a dynamic value.
auto ATy = A.getType().cast<MemRefType>();
auto BTy = B.getType().cast<MemRefType>();
int64_t K_A_Idx = (isTransA) ? 0 : 1;
int64_t K_B_Idx = (isTransB) ? 1 : 0;
int K_A_Idx = (isTransA) ? 0 : 1;
int K_B_Idx = (isTransB) ? 1 : 0;
reductionPack.pushConstantBound(0);
if (ATy.getShape()[K_A_Idx] != -1)
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
else
if (BTy.getShape()[K_B_Idx] != -1)
else if (BTy.getShape()[K_B_Idx] != -1)
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
else
reductionPack.pushOperandBound(
@ -1557,8 +1896,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
auto loopCIVs = getLoopIVsForBroadcasting(
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
@ -1650,8 +1989,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
Value dimVal = nullptr;
if (memRefShape[outIdx] < 0) {
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
dimVal = rewriter.create<IndexCastOp>(
loc, index, rewriter.getIntegerType(64));
dimVal = rewriter.create<IndexCastOp>(loc, index,
rewriter.getIntegerType(64));
allocOperands.emplace_back(index);
} else {
dimVal = rewriter.create<ConstantOp>(
@ -2362,8 +2701,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXReductionOpLowering<mlir::ONNXReduceSumOp>,
ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering,
ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering
>(&getContext());
ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering,
ONNXMatMulOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`

View File

@ -295,6 +295,12 @@ test_to_enable = [
# Sign Op:
"test_sign_cpu",
# MatmulOp
"test_matmul_2d_cpu",
"test_matmul_3d_cpu",
"test_matmul_4d_cpu",
]
# Extract name of all test cases.

View File

@ -930,6 +930,223 @@ func @test_sign_i(%arg0 : tensor<?x10xi32>) -> tensor<*xi32> {
// CHECK: return [[RES]] : memref<?x10xi32>
}
// 2-D x 2-D
func @test_matmul1(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<10x5xf32>, tensor<5x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul1
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[LOOPS]]#0 -> %arg2 = 0 to 10, [[LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg4 = 0 to 5) {
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg2, %arg4] : memref<10x5xf32>
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<10x10xf32>
}
// 2-D x N-D
func @test_matmul2(%arg0 : tensor<10x5xf32>, %arg1 : tensor<2x3x5x10xf32>) -> tensor<*xf32> {
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<10x5xf32>, tensor<2x3x5x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul2
// CHECK: [[RES:%.+]] = alloc() : memref<2x3x10x10xf32>
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[LOOPS]]#0 -> %arg2 = 0 to 2, [[LOOPS]]#1 -> %arg3 = 0 to 3) {
// CHECK: krnl.iterate([[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#2 -> %arg4 = 0 to 10, [[LOOPS]]#3 -> %arg5 = 0 to 10) {
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg6 = 0 to 5) {
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg4, %arg6] : memref<10x5xf32>
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg2, %arg3, %arg6, %arg5] : memref<2x3x5x10xf32>
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<2x3x10x10xf32>
}
// N-D x N-D
func @test_matmul3(%arg0 : tensor<2x3x10x5xf32>, %arg1 : tensor<2x3x5x10xf32>) -> tensor<*xf32> {
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<2x3x10x5xf32>, tensor<2x3x5x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul3
// CHECK: [[RES:%.+]] = alloc() : memref<2x3x10x10xf32>
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[LOOPS]]#0 -> %arg2 = 0 to 2, [[LOOPS]]#1 -> %arg3 = 0 to 3) {
// CHECK: krnl.iterate([[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#2 -> %arg4 = 0 to 10, [[LOOPS]]#3 -> %arg5 = 0 to 10) {
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg6 = 0 to 5) {
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg2, %arg3, %arg4, %arg6] : memref<2x3x10x5xf32>
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg2, %arg3, %arg6, %arg5] : memref<2x3x5x10xf32>
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<2x3x10x10xf32>
}
// 1-D x 2-D
func @test_matmul4(%arg0 : tensor<5xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<5xf32>, tensor<5x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul4
// CHECK: [[RES:%.+]] = alloc() : memref<10xf32>
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[LOOPS:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[OPT_LOOPS]]) with ([[LOOPS]] -> %arg2 = 0 to 10) {
// CHECK: store [[CONSTANT]], [[RES]][%arg2] : memref<10xf32>
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg3 = 0 to 5) {
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg3] : memref<5xf32>
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg3, %arg2] : memref<5x10xf32>
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2] : memref<10xf32>
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
// CHECK: store [[ADD]], [[RES]][%arg2] : memref<10xf32>
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<10xf32>
}
// 1-D x N-D
func @test_matmul5(%arg0 : tensor<5xf32>, %arg1 : tensor<?x5x10xf32>) -> tensor<*xf32> {
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<5xf32>, tensor<?x5x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul5
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[DIM_0:%.+]] = dim %arg1, 0 : memref<?x5x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_1:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[LOOPS]]#0 -> %arg2 = 0 to [[DIM_1]]) {
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg4 = 0 to 5) {
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg4] : memref<5xf32>
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg2, %arg4, %arg3] : memref<?x5x10xf32>
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<?x10xf32>
}
// N-D x 1-D
func @test_matmul6(%arg0 : tensor<?x10x5xf32>, %arg1 : tensor<5xf32>) -> tensor<*xf32> {
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<?x10x5xf32>, tensor<5xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul6
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10x5xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_1:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[LOOPS]]#0 -> %arg2 = 0 to [[DIM_1]]) {
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg4 = 0 to 5) {
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg2, %arg3, %arg4] : memref<?x10x5xf32>
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg4] : memref<5xf32>
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<?x10xf32>
}
// 1-D x 1-D
func @test_matmul7(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<*xf32> {
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul7
// CHECK: [[RES:%.+]] = alloc() : memref<1xf32>
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
// CHECK: %[[CONSTANT_INDEX:.+]] = constant 0 : index
// CHECK: store [[CONSTANT]], [[RES]][%[[CONSTANT_INDEX]]] : memref<1xf32>
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg2 = 0 to 5) {
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg2] : memref<5xf32>
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg2] : memref<5xf32>
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%[[CONSTANT_INDEX]]] : memref<1xf32>
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
// CHECK: store [[ADD]], [[RES]][%[[CONSTANT_INDEX]]] : memref<1xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<1xf32>
}
func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()