Lower Matmul operation to Krnl dialect (#57)
* Allocate memory for matmul's result * Group cases * Add support of N-D x N-D, N>=2 * Revise createIterateOperandPack * Add 1-D x 1-D * Add 1-D x N-D * Add MLIR tests * Change variable names * Change type from int to int64_t for indices * Change variable names * Change int64_t back to int * Change int64_t back to int * Change int64_t back to int * Use decltype Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com> Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
dab862e4f1
commit
b521719587
|
@ -442,7 +442,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
lhsShape[0] != rhsShape[rhsRank - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
for (int i = 0; i < rhsRank - 2; ++i)
|
||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||
dims.emplace_back(rhsShape[i]);
|
||||
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||
} else if (lhsShape.size() >= 2 && rhsShape.size() == 1) {
|
||||
|
@ -460,7 +460,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
for (int i = 0; i < lhsRank - 2; ++i)
|
||||
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
||||
dims.emplace_back(lhsShape[i]);
|
||||
dims.emplace_back(lhsShape[lhsRank - 2]);
|
||||
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
||||
|
@ -474,7 +474,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
for (int i = 0; i < lhsRank - 1; ++i)
|
||||
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
|
||||
dims.emplace_back(lhsShape[i]);
|
||||
dims.emplace_back(rhsShape[1]);
|
||||
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
|
||||
|
@ -488,7 +488,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
lhsShape[1] != rhsShape[rhsRank - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
for (int i = 0; i < rhsRank - 2; ++i)
|
||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||
dims.emplace_back(rhsShape[i]);
|
||||
dims.emplace_back(lhsShape[0]);
|
||||
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||
|
@ -506,10 +506,10 @@ void ONNXMatMulOp::inferShapes() {
|
|||
|
||||
// Check and perform broadcasting for the shapes.
|
||||
SmallVector<int64_t, 2> lhsBcastShape;
|
||||
for (int i = 0; i < lhsRank - 2; ++i)
|
||||
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
||||
lhsBcastShape.emplace_back(lhsShape[i]);
|
||||
SmallVector<int64_t, 2> rhsBcastShape;
|
||||
for (int i = 0; i < rhsRank - 2; ++i)
|
||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||
rhsBcastShape.emplace_back(rhsShape[i]);
|
||||
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
||||
emitError("Broadcasted dimensions are incompatible.");
|
||||
|
|
|
@ -284,7 +284,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
|||
auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
|
||||
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
auto isBroadcasted =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||
broadcastedDims.insert(std::make_pair(j, isBroadcasted));
|
||||
}
|
||||
}
|
||||
|
@ -588,9 +588,11 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
|||
// Constant 1)
|
||||
auto loc = op->getLoc();
|
||||
Value operand = operands[0];
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
auto alphaAttribute = FloatAttr::get(
|
||||
rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
|
||||
auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
auto betaAttribute = FloatAttr::get(
|
||||
rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
|
||||
auto elementType = result_types[0];
|
||||
|
||||
|
@ -625,8 +627,9 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
Value operand = operands[0];
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
|
||||
auto alphaAttribute =
|
||||
FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
|
@ -679,7 +682,8 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
|||
Value operand = operands[0];
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
auto alphaAttribute = FloatAttr::get(
|
||||
rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
|
||||
|
@ -705,10 +709,12 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
|||
// alpha)))
|
||||
auto loc = op->getLoc();
|
||||
Value operand = operands[0];
|
||||
auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
|
||||
auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
|
||||
auto alphaAttribute =
|
||||
FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
|
||||
auto gammaAttribute =
|
||||
FloatAttr::get(rewriter.getF32Type(),
|
||||
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
|
||||
auto elementType = result_types[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
|
@ -748,9 +754,10 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
|
|||
// Scalar unary ops for lowering ONNXSoftplusOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value mapToLowerScalarOp<ONNXSoftplusOp>(
|
||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXSoftplusOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1))
|
||||
auto loc = op->getLoc();
|
||||
Value operand = operands[0];
|
||||
|
@ -768,9 +775,10 @@ Value mapToLowerScalarOp<ONNXSoftplusOp>(
|
|||
// Scalar unary ops for lowering ONNXSoftsignOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value mapToLowerScalarOp<ONNXSoftsignOp>(
|
||||
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Value mapToLowerScalarOp<ONNXSoftsignOp>(Operation *op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||
auto loc = op->getLoc();
|
||||
Value operand = operands[0];
|
||||
|
@ -1408,6 +1416,337 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||
ONNXMatMulOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
Value A = operands[0];
|
||||
Value B = operands[1];
|
||||
auto AShape = A.getType().cast<MemRefType>().getShape();
|
||||
auto BShape = B.getType().cast<MemRefType>().getShape();
|
||||
|
||||
// There are three cases related to the shapes of the two arguments:
|
||||
// - Both arguments are N-D, N >= 2
|
||||
// - Either argument is 1-D, the other is N-D, N >= 2
|
||||
// - Both arguments are 1-D
|
||||
|
||||
// Result type
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto elementType = memRefType.getElementType();
|
||||
auto memRefShape = memRefType.getShape();
|
||||
|
||||
// A value zero
|
||||
Value zero;
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
zero = rewriter.create<ConstantOp>(
|
||||
loc, IntegerAttr::get(memRefType.getElementType(), 0));
|
||||
} else if (elementType.isa<FloatType>()) {
|
||||
zero = rewriter.create<ConstantOp>(
|
||||
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else {
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
if (AShape.size() >= 2 && BShape.size() >= 2) {
|
||||
// Both arguments are N-D, N >= 2
|
||||
// (s1 x s2 x... x sK x M x K) MATMUL (K x N)
|
||||
// =>
|
||||
// (s1 x s2 x... x sK x M x N)
|
||||
for (int i = 0; i < memRefShape.size() - 2; ++i) {
|
||||
if (memRefShape[i] < 0) {
|
||||
if ((AShape.size() == 2) && (BShape.size() > 2))
|
||||
allocOperands.emplace_back(rewriter.create<DimOp>(loc, B, i));
|
||||
else if ((AShape.size() > 2) && (BShape.size() == 2))
|
||||
allocOperands.emplace_back(rewriter.create<DimOp>(loc, A, i));
|
||||
}
|
||||
}
|
||||
if (memRefShape[memRefShape.size() - 2] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, A, memRefShape.size() - 2);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
if (memRefShape[memRefShape.size() - 1] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, B, memRefShape.size() - 1);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
} else if (AShape.size() == 1 && BShape.size() >= 2) {
|
||||
// Either argument is 1-D
|
||||
// K MATMUL (s1 x s2 x... x sK x K x N)
|
||||
// =>
|
||||
// (s1 x s2 x... x sK x N)
|
||||
for (int i = 0; i < memRefShape.size() - 1; ++i) {
|
||||
if (memRefShape[i] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, B, i);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
}
|
||||
if (memRefShape[memRefShape.size() - 1] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, B, BShape.size() - 1);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
} else if (AShape.size() >= 2 && BShape.size() == 1) {
|
||||
// Either argument is 1-D
|
||||
// (s1 x s2 x... x sK x M x K) MATMUL K
|
||||
// =>
|
||||
// (s1 x s2 x... x sK x M)
|
||||
for (int i = 0; i < memRefShape.size() - 1; ++i) {
|
||||
if (memRefShape[i] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, A, i);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
}
|
||||
if (memRefShape[memRefShape.size() - 1] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, A, AShape.size() - 2);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
} else if (AShape.size() == 1 && BShape.size() == 1) {
|
||||
// Both arguments are 1-D
|
||||
if (memRefShape[0] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, A, 0);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
} else {
|
||||
emitError(loc, "Invalid shapes");
|
||||
}
|
||||
|
||||
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||
}
|
||||
|
||||
if (AShape.size() >= 2 || BShape.size() >= 2) {
|
||||
// Cases 1 and 2:
|
||||
// - Both arguments are N-D, N >= 2
|
||||
// - Either argument is 1-D, the other is N-D, N >= 2
|
||||
|
||||
// Define loops for batch dimensions.
|
||||
std::vector<Value> originalLoops;
|
||||
std::vector<Value> optimizedLoops;
|
||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
||||
optimizedLoops, memRefShape.size());
|
||||
|
||||
// Outer KrnlIterateOp
|
||||
SmallVector<Value, 4> loopBatchIVs;
|
||||
bool hasBatchLoop = false;
|
||||
if (AShape.size() > 2 || BShape.size() > 2) {
|
||||
SmallVector<int, 4> batchAxes;
|
||||
int matmulResultDims =
|
||||
((AShape.size() == 1 || BShape.size() == 1)) ? 1 : 2;
|
||||
for (int i = 0; i < memRefShape.size() - matmulResultDims; ++i)
|
||||
batchAxes.emplace_back(i);
|
||||
|
||||
std::vector<Value> outerLoops, optimizedOuterLoops;
|
||||
outerLoops.reserve(batchAxes.size());
|
||||
optimizedOuterLoops.reserve(batchAxes.size());
|
||||
for (int i = 0; i < batchAxes.size(); ++i) {
|
||||
outerLoops.push_back(originalLoops[i]);
|
||||
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||
}
|
||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
|
||||
optimizedOuterLoops);
|
||||
for (int i = 0; i < batchAxes.size(); ++i) {
|
||||
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
||||
}
|
||||
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
||||
|
||||
// No optimization
|
||||
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
|
||||
// Insert instructions into the outer KrnlIterateOp.
|
||||
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
||||
|
||||
// Induction variables: non-matrix-multiplication variables.
|
||||
for (auto arg : outerIterationBlock.getArguments()) {
|
||||
loopBatchIVs.emplace_back(arg);
|
||||
}
|
||||
|
||||
hasBatchLoop = true;
|
||||
}
|
||||
|
||||
// Now, we define loops for matrix multiplication.
|
||||
|
||||
// Create a KrnlIterateOp for matrix multiplication.
|
||||
KrnlIterateOp matmulIterateOp;
|
||||
std::vector<Value> matmulLoops, optimizedMatmulLoops;
|
||||
if (AShape.size() >= 2 && BShape.size() >= 2) {
|
||||
// 2-D x 2-D. Result has two dimensions.
|
||||
matmulLoops.reserve(2);
|
||||
optimizedMatmulLoops.reserve(2);
|
||||
for (int i = 2; i > 0; --i) {
|
||||
matmulLoops.emplace_back(originalLoops[memRefShape.size() - i]);
|
||||
optimizedMatmulLoops.emplace_back(
|
||||
optimizedLoops[memRefShape.size() - i]);
|
||||
}
|
||||
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
|
||||
optimizedMatmulLoops);
|
||||
for (int i = 2; i > 0; --i) {
|
||||
addDimensionToPack(rewriter, loc, matmulPack, alloc,
|
||||
memRefShape.size() - i);
|
||||
}
|
||||
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
||||
} else {
|
||||
// 1-D x 2-D, and vice versa. Result has one dimension.
|
||||
matmulLoops.reserve(1);
|
||||
optimizedMatmulLoops.reserve(1);
|
||||
matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
|
||||
optimizedMatmulLoops.emplace_back(
|
||||
optimizedLoops[memRefShape.size() - 1]);
|
||||
KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
|
||||
optimizedMatmulLoops);
|
||||
addDimensionToPack(rewriter, loc, matmulPack, alloc,
|
||||
memRefShape.size() - 1);
|
||||
matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
|
||||
}
|
||||
|
||||
if (!hasBatchLoop) {
|
||||
// No optimization
|
||||
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
}
|
||||
|
||||
// Insert instructions into the matmul KrnlIterateOp.
|
||||
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&matmulIterationBlock);
|
||||
|
||||
// Induction variables: M, N
|
||||
SmallVector<Value, 4> loopMNIVs;
|
||||
for (auto arg : matmulIterationBlock.getArguments()) {
|
||||
loopMNIVs.emplace_back(arg);
|
||||
}
|
||||
// Induction variables for the final result.
|
||||
SmallVector<Value, 4> loopBatchMNIVs;
|
||||
for (auto arg : loopBatchIVs) {
|
||||
loopBatchMNIVs.emplace_back(arg);
|
||||
}
|
||||
for (auto arg : loopMNIVs) {
|
||||
loopBatchMNIVs.emplace_back(arg);
|
||||
}
|
||||
|
||||
// Fill the output with value 0.
|
||||
rewriter.create<StoreOp>(loc, zero, alloc, loopBatchMNIVs);
|
||||
|
||||
// Iterate along the reduction dimension.
|
||||
// Use a value from A.
|
||||
std::vector<Value> reduceLoops;
|
||||
std::vector<Value> optimizedReduceLoops;
|
||||
Block *optimizationReduceBlock =
|
||||
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
||||
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
|
||||
optimizedReduceLoops);
|
||||
addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
|
||||
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
||||
|
||||
// No optimization
|
||||
rewriter.setInsertionPointToEnd(optimizationReduceBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
|
||||
|
||||
// Insert instructions into the reduction KrnlIterateOp.
|
||||
Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&reduceIterationBlock);
|
||||
|
||||
// Induction variables
|
||||
SmallVector<Value, 4> loopKIVs, loopBatchMKIVs, loopBatchKNIVs;
|
||||
// K
|
||||
loopKIVs.emplace_back(reduceIterationBlock.getArguments()[0]);
|
||||
// MK
|
||||
if (AShape.size() > 2)
|
||||
for (auto arg : loopBatchIVs)
|
||||
loopBatchMKIVs.emplace_back(arg);
|
||||
if (AShape.size() >= 2)
|
||||
loopBatchMKIVs.emplace_back(loopMNIVs[0]);
|
||||
loopBatchMKIVs.emplace_back(loopKIVs[0]);
|
||||
// KN
|
||||
if (BShape.size() > 2)
|
||||
for (auto arg : loopBatchIVs)
|
||||
loopBatchKNIVs.emplace_back(arg);
|
||||
loopBatchKNIVs.emplace_back(loopKIVs[0]);
|
||||
if (BShape.size() >= 2)
|
||||
if (AShape.size() >= 2)
|
||||
loopBatchKNIVs.emplace_back(loopMNIVs[1]);
|
||||
else
|
||||
loopBatchKNIVs.emplace_back(loopMNIVs[0]);
|
||||
|
||||
// Matmul computation
|
||||
auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs);
|
||||
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs);
|
||||
auto loadedY = rewriter.create<LoadOp>(loc, alloc, loopBatchMNIVs);
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
|
||||
auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
|
||||
rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
|
||||
} else if (elementType.isa<FloatType>()) {
|
||||
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
|
||||
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
|
||||
rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
|
||||
}
|
||||
} else if ((AShape.size() == 1) && (BShape.size() == 1)) {
|
||||
// Case 3:
|
||||
// - Both arguments are 1-D
|
||||
|
||||
// Fill the output with value 0.
|
||||
Value zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
rewriter.create<StoreOp>(loc, zero, alloc, zeroIndex);
|
||||
|
||||
// Iterate along the reduction dimension.
|
||||
// Use a value from A.
|
||||
std::vector<Value> reduceLoops;
|
||||
std::vector<Value> optimizedReduceLoops;
|
||||
Block *optimizationReduceBlock =
|
||||
defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
|
||||
KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
|
||||
optimizedReduceLoops);
|
||||
addDimensionToPack(rewriter, loc, reducePack, A, 0);
|
||||
auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
|
||||
|
||||
// No optimization
|
||||
rewriter.setInsertionPointToEnd(optimizationReduceBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
|
||||
|
||||
// Insert instructions into the reduction KrnlIterateOp.
|
||||
Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&reduceIterationBlock);
|
||||
|
||||
// Induction variables
|
||||
SmallVector<Value, 4> loopKIVs;
|
||||
// K
|
||||
loopKIVs.emplace_back(reduceIterationBlock.getArgument(0));
|
||||
|
||||
// Matmul computation
|
||||
auto loadedA = rewriter.create<LoadOp>(loc, A, loopKIVs);
|
||||
auto loadedB = rewriter.create<LoadOp>(loc, B, loopKIVs);
|
||||
auto loadedY = rewriter.create<LoadOp>(loc, alloc, zeroIndex);
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
|
||||
auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
|
||||
rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
|
||||
} else if (elementType.isa<FloatType>()) {
|
||||
auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
|
||||
auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
|
||||
rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
|
||||
}
|
||||
} else {
|
||||
// No scalar matrix multiplication.
|
||||
llvm_unreachable("Unsupported scalar matrix multiplication.");
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct ONNXGemmOpLowering : public ConversionPattern {
|
||||
ONNXGemmOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {}
|
||||
|
@ -1423,10 +1762,12 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
B = operands[1];
|
||||
C = operands[2];
|
||||
|
||||
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
|
||||
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
|
||||
auto alphaAttr =
|
||||
FloatAttr::get(tensorType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
|
||||
auto betaAttr =
|
||||
FloatAttr::get(tensorType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||
|
||||
|
@ -1482,8 +1823,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
outerLoops.push_back(originalLoops[i]);
|
||||
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||
}
|
||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
|
||||
optimizedOuterLoops);
|
||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
||||
// Induction variables for the outer loops
|
||||
for (int i = 0; i < 2; ++i)
|
||||
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
||||
|
@ -1501,17 +1841,16 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
// If it failed then use a dynamic value.
|
||||
auto ATy = A.getType().cast<MemRefType>();
|
||||
auto BTy = B.getType().cast<MemRefType>();
|
||||
int64_t K_A_Idx = (isTransA) ? 0 : 1;
|
||||
int64_t K_B_Idx = (isTransB) ? 1 : 0;
|
||||
int K_A_Idx = (isTransA) ? 0 : 1;
|
||||
int K_B_Idx = (isTransB) ? 1 : 0;
|
||||
reductionPack.pushConstantBound(0);
|
||||
if (ATy.getShape()[K_A_Idx] != -1)
|
||||
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
|
||||
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
|
||||
else if (BTy.getShape()[K_B_Idx] != -1)
|
||||
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
|
||||
else
|
||||
if (BTy.getShape()[K_B_Idx] != -1)
|
||||
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
|
||||
else
|
||||
reductionPack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
|
||||
reductionPack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
|
||||
|
||||
// Get run-time dimension information for unknown dimensions used for
|
||||
// broadcasting.
|
||||
|
@ -1524,7 +1863,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
auto dim = rewriter.create<DimOp>(loc, C, i).getResult();
|
||||
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
auto isBroadcasted =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
|
||||
broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
|
||||
}
|
||||
}
|
||||
|
@ -1557,8 +1896,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
|
||||
|
||||
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
||||
auto loopCIVs = getLoopIVsForBroadcasting(
|
||||
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
|
||||
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
|
||||
broadcastedDimInfo);
|
||||
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
||||
|
@ -1650,8 +1989,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
Value dimVal = nullptr;
|
||||
if (memRefShape[outIdx] < 0) {
|
||||
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
|
||||
dimVal = rewriter.create<IndexCastOp>(
|
||||
loc, index, rewriter.getIntegerType(64));
|
||||
dimVal = rewriter.create<IndexCastOp>(loc, index,
|
||||
rewriter.getIntegerType(64));
|
||||
allocOperands.emplace_back(index);
|
||||
} else {
|
||||
dimVal = rewriter.create<ConstantOp>(
|
||||
|
@ -1738,8 +2077,8 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
|||
// TODO: Remove when perm is guaranteed to be present (even for
|
||||
// the default case). This means that perm was added by shape
|
||||
// inference or another pass to contain the values corresponding
|
||||
// to the default behavior of Transpose.
|
||||
for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--)
|
||||
// to the default behavior of Transpose.
|
||||
for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--)
|
||||
perm.emplace_back(i);
|
||||
}
|
||||
|
||||
|
@ -1748,7 +2087,7 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
|||
inLoopIVs.emplace_back(arg);
|
||||
|
||||
SmallVector<Value, 4> outLoopIVs;
|
||||
for (int i=0; i<iterationBlock.getArguments().size(); ++i)
|
||||
for (int i = 0; i < iterationBlock.getArguments().size(); ++i)
|
||||
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
|
||||
|
||||
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
||||
|
@ -2362,8 +2701,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
ONNXReductionOpLowering<mlir::ONNXReduceSumOp>,
|
||||
ONNXSoftmaxOpLowering, ONNXGemmOpLowering,
|
||||
ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering,
|
||||
ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering
|
||||
>(&getContext());
|
||||
ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering,
|
||||
ONNXMatMulOpLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
|
|
|
@ -295,6 +295,12 @@ test_to_enable = [
|
|||
|
||||
# Sign Op:
|
||||
"test_sign_cpu",
|
||||
|
||||
# MatmulOp
|
||||
"test_matmul_2d_cpu",
|
||||
"test_matmul_3d_cpu",
|
||||
"test_matmul_4d_cpu",
|
||||
|
||||
]
|
||||
|
||||
# Extract name of all test cases.
|
||||
|
|
|
@ -930,6 +930,223 @@ func @test_sign_i(%arg0 : tensor<?x10xi32>) -> tensor<*xi32> {
|
|||
// CHECK: return [[RES]] : memref<?x10xi32>
|
||||
}
|
||||
|
||||
// 2-D x 2-D
|
||||
func @test_matmul1(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<10x5xf32>, tensor<5x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul1
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
||||
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[LOOPS]]#0 -> %arg2 = 0 to 10, [[LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg4 = 0 to 5) {
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg2, %arg4] : memref<10x5xf32>
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
|
||||
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
|
||||
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
|
||||
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||
}
|
||||
|
||||
// 2-D x N-D
|
||||
func @test_matmul2(%arg0 : tensor<10x5xf32>, %arg1 : tensor<2x3x5x10xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<10x5xf32>, tensor<2x3x5x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul2
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<2x3x10x10xf32>
|
||||
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[LOOPS]]#0 -> %arg2 = 0 to 2, [[LOOPS]]#1 -> %arg3 = 0 to 3) {
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#2 -> %arg4 = 0 to 10, [[LOOPS]]#3 -> %arg5 = 0 to 10) {
|
||||
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
|
||||
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg6 = 0 to 5) {
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg4, %arg6] : memref<10x5xf32>
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg2, %arg3, %arg6, %arg5] : memref<2x3x5x10xf32>
|
||||
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
|
||||
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
|
||||
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
|
||||
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<2x3x10x10xf32>
|
||||
}
|
||||
|
||||
// N-D x N-D
|
||||
func @test_matmul3(%arg0 : tensor<2x3x10x5xf32>, %arg1 : tensor<2x3x5x10xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<2x3x10x5xf32>, tensor<2x3x5x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul3
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<2x3x10x10xf32>
|
||||
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[LOOPS]]#0 -> %arg2 = 0 to 2, [[LOOPS]]#1 -> %arg3 = 0 to 3) {
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#2 -> %arg4 = 0 to 10, [[LOOPS]]#3 -> %arg5 = 0 to 10) {
|
||||
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
|
||||
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg6 = 0 to 5) {
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg2, %arg3, %arg4, %arg6] : memref<2x3x10x5xf32>
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg2, %arg3, %arg6, %arg5] : memref<2x3x5x10xf32>
|
||||
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
|
||||
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
|
||||
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
|
||||
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<2x3x10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<2x3x10x10xf32>
|
||||
}
|
||||
|
||||
// 1-D x 2-D
|
||||
func @test_matmul4(%arg0 : tensor<5xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<5xf32>, tensor<5x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul4
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10xf32>
|
||||
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[LOOPS:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]) with ([[LOOPS]] -> %arg2 = 0 to 10) {
|
||||
// CHECK: store [[CONSTANT]], [[RES]][%arg2] : memref<10xf32>
|
||||
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg3 = 0 to 5) {
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg3] : memref<5xf32>
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg3, %arg2] : memref<5x10xf32>
|
||||
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2] : memref<10xf32>
|
||||
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
|
||||
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
|
||||
// CHECK: store [[ADD]], [[RES]][%arg2] : memref<10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<10xf32>
|
||||
}
|
||||
|
||||
// 1-D x N-D
|
||||
func @test_matmul5(%arg0 : tensor<5xf32>, %arg1 : tensor<?x5x10xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<5xf32>, tensor<?x5x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul5
|
||||
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg1, 0 : memref<?x5x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_1:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[LOOPS]]#0 -> %arg2 = 0 to [[DIM_1]]) {
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg4 = 0 to 5) {
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg4] : memref<5xf32>
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg2, %arg4, %arg3] : memref<?x5x10xf32>
|
||||
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
|
||||
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
|
||||
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
// N-D x 1-D
|
||||
func @test_matmul6(%arg0 : tensor<?x10x5xf32>, %arg1 : tensor<5xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<?x10x5xf32>, tensor<5xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul6
|
||||
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10x5xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_1:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[LOOPS]]#0 -> %arg2 = 0 to [[DIM_1]]) {
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: store [[CONSTANT]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg4 = 0 to 5) {
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg2, %arg3, %arg4] : memref<?x10x5xf32>
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg4] : memref<5xf32>
|
||||
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
|
||||
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
|
||||
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
// 1-D x 1-D
|
||||
func @test_matmul7(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.MatMul"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul7
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<1xf32>
|
||||
// CHECK: [[CONSTANT:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[CONSTANT_INDEX:.+]] = constant 0 : index
|
||||
// CHECK: store [[CONSTANT]], [[RES]][%[[CONSTANT_INDEX]]] : memref<1xf32>
|
||||
// CHECK: [[LOOPS_REDUCE:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS_REDUCE:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[LOOPS_REDUCE]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_REDUCE]]) with ([[LOOPS_REDUCE]] -> %arg2 = 0 to 5) {
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg2] : memref<5xf32>
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%arg2] : memref<5xf32>
|
||||
// CHECK: [[LOAD_RES:%.+]] = load [[RES]][%[[CONSTANT_INDEX]]] : memref<1xf32>
|
||||
// CHECK: [[MUL:%.+]] = mulf [[LOAD_0]], [[LOAD_1]] : f32
|
||||
// CHECK: [[ADD:%.+]] = addf [[LOAD_RES]], [[MUL]] : f32
|
||||
// CHECK: store [[ADD]], [[RES]][%[[CONSTANT_INDEX]]] : memref<1xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<1xf32>
|
||||
}
|
||||
|
||||
func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue