Bug fix, ensure krnl.iterate can lower in the degenerate case. (#78)
* Bug fix, ensure krnl.iterate can lower in the degenerate case. * Fix parser issue with degenerate iterate op. * Add a test case. * Remove dead code. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
7dba324404
commit
caeaa390e2
|
@ -46,13 +46,13 @@ KrnlOpsDialect::KrnlOpsDialect(MLIRContext *context)
|
||||||
// KrnlDefineLoopsOp
|
// KrnlDefineLoopsOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void KrnlDefineLoopsOp::build(Builder *builder, OperationState &result,
|
void KrnlDefineLoopsOp::build(
|
||||||
int64_t num_loops) {
|
Builder *builder, OperationState &result, int64_t num_loops) {
|
||||||
// Create the same number of dimension handlers as the number of
|
// Create the same number of dimension handlers as the number of
|
||||||
// dimensions in the associated integer set.
|
// dimensions in the associated integer set.
|
||||||
result.types.append(num_loops, LoopType::get(builder->getContext()));
|
result.types.append(num_loops, LoopType::get(builder->getContext()));
|
||||||
result.addAttribute(getNumLoopsAttrName(),
|
result.addAttribute(
|
||||||
builder->getI32IntegerAttr(num_loops));
|
getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops));
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) {
|
void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) {
|
||||||
|
@ -61,15 +61,14 @@ void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) {
|
||||||
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
|
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser,
|
ParseResult parseKrnlDefineLoopsOp(
|
||||||
OperationState &result) {
|
OpAsmParser &parser, OperationState &result) {
|
||||||
// Parse the attribute indicating number of loops defined.
|
// Parse the attribute indicating number of loops defined.
|
||||||
IntegerAttr numLoops;
|
IntegerAttr numLoops;
|
||||||
auto &builder = parser.getBuilder();
|
auto &builder = parser.getBuilder();
|
||||||
auto intType = builder.getIntegerType(64);
|
auto intType = builder.getIntegerType(64);
|
||||||
if (parser.parseAttribute(numLoops, intType,
|
if (parser.parseAttribute(numLoops, intType,
|
||||||
KrnlDefineLoopsOp::getNumLoopsAttrName(),
|
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
|
||||||
result.attributes))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto loopTypes = llvm::SmallVector<Type, 4>(
|
auto loopTypes = llvm::SmallVector<Type, 4>(
|
||||||
|
@ -82,10 +81,10 @@ ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser,
|
||||||
// KrnlOptimizeLoopsOp
|
// KrnlOptimizeLoopsOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void KrnlOptimizeLoopsOp::build(Builder *builder, OperationState &result,
|
void KrnlOptimizeLoopsOp::build(
|
||||||
int num_optimized_loops) {
|
Builder *builder, OperationState &result, int num_optimized_loops) {
|
||||||
result.types.append(num_optimized_loops,
|
result.types.append(
|
||||||
LoopType::get(builder->getContext()));
|
num_optimized_loops, LoopType::get(builder->getContext()));
|
||||||
// Create a region and a block for the body.
|
// Create a region and a block for the body.
|
||||||
// Schedule intrinsics will be placed into this region.
|
// Schedule intrinsics will be placed into this region.
|
||||||
Region *region = result.addRegion();
|
Region *region = result.addRegion();
|
||||||
|
@ -96,13 +95,13 @@ void KrnlOptimizeLoopsOp::build(Builder *builder, OperationState &result,
|
||||||
void print(OpAsmPrinter &p, KrnlOptimizeLoopsOp &op) {
|
void print(OpAsmPrinter &p, KrnlOptimizeLoopsOp &op) {
|
||||||
p << "krnl.optimize_loops ";
|
p << "krnl.optimize_loops ";
|
||||||
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
|
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
|
||||||
/*printBlockTerminators=*/true);
|
/*printBlockTerminators=*/true);
|
||||||
p << " : ";
|
p << " : ";
|
||||||
p.printFunctionalType(op);
|
p.printFunctionalType(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseKrnlOptimizeLoopsOp(OpAsmParser &parser,
|
ParseResult parseKrnlOptimizeLoopsOp(
|
||||||
OperationState &result) {
|
OpAsmParser &parser, OperationState &result) {
|
||||||
// Parse the schedule body region.
|
// Parse the schedule body region.
|
||||||
Region *region = result.addRegion();
|
Region *region = result.addRegion();
|
||||||
if (parser.parseRegion(*region, llvm::None, llvm::None))
|
if (parser.parseRegion(*region, llvm::None, llvm::None))
|
||||||
|
@ -146,14 +145,13 @@ ParseResult parseKrnlOptimizeLoopsOp(OpAsmParser &parser,
|
||||||
* %i0 = 10 to N : %i1 = M to 20
|
* %i0 = 10 to N : %i1 = M to 20
|
||||||
*/
|
*/
|
||||||
void KrnlIterateOp::build(Builder *builder, OperationState &result,
|
void KrnlIterateOp::build(Builder *builder, OperationState &result,
|
||||||
KrnlIterateOperandPack operandPack) {
|
KrnlIterateOperandPack operandPack) {
|
||||||
// Record optimized loops and the number of such loops.
|
// Record optimized loops and the number of such loops.
|
||||||
result.addOperands(operandPack.getOperands());
|
result.addOperands(operandPack.getOperands());
|
||||||
result.addAttribute(KrnlIterateOp::getBoundsAttrName(),
|
|
||||||
operandPack.getAttributes());
|
|
||||||
|
|
||||||
result.addAttribute(
|
result.addAttribute(
|
||||||
getNumOptimizedLoopsAttrName(),
|
KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes());
|
||||||
|
|
||||||
|
result.addAttribute(getNumOptimizedLoopsAttrName(),
|
||||||
builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
|
builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
|
||||||
|
|
||||||
// Create a region and a block for the body. The arguments of the region are
|
// Create a region and a block for the body. The arguments of the region are
|
||||||
|
@ -190,15 +188,17 @@ void print(OpAsmPrinter &p, KrnlIterateOp &op) {
|
||||||
p << " -> ";
|
p << " -> ";
|
||||||
p.printOperand(var);
|
p.printOperand(var);
|
||||||
p << " = ";
|
p << " = ";
|
||||||
onnx_mlir::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "max", p);
|
onnx_mlir::printBound(
|
||||||
|
(*boundItr++).cast<AffineMapAttr>(), operandItr, "max", p);
|
||||||
p << " to ";
|
p << " to ";
|
||||||
onnx_mlir::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "min", p);
|
onnx_mlir::printBound(
|
||||||
|
(*boundItr++).cast<AffineMapAttr>(), operandItr, "min", p);
|
||||||
delimiter = ", ";
|
delimiter = ", ";
|
||||||
}
|
}
|
||||||
|
|
||||||
p << ")";
|
p << ")";
|
||||||
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
|
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
|
||||||
/*printBlockTerminators=*/false);
|
/*printBlockTerminators=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
||||||
|
@ -208,16 +208,15 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
||||||
|
|
||||||
// Parse optimized loops:
|
// Parse optimized loops:
|
||||||
SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
|
SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
|
||||||
if (parser.parseOperandList(optimizedLoopRefs,
|
if (parser.parseOperandList(
|
||||||
OpAsmParser::Delimiter::Paren) ||
|
optimizedLoopRefs, OpAsmParser::Delimiter::Paren) ||
|
||||||
parser.resolveOperands(optimizedLoopRefs,
|
parser.resolveOperands(optimizedLoopRefs,
|
||||||
LoopType::get(result.getContext()),
|
LoopType::get(result.getContext()), result.operands))
|
||||||
result.operands))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Record how many optimized loops did we parse.
|
// Record how many optimized loops did we parse.
|
||||||
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
|
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
|
||||||
builder.getI64IntegerAttr(optimizedLoopRefs.size()));
|
builder.getI64IntegerAttr(optimizedLoopRefs.size()));
|
||||||
|
|
||||||
// Parse input loops and their lower and upper bounds.
|
// Parse input loops and their lower and upper bounds.
|
||||||
SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs;
|
SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs;
|
||||||
|
@ -227,16 +226,16 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// A function to parse a lower or upper bound.
|
// A function to parse a lower or upper bound.
|
||||||
auto parseBound = [&result, &builder, &parser, &operandParser,
|
auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps](
|
||||||
&boundMaps](bool isUpper) -> ParseResult {
|
bool isUpper) -> ParseResult {
|
||||||
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
|
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
|
||||||
// the map has multiple results.
|
// the map has multiple results.
|
||||||
bool failedToParsedMinMax =
|
bool failedToParsedMinMax =
|
||||||
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
|
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
|
||||||
|
|
||||||
// Try parse an SSA operand.
|
// Try parse an SSA operand.
|
||||||
if (succeeded(operandParser.ParseOptionalOperand(builder.getIndexType(),
|
if (succeeded(operandParser.ParseOptionalOperand(
|
||||||
result.operands))) {
|
builder.getIndexType(), result.operands))) {
|
||||||
AffineMap map = builder.getSymbolIdentityMap();
|
AffineMap map = builder.getSymbolIdentityMap();
|
||||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
return success();
|
return success();
|
||||||
|
@ -248,8 +247,8 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
||||||
llvm::SMLoc attrLoc = parser.getCurrentLocation();
|
llvm::SMLoc attrLoc = parser.getCurrentLocation();
|
||||||
Attribute boundAttr;
|
Attribute boundAttr;
|
||||||
llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
|
llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
|
||||||
if (parser.parseAttribute(boundAttr, builder.getIndexType(), "temp",
|
if (parser.parseAttribute(
|
||||||
tempBoundAttrContainer))
|
boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
|
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
|
||||||
|
@ -260,15 +259,13 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
||||||
|
|
||||||
auto map = affineMapAttr.getValue();
|
auto map = affineMapAttr.getValue();
|
||||||
if (map.getNumDims() != numDims)
|
if (map.getNumDims() != numDims)
|
||||||
return parser.emitError(
|
return parser.emitError(parser.getNameLoc(),
|
||||||
parser.getNameLoc(),
|
|
||||||
"dim operand count and integer set dim count must match");
|
"dim operand count and integer set dim count must match");
|
||||||
|
|
||||||
unsigned numDimAndSymbolOperands =
|
unsigned numDimAndSymbolOperands =
|
||||||
result.operands.size() - currentNumOperands;
|
result.operands.size() - currentNumOperands;
|
||||||
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
|
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
|
||||||
return parser.emitError(
|
return parser.emitError(parser.getNameLoc(),
|
||||||
parser.getNameLoc(),
|
|
||||||
"symbol operand count and integer set symbol count must match");
|
"symbol operand count and integer set symbol count must match");
|
||||||
|
|
||||||
// If the map has multiple results, make sure that we parsed the min/max
|
// If the map has multiple results, make sure that we parsed the min/max
|
||||||
|
@ -276,11 +273,11 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
||||||
if (map.getNumResults() > 1 && failedToParsedMinMax) {
|
if (map.getNumResults() > 1 && failedToParsedMinMax) {
|
||||||
if (isUpper)
|
if (isUpper)
|
||||||
return parser.emitError(attrLoc,
|
return parser.emitError(attrLoc,
|
||||||
"upper loop bound affine map with multiple "
|
"upper loop bound affine map with multiple "
|
||||||
"results requires 'min' prefix");
|
"results requires 'min' prefix");
|
||||||
return parser.emitError(attrLoc,
|
return parser.emitError(attrLoc,
|
||||||
"lower loop bound affine mapwith "
|
"lower loop bound affine mapwith "
|
||||||
"multiple results requires 'max' prefix");
|
"multiple results requires 'max' prefix");
|
||||||
}
|
}
|
||||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
return success();
|
return success();
|
||||||
|
@ -293,8 +290,7 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool keepParsing; // Do we keep parsing loops/bounds?
|
while (failed(parser.parseOptionalRParen())) {
|
||||||
do {
|
|
||||||
// Parse an input loop operand;
|
// Parse an input loop operand;
|
||||||
operandParser.ParseOperand(LoopType::get(context), result.operands);
|
operandParser.ParseOperand(LoopType::get(context), result.operands);
|
||||||
parser.parseArrow();
|
parser.parseArrow();
|
||||||
|
@ -315,26 +311,23 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
||||||
// the entire "{operand bound}, {input_loop_operand}" sequence will
|
// the entire "{operand bound}, {input_loop_operand}" sequence will
|
||||||
// be parsed as an operand list.
|
// be parsed as an operand list.
|
||||||
parser.parseOptionalComma();
|
parser.parseOptionalComma();
|
||||||
|
}
|
||||||
// If we don't see a RParen token, we keep parsing.
|
|
||||||
keepParsing = failed(parser.parseOptionalRParen());
|
|
||||||
} while (keepParsing);
|
|
||||||
|
|
||||||
// At this point, there shouldn't be any operands left to parse.
|
// At this point, there shouldn't be any operands left to parse.
|
||||||
if (operandParser.hasOperandLeft())
|
if (operandParser.hasOperandLeft())
|
||||||
return parser.emitError(parser.getCurrentLocation());
|
return parser.emitError(parser.getCurrentLocation());
|
||||||
result.addAttribute(KrnlIterateOp::getBoundsAttrName(),
|
result.addAttribute(
|
||||||
builder.getArrayAttr(boundMaps));
|
KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps));
|
||||||
|
|
||||||
Region *region = result.addRegion();
|
Region *region = result.addRegion();
|
||||||
SmallVector<Type, 4> inductionVarTypes(inductionVarRefs.size(),
|
SmallVector<Type, 4> inductionVarTypes(
|
||||||
builder.getIndexType());
|
inductionVarRefs.size(), builder.getIndexType());
|
||||||
if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
|
if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Ensure iterate region is closed off with krnl.terminate.
|
// Ensure iterate region is closed off with krnl.terminate.
|
||||||
KrnlIterateOp::ensureTerminator(*region, parser.getBuilder(),
|
KrnlIterateOp::ensureTerminator(
|
||||||
result.location);
|
*region, parser.getBuilder(), result.location);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -353,22 +346,20 @@ void print(OpAsmPrinter &p, KrnlReturnLoopsOp &op) {
|
||||||
p.printOperands(op.operand_begin(), op.operand_end());
|
p.printOperands(op.operand_begin(), op.operand_end());
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser,
|
ParseResult parseKrnlReturnLoopsOp(
|
||||||
OperationState &result) {
|
OpAsmParser &parser, OperationState &result) {
|
||||||
// Parse the loops to return.
|
// Parse the loops to return.
|
||||||
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
|
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
|
||||||
if (parser.parseOperandList(timestamp_dim_handlers) ||
|
if (parser.parseOperandList(timestamp_dim_handlers) ||
|
||||||
parser.resolveOperands(timestamp_dim_handlers,
|
parser.resolveOperands(timestamp_dim_handlers,
|
||||||
LoopType::get(result.getContext()),
|
LoopType::get(result.getContext()), result.operands))
|
||||||
result.operands))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state,
|
void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state,
|
||||||
SymbolRefAttr funcAttr, IntegerAttr numInputs,
|
SymbolRefAttr funcAttr, IntegerAttr numInputs, IntegerAttr numOutputs) {
|
||||||
IntegerAttr numOutputs) {
|
|
||||||
state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr);
|
state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr);
|
||||||
state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs);
|
state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs);
|
||||||
state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs);
|
state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs);
|
||||||
|
|
|
@ -27,8 +27,8 @@ namespace {
|
||||||
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||||
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
|
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(KrnlIterateOp iterateOp,
|
LogicalResult matchAndRewrite(
|
||||||
PatternRewriter &rewriter) const override {
|
KrnlIterateOp iterateOp, PatternRewriter &rewriter) const override {
|
||||||
auto boundMapAttrs =
|
auto boundMapAttrs =
|
||||||
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
||||||
.getValue();
|
.getValue();
|
||||||
|
@ -48,21 +48,21 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||||
map = boundMapAttrs[boundIdx + boundType]
|
map = boundMapAttrs[boundIdx + boundType]
|
||||||
.cast<AffineMapAttr>()
|
.cast<AffineMapAttr>()
|
||||||
.getValue();
|
.getValue();
|
||||||
operands.insert(operands.end(), operandItr,
|
operands.insert(
|
||||||
operandItr + map.getNumInputs());
|
operands.end(), operandItr, operandItr + map.getNumInputs());
|
||||||
std::advance(operandItr, map.getNumInputs());
|
std::advance(operandItr, map.getNumInputs());
|
||||||
}
|
}
|
||||||
|
|
||||||
nestedForOps.emplace_back(rewriter.create<AffineForOp>(
|
nestedForOps.emplace_back(rewriter.create<AffineForOp>(
|
||||||
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap));
|
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap));
|
||||||
rewriter.setInsertionPoint(nestedForOps.back().getBody(),
|
rewriter.setInsertionPoint(nestedForOps.back().getBody(),
|
||||||
nestedForOps.back().getBody()->begin());
|
nestedForOps.back().getBody()->begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace induction variable references from those introduced by a
|
// Replace induction variable references from those introduced by a
|
||||||
// single krnl.iterate to those introduced by multiple affine.for
|
// single krnl.iterate to those introduced by multiple affine.for
|
||||||
// operations.
|
// operations.
|
||||||
for (size_t i = 0; i < nestedForOps.size() - 1; i++) {
|
for (int64_t i = 0; i < (int64_t)nestedForOps.size() - 1; i++) {
|
||||||
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
|
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
|
||||||
auto forIV = nestedForOps[i].getBody()->getArgument(0);
|
auto forIV = nestedForOps[i].getBody()->getArgument(0);
|
||||||
iterateIV.replaceAllUsesWith(forIV);
|
iterateIV.replaceAllUsesWith(forIV);
|
||||||
|
@ -74,11 +74,21 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||||
while (iterateOp.bodyRegion().front().getNumArguments() > 1)
|
while (iterateOp.bodyRegion().front().getNumArguments() > 1)
|
||||||
iterateOp.bodyRegion().front().eraseArgument(0);
|
iterateOp.bodyRegion().front().eraseArgument(0);
|
||||||
|
|
||||||
// Transfer krnl.iterate region to innermost for op.
|
if (nestedForOps.empty()) {
|
||||||
auto innermostForOp = nestedForOps.back();
|
// If no loops are involved, simply move operations from within iterateOp
|
||||||
innermostForOp.region().getBlocks().clear();
|
// body region to the parent region of iterateOp.
|
||||||
rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(),
|
rewriter.setInsertionPoint(iterateOp);
|
||||||
innermostForOp.region().end());
|
iterateOp.bodyRegion().walk([&](Operation *op) {
|
||||||
|
if (!op->isKnownTerminator())
|
||||||
|
op->replaceAllUsesWith(rewriter.clone(*op));
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Transfer krnl.iterate region to innermost for op.
|
||||||
|
auto innermostForOp = nestedForOps.back();
|
||||||
|
innermostForOp.region().getBlocks().clear();
|
||||||
|
rewriter.inlineRegionBefore(iterateOp.bodyRegion(),
|
||||||
|
innermostForOp.region(), innermostForOp.region().end());
|
||||||
|
}
|
||||||
|
|
||||||
rewriter.eraseOp(iterateOp);
|
rewriter.eraseOp(iterateOp);
|
||||||
return success();
|
return success();
|
||||||
|
@ -93,8 +103,8 @@ class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
|
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(KrnlTerminatorOp op,
|
LogicalResult matchAndRewrite(
|
||||||
PatternRewriter &rewriter) const override {
|
KrnlTerminatorOp op, PatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -108,8 +118,8 @@ class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
|
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(KrnlDefineLoopsOp op,
|
LogicalResult matchAndRewrite(
|
||||||
PatternRewriter &rewriter) const override {
|
KrnlDefineLoopsOp op, PatternRewriter &rewriter) const override {
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -123,8 +133,8 @@ class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
|
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(KrnlOptimizeLoopsOp op,
|
LogicalResult matchAndRewrite(
|
||||||
PatternRewriter &rewriter) const override {
|
KrnlOptimizeLoopsOp op, PatternRewriter &rewriter) const override {
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -158,8 +168,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||||
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(
|
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext());
|
||||||
&getContext());
|
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
|
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
@ -172,5 +181,5 @@ std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
||||||
return std::make_unique<KrnlToAffineLoweringPass>();
|
return std::make_unique<KrnlToAffineLoweringPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl",
|
static PassRegistration<KrnlToAffineLoweringPass> pass(
|
||||||
"Lower Krnl dialect.");
|
"lower-krnl", "Lower Krnl dialect.");
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
// RUN: onnx-mlir-opt --lower-krnl %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @test_lower_degenerate_iterate(%arg0: memref<f32>) -> memref<f32> {
|
||||||
|
%0 = alloc() : memref<f32>
|
||||||
|
krnl.iterate() with () {
|
||||||
|
%1 = load %arg0[] : memref<f32>
|
||||||
|
store %1, %0[] : memref<f32>
|
||||||
|
}
|
||||||
|
return %0 : memref<f32>
|
||||||
|
// CHECK-LABEL: test_lower_degenerate_iterate
|
||||||
|
// CHECK-NEXT: [[ALLOC:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK-NEXT: [[LOAD:%.+]] = load %{{.*}}[] : memref<f32>
|
||||||
|
// CHECK-NEXT: store [[LOAD]], [[ALLOC]][] : memref<f32>
|
||||||
|
// CHECK-NEXT: return [[ALLOC]] : memref<f32>
|
||||||
|
}
|
Loading…
Reference in New Issue