Bug fix, ensure krnl.iterate can lower in the degenerate case. (#78)

* Bug fix, ensure krnl.iterate can lower in the degenerate case.

* Fix parser issue with degenerate iterate op.

* Add a test case.

* Remove dead code.

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tian Jin 2020-04-10 23:27:00 +08:00 committed by GitHub
parent 7dba324404
commit caeaa390e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 96 additions and 81 deletions

View File

@ -46,13 +46,13 @@ KrnlOpsDialect::KrnlOpsDialect(MLIRContext *context)
// KrnlDefineLoopsOp // KrnlDefineLoopsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void KrnlDefineLoopsOp::build(Builder *builder, OperationState &result, void KrnlDefineLoopsOp::build(
int64_t num_loops) { Builder *builder, OperationState &result, int64_t num_loops) {
// Create the same number of dimension handlers as the number of // Create the same number of dimension handlers as the number of
// dimensions in the associated integer set. // dimensions in the associated integer set.
result.types.append(num_loops, LoopType::get(builder->getContext())); result.types.append(num_loops, LoopType::get(builder->getContext()));
result.addAttribute(getNumLoopsAttrName(), result.addAttribute(
builder->getI32IntegerAttr(num_loops)); getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops));
} }
void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) { void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) {
@ -61,15 +61,14 @@ void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) {
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue(); p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
} }
ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser, ParseResult parseKrnlDefineLoopsOp(
OperationState &result) { OpAsmParser &parser, OperationState &result) {
// Parse the attribute indicating number of loops defined. // Parse the attribute indicating number of loops defined.
IntegerAttr numLoops; IntegerAttr numLoops;
auto &builder = parser.getBuilder(); auto &builder = parser.getBuilder();
auto intType = builder.getIntegerType(64); auto intType = builder.getIntegerType(64);
if (parser.parseAttribute(numLoops, intType, if (parser.parseAttribute(numLoops, intType,
KrnlDefineLoopsOp::getNumLoopsAttrName(), KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
result.attributes))
return failure(); return failure();
auto loopTypes = llvm::SmallVector<Type, 4>( auto loopTypes = llvm::SmallVector<Type, 4>(
@ -82,10 +81,10 @@ ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser,
// KrnlOptimizeLoopsOp // KrnlOptimizeLoopsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void KrnlOptimizeLoopsOp::build(Builder *builder, OperationState &result, void KrnlOptimizeLoopsOp::build(
int num_optimized_loops) { Builder *builder, OperationState &result, int num_optimized_loops) {
result.types.append(num_optimized_loops, result.types.append(
LoopType::get(builder->getContext())); num_optimized_loops, LoopType::get(builder->getContext()));
// Create a region and a block for the body. // Create a region and a block for the body.
// Schedule intrinsics will be placed into this region. // Schedule intrinsics will be placed into this region.
Region *region = result.addRegion(); Region *region = result.addRegion();
@ -101,8 +100,8 @@ void print(OpAsmPrinter &p, KrnlOptimizeLoopsOp &op) {
p.printFunctionalType(op); p.printFunctionalType(op);
} }
ParseResult parseKrnlOptimizeLoopsOp(OpAsmParser &parser, ParseResult parseKrnlOptimizeLoopsOp(
OperationState &result) { OpAsmParser &parser, OperationState &result) {
// Parse the schedule body region. // Parse the schedule body region.
Region *region = result.addRegion(); Region *region = result.addRegion();
if (parser.parseRegion(*region, llvm::None, llvm::None)) if (parser.parseRegion(*region, llvm::None, llvm::None))
@ -149,11 +148,10 @@ void KrnlIterateOp::build(Builder *builder, OperationState &result,
KrnlIterateOperandPack operandPack) { KrnlIterateOperandPack operandPack) {
// Record optimized loops and the number of such loops. // Record optimized loops and the number of such loops.
result.addOperands(operandPack.getOperands()); result.addOperands(operandPack.getOperands());
result.addAttribute(KrnlIterateOp::getBoundsAttrName(),
operandPack.getAttributes());
result.addAttribute( result.addAttribute(
getNumOptimizedLoopsAttrName(), KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes());
result.addAttribute(getNumOptimizedLoopsAttrName(),
builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops())); builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
// Create a region and a block for the body. The arguments of the region are // Create a region and a block for the body. The arguments of the region are
@ -190,9 +188,11 @@ void print(OpAsmPrinter &p, KrnlIterateOp &op) {
p << " -> "; p << " -> ";
p.printOperand(var); p.printOperand(var);
p << " = "; p << " = ";
onnx_mlir::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "max", p); onnx_mlir::printBound(
(*boundItr++).cast<AffineMapAttr>(), operandItr, "max", p);
p << " to "; p << " to ";
onnx_mlir::printBound((*boundItr++).cast<AffineMapAttr>(), operandItr, "min", p); onnx_mlir::printBound(
(*boundItr++).cast<AffineMapAttr>(), operandItr, "min", p);
delimiter = ", "; delimiter = ", ";
} }
@ -208,11 +208,10 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
// Parse optimized loops: // Parse optimized loops:
SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs; SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
if (parser.parseOperandList(optimizedLoopRefs, if (parser.parseOperandList(
OpAsmParser::Delimiter::Paren) || optimizedLoopRefs, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(optimizedLoopRefs, parser.resolveOperands(optimizedLoopRefs,
LoopType::get(result.getContext()), LoopType::get(result.getContext()), result.operands))
result.operands))
return failure(); return failure();
// Record how many optimized loops did we parse. // Record how many optimized loops did we parse.
@ -227,16 +226,16 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
return failure(); return failure();
// A function to parse a lower or upper bound. // A function to parse a lower or upper bound.
auto parseBound = [&result, &builder, &parser, &operandParser, auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps](
&boundMaps](bool isUpper) -> ParseResult { bool isUpper) -> ParseResult {
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
// the map has multiple results. // the map has multiple results.
bool failedToParsedMinMax = bool failedToParsedMinMax =
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max")); failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
// Try parse an SSA operand. // Try parse an SSA operand.
if (succeeded(operandParser.ParseOptionalOperand(builder.getIndexType(), if (succeeded(operandParser.ParseOptionalOperand(
result.operands))) { builder.getIndexType(), result.operands))) {
AffineMap map = builder.getSymbolIdentityMap(); AffineMap map = builder.getSymbolIdentityMap();
boundMaps.emplace_back(AffineMapAttr::get(map)); boundMaps.emplace_back(AffineMapAttr::get(map));
return success(); return success();
@ -248,8 +247,8 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
llvm::SMLoc attrLoc = parser.getCurrentLocation(); llvm::SMLoc attrLoc = parser.getCurrentLocation();
Attribute boundAttr; Attribute boundAttr;
llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer; llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
if (parser.parseAttribute(boundAttr, builder.getIndexType(), "temp", if (parser.parseAttribute(
tempBoundAttrContainer)) boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer))
return failure(); return failure();
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
@ -260,15 +259,13 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
auto map = affineMapAttr.getValue(); auto map = affineMapAttr.getValue();
if (map.getNumDims() != numDims) if (map.getNumDims() != numDims)
return parser.emitError( return parser.emitError(parser.getNameLoc(),
parser.getNameLoc(),
"dim operand count and integer set dim count must match"); "dim operand count and integer set dim count must match");
unsigned numDimAndSymbolOperands = unsigned numDimAndSymbolOperands =
result.operands.size() - currentNumOperands; result.operands.size() - currentNumOperands;
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
return parser.emitError( return parser.emitError(parser.getNameLoc(),
parser.getNameLoc(),
"symbol operand count and integer set symbol count must match"); "symbol operand count and integer set symbol count must match");
// If the map has multiple results, make sure that we parsed the min/max // If the map has multiple results, make sure that we parsed the min/max
@ -293,8 +290,7 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
} }
}; };
bool keepParsing; // Do we keep parsing loops/bounds? while (failed(parser.parseOptionalRParen())) {
do {
// Parse an input loop operand; // Parse an input loop operand;
operandParser.ParseOperand(LoopType::get(context), result.operands); operandParser.ParseOperand(LoopType::get(context), result.operands);
parser.parseArrow(); parser.parseArrow();
@ -315,26 +311,23 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
// the entire "{operand bound}, {input_loop_operand}" sequence will // the entire "{operand bound}, {input_loop_operand}" sequence will
// be parsed as an operand list. // be parsed as an operand list.
parser.parseOptionalComma(); parser.parseOptionalComma();
}
// If we don't see a RParen token, we keep parsing.
keepParsing = failed(parser.parseOptionalRParen());
} while (keepParsing);
// At this point, there shouldn't be any operands left to parse. // At this point, there shouldn't be any operands left to parse.
if (operandParser.hasOperandLeft()) if (operandParser.hasOperandLeft())
return parser.emitError(parser.getCurrentLocation()); return parser.emitError(parser.getCurrentLocation());
result.addAttribute(KrnlIterateOp::getBoundsAttrName(), result.addAttribute(
builder.getArrayAttr(boundMaps)); KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps));
Region *region = result.addRegion(); Region *region = result.addRegion();
SmallVector<Type, 4> inductionVarTypes(inductionVarRefs.size(), SmallVector<Type, 4> inductionVarTypes(
builder.getIndexType()); inductionVarRefs.size(), builder.getIndexType());
if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes)) if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
return failure(); return failure();
// Ensure iterate region is closed off with krnl.terminate. // Ensure iterate region is closed off with krnl.terminate.
KrnlIterateOp::ensureTerminator(*region, parser.getBuilder(), KrnlIterateOp::ensureTerminator(
result.location); *region, parser.getBuilder(), result.location);
return success(); return success();
} }
@ -353,22 +346,20 @@ void print(OpAsmPrinter &p, KrnlReturnLoopsOp &op) {
p.printOperands(op.operand_begin(), op.operand_end()); p.printOperands(op.operand_begin(), op.operand_end());
} }
ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser, ParseResult parseKrnlReturnLoopsOp(
OperationState &result) { OpAsmParser &parser, OperationState &result) {
// Parse the loops to return. // Parse the loops to return.
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers; SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
if (parser.parseOperandList(timestamp_dim_handlers) || if (parser.parseOperandList(timestamp_dim_handlers) ||
parser.resolveOperands(timestamp_dim_handlers, parser.resolveOperands(timestamp_dim_handlers,
LoopType::get(result.getContext()), LoopType::get(result.getContext()), result.operands))
result.operands))
return failure(); return failure();
return success(); return success();
} }
void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state, void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state,
SymbolRefAttr funcAttr, IntegerAttr numInputs, SymbolRefAttr funcAttr, IntegerAttr numInputs, IntegerAttr numOutputs) {
IntegerAttr numOutputs) {
state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr); state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr);
state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs); state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs);
state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs); state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs);

View File

@ -27,8 +27,8 @@ namespace {
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> { struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern; using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(KrnlIterateOp iterateOp, LogicalResult matchAndRewrite(
PatternRewriter &rewriter) const override { KrnlIterateOp iterateOp, PatternRewriter &rewriter) const override {
auto boundMapAttrs = auto boundMapAttrs =
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName()) iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
.getValue(); .getValue();
@ -48,8 +48,8 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
map = boundMapAttrs[boundIdx + boundType] map = boundMapAttrs[boundIdx + boundType]
.cast<AffineMapAttr>() .cast<AffineMapAttr>()
.getValue(); .getValue();
operands.insert(operands.end(), operandItr, operands.insert(
operandItr + map.getNumInputs()); operands.end(), operandItr, operandItr + map.getNumInputs());
std::advance(operandItr, map.getNumInputs()); std::advance(operandItr, map.getNumInputs());
} }
@ -62,7 +62,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
// Replace induction variable references from those introduced by a // Replace induction variable references from those introduced by a
// single krnl.iterate to those introduced by multiple affine.for // single krnl.iterate to those introduced by multiple affine.for
// operations. // operations.
for (size_t i = 0; i < nestedForOps.size() - 1; i++) { for (int64_t i = 0; i < (int64_t)nestedForOps.size() - 1; i++) {
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0); auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
auto forIV = nestedForOps[i].getBody()->getArgument(0); auto forIV = nestedForOps[i].getBody()->getArgument(0);
iterateIV.replaceAllUsesWith(forIV); iterateIV.replaceAllUsesWith(forIV);
@ -74,11 +74,21 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
while (iterateOp.bodyRegion().front().getNumArguments() > 1) while (iterateOp.bodyRegion().front().getNumArguments() > 1)
iterateOp.bodyRegion().front().eraseArgument(0); iterateOp.bodyRegion().front().eraseArgument(0);
if (nestedForOps.empty()) {
// If no loops are involved, simply move operations from within iterateOp
// body region to the parent region of iterateOp.
rewriter.setInsertionPoint(iterateOp);
iterateOp.bodyRegion().walk([&](Operation *op) {
if (!op->isKnownTerminator())
op->replaceAllUsesWith(rewriter.clone(*op));
});
} else {
// Transfer krnl.iterate region to innermost for op. // Transfer krnl.iterate region to innermost for op.
auto innermostForOp = nestedForOps.back(); auto innermostForOp = nestedForOps.back();
innermostForOp.region().getBlocks().clear(); innermostForOp.region().getBlocks().clear();
rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(), rewriter.inlineRegionBefore(iterateOp.bodyRegion(),
innermostForOp.region().end()); innermostForOp.region(), innermostForOp.region().end());
}
rewriter.eraseOp(iterateOp); rewriter.eraseOp(iterateOp);
return success(); return success();
@ -93,8 +103,8 @@ class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
public: public:
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern; using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(KrnlTerminatorOp op, LogicalResult matchAndRewrite(
PatternRewriter &rewriter) const override { KrnlTerminatorOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op); rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
return success(); return success();
} }
@ -108,8 +118,8 @@ class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
public: public:
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern; using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
LogicalResult matchAndRewrite(KrnlDefineLoopsOp op, LogicalResult matchAndRewrite(
PatternRewriter &rewriter) const override { KrnlDefineLoopsOp op, PatternRewriter &rewriter) const override {
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
@ -123,8 +133,8 @@ class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
public: public:
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern; using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
LogicalResult matchAndRewrite(KrnlOptimizeLoopsOp op, LogicalResult matchAndRewrite(
PatternRewriter &rewriter) const override { KrnlOptimizeLoopsOp op, PatternRewriter &rewriter) const override {
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
@ -158,8 +168,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering, patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>( KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext());
&getContext());
if (failed(applyPartialConversion(getFunction(), target, patterns))) { if (failed(applyPartialConversion(getFunction(), target, patterns))) {
signalPassFailure(); signalPassFailure();
@ -172,5 +181,5 @@ std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
return std::make_unique<KrnlToAffineLoweringPass>(); return std::make_unique<KrnlToAffineLoweringPass>();
} }
static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl", static PassRegistration<KrnlToAffineLoweringPass> pass(
"Lower Krnl dialect."); "lower-krnl", "Lower Krnl dialect.");

View File

@ -0,0 +1,15 @@
// RUN: onnx-mlir-opt --lower-krnl %s -split-input-file | FileCheck %s
func @test_lower_degenerate_iterate(%arg0: memref<f32>) -> memref<f32> {
%0 = alloc() : memref<f32>
krnl.iterate() with () {
%1 = load %arg0[] : memref<f32>
store %1, %0[] : memref<f32>
}
return %0 : memref<f32>
// CHECK-LABEL: test_lower_degenerate_iterate
// CHECK-NEXT: [[ALLOC:%.+]] = alloc() : memref<f32>
// CHECK-NEXT: [[LOAD:%.+]] = load %{{.*}}[] : memref<f32>
// CHECK-NEXT: store [[LOAD]], [[ALLOC]][] : memref<f32>
// CHECK-NEXT: return [[ALLOC]] : memref<f32>
}