From caeaa390e2c93018e8a2417d20c0e0cf7cf119ca Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Fri, 10 Apr 2020 23:27:00 +0800 Subject: [PATCH] Bug fix, ensure krnl.iterate can lower in the degenerate case. (#78) * Bug fix, ensure krnl.iterate can lower in the degenerate case. * Fix parser issue with degenerate iterate op. * Add a test case. * Remove dead code. Co-authored-by: Gheorghe-Teodor Bercea --- src/Dialect/Krnl/KrnlOps.cpp | 111 +++++++++++------------ src/Transform/LowerKrnl.cpp | 51 ++++++----- test/mlir/conversion/krnl_to_affine.mlir | 15 +++ 3 files changed, 96 insertions(+), 81 deletions(-) create mode 100644 test/mlir/conversion/krnl_to_affine.mlir diff --git a/src/Dialect/Krnl/KrnlOps.cpp b/src/Dialect/Krnl/KrnlOps.cpp index ec240b3..fbe0c34 100644 --- a/src/Dialect/Krnl/KrnlOps.cpp +++ b/src/Dialect/Krnl/KrnlOps.cpp @@ -46,13 +46,13 @@ KrnlOpsDialect::KrnlOpsDialect(MLIRContext *context) // KrnlDefineLoopsOp //===----------------------------------------------------------------------===// -void KrnlDefineLoopsOp::build(Builder *builder, OperationState &result, - int64_t num_loops) { +void KrnlDefineLoopsOp::build( + Builder *builder, OperationState &result, int64_t num_loops) { // Create the same number of dimension handlers as the number of // dimensions in the associated integer set. result.types.append(num_loops, LoopType::get(builder->getContext())); - result.addAttribute(getNumLoopsAttrName(), - builder->getI32IntegerAttr(num_loops)); + result.addAttribute( + getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops)); } void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) { @@ -61,15 +61,14 @@ void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) { p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue(); } -ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser, - OperationState &result) { +ParseResult parseKrnlDefineLoopsOp( + OpAsmParser &parser, OperationState &result) { // Parse the attribute indicating number of loops defined. IntegerAttr numLoops; auto &builder = parser.getBuilder(); auto intType = builder.getIntegerType(64); if (parser.parseAttribute(numLoops, intType, - KrnlDefineLoopsOp::getNumLoopsAttrName(), - result.attributes)) + KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes)) return failure(); auto loopTypes = llvm::SmallVector( @@ -82,10 +81,10 @@ ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser, // KrnlOptimizeLoopsOp //===----------------------------------------------------------------------===// -void KrnlOptimizeLoopsOp::build(Builder *builder, OperationState &result, - int num_optimized_loops) { - result.types.append(num_optimized_loops, - LoopType::get(builder->getContext())); +void KrnlOptimizeLoopsOp::build( + Builder *builder, OperationState &result, int num_optimized_loops) { + result.types.append( + num_optimized_loops, LoopType::get(builder->getContext())); // Create a region and a block for the body. // Schedule intrinsics will be placed into this region. Region *region = result.addRegion(); @@ -96,13 +95,13 @@ void KrnlOptimizeLoopsOp::build(Builder *builder, OperationState &result, void print(OpAsmPrinter &p, KrnlOptimizeLoopsOp &op) { p << "krnl.optimize_loops "; p.printRegion(op.region(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); + /*printBlockTerminators=*/true); p << " : "; p.printFunctionalType(op); } -ParseResult parseKrnlOptimizeLoopsOp(OpAsmParser &parser, - OperationState &result) { +ParseResult parseKrnlOptimizeLoopsOp( + OpAsmParser &parser, OperationState &result) { // Parse the schedule body region. Region *region = result.addRegion(); if (parser.parseRegion(*region, llvm::None, llvm::None)) @@ -146,14 +145,13 @@ ParseResult parseKrnlOptimizeLoopsOp(OpAsmParser &parser, * %i0 = 10 to N : %i1 = M to 20 */ void KrnlIterateOp::build(Builder *builder, OperationState &result, - KrnlIterateOperandPack operandPack) { + KrnlIterateOperandPack operandPack) { // Record optimized loops and the number of such loops. result.addOperands(operandPack.getOperands()); - result.addAttribute(KrnlIterateOp::getBoundsAttrName(), - operandPack.getAttributes()); - result.addAttribute( - getNumOptimizedLoopsAttrName(), + KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes()); + + result.addAttribute(getNumOptimizedLoopsAttrName(), builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops())); // Create a region and a block for the body. The arguments of the region are @@ -190,15 +188,17 @@ void print(OpAsmPrinter &p, KrnlIterateOp &op) { p << " -> "; p.printOperand(var); p << " = "; - onnx_mlir::printBound((*boundItr++).cast(), operandItr, "max", p); + onnx_mlir::printBound( + (*boundItr++).cast(), operandItr, "max", p); p << " to "; - onnx_mlir::printBound((*boundItr++).cast(), operandItr, "min", p); + onnx_mlir::printBound( + (*boundItr++).cast(), operandItr, "min", p); delimiter = ", "; } p << ")"; p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + /*printBlockTerminators=*/false); } ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) { @@ -208,16 +208,15 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) { // Parse optimized loops: SmallVector optimizedLoopRefs; - if (parser.parseOperandList(optimizedLoopRefs, - OpAsmParser::Delimiter::Paren) || + if (parser.parseOperandList( + optimizedLoopRefs, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(optimizedLoopRefs, - LoopType::get(result.getContext()), - result.operands)) + LoopType::get(result.getContext()), result.operands)) return failure(); // Record how many optimized loops did we parse. result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(), - builder.getI64IntegerAttr(optimizedLoopRefs.size())); + builder.getI64IntegerAttr(optimizedLoopRefs.size())); // Parse input loops and their lower and upper bounds. SmallVector inductionVarRefs; @@ -227,16 +226,16 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) { return failure(); // A function to parse a lower or upper bound. - auto parseBound = [&result, &builder, &parser, &operandParser, - &boundMaps](bool isUpper) -> ParseResult { + auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps]( + bool isUpper) -> ParseResult { // 'min' / 'max' prefixes are generally syntactic sugar, but are required if // the map has multiple results. bool failedToParsedMinMax = failed(parser.parseOptionalKeyword(isUpper ? "min" : "max")); // Try parse an SSA operand. - if (succeeded(operandParser.ParseOptionalOperand(builder.getIndexType(), - result.operands))) { + if (succeeded(operandParser.ParseOptionalOperand( + builder.getIndexType(), result.operands))) { AffineMap map = builder.getSymbolIdentityMap(); boundMaps.emplace_back(AffineMapAttr::get(map)); return success(); @@ -248,8 +247,8 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc attrLoc = parser.getCurrentLocation(); Attribute boundAttr; llvm::SmallVector tempBoundAttrContainer; - if (parser.parseAttribute(boundAttr, builder.getIndexType(), "temp", - tempBoundAttrContainer)) + if (parser.parseAttribute( + boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer)) return failure(); if (auto affineMapAttr = boundAttr.dyn_cast()) { @@ -260,15 +259,13 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) { auto map = affineMapAttr.getValue(); if (map.getNumDims() != numDims) - return parser.emitError( - parser.getNameLoc(), + return parser.emitError(parser.getNameLoc(), "dim operand count and integer set dim count must match"); unsigned numDimAndSymbolOperands = result.operands.size() - currentNumOperands; if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) - return parser.emitError( - parser.getNameLoc(), + return parser.emitError(parser.getNameLoc(), "symbol operand count and integer set symbol count must match"); // If the map has multiple results, make sure that we parsed the min/max @@ -276,11 +273,11 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) { if (map.getNumResults() > 1 && failedToParsedMinMax) { if (isUpper) return parser.emitError(attrLoc, - "upper loop bound affine map with multiple " - "results requires 'min' prefix"); + "upper loop bound affine map with multiple " + "results requires 'min' prefix"); return parser.emitError(attrLoc, - "lower loop bound affine mapwith " - "multiple results requires 'max' prefix"); + "lower loop bound affine mapwith " + "multiple results requires 'max' prefix"); } boundMaps.emplace_back(AffineMapAttr::get(map)); return success(); @@ -293,8 +290,7 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) { } }; - bool keepParsing; // Do we keep parsing loops/bounds? - do { + while (failed(parser.parseOptionalRParen())) { // Parse an input loop operand; operandParser.ParseOperand(LoopType::get(context), result.operands); parser.parseArrow(); @@ -315,26 +311,23 @@ ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) { // the entire "{operand bound}, {input_loop_operand}" sequence will // be parsed as an operand list. parser.parseOptionalComma(); - - // If we don't see a RParen token, we keep parsing. - keepParsing = failed(parser.parseOptionalRParen()); - } while (keepParsing); + } // At this point, there shouldn't be any operands left to parse. if (operandParser.hasOperandLeft()) return parser.emitError(parser.getCurrentLocation()); - result.addAttribute(KrnlIterateOp::getBoundsAttrName(), - builder.getArrayAttr(boundMaps)); + result.addAttribute( + KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps)); Region *region = result.addRegion(); - SmallVector inductionVarTypes(inductionVarRefs.size(), - builder.getIndexType()); + SmallVector inductionVarTypes( + inductionVarRefs.size(), builder.getIndexType()); if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes)) return failure(); // Ensure iterate region is closed off with krnl.terminate. - KrnlIterateOp::ensureTerminator(*region, parser.getBuilder(), - result.location); + KrnlIterateOp::ensureTerminator( + *region, parser.getBuilder(), result.location); return success(); } @@ -353,22 +346,20 @@ void print(OpAsmPrinter &p, KrnlReturnLoopsOp &op) { p.printOperands(op.operand_begin(), op.operand_end()); } -ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser, - OperationState &result) { +ParseResult parseKrnlReturnLoopsOp( + OpAsmParser &parser, OperationState &result) { // Parse the loops to return. SmallVector timestamp_dim_handlers; if (parser.parseOperandList(timestamp_dim_handlers) || parser.resolveOperands(timestamp_dim_handlers, - LoopType::get(result.getContext()), - result.operands)) + LoopType::get(result.getContext()), result.operands)) return failure(); return success(); } void KrnlEntryPointOp::build(mlir::Builder *builder, OperationState &state, - SymbolRefAttr funcAttr, IntegerAttr numInputs, - IntegerAttr numOutputs) { + SymbolRefAttr funcAttr, IntegerAttr numInputs, IntegerAttr numOutputs) { state.addAttribute(KrnlEntryPointOp::getEntryPointFuncAttrName(), funcAttr); state.addAttribute(KrnlEntryPointOp::getNumInputsAttrName(), numInputs); state.addAttribute(KrnlEntryPointOp::getNumOutputsAttrName(), numOutputs); diff --git a/src/Transform/LowerKrnl.cpp b/src/Transform/LowerKrnl.cpp index 9aca349..552f0f3 100644 --- a/src/Transform/LowerKrnl.cpp +++ b/src/Transform/LowerKrnl.cpp @@ -27,8 +27,8 @@ namespace { struct KrnlIterateOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(KrnlIterateOp iterateOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + KrnlIterateOp iterateOp, PatternRewriter &rewriter) const override { auto boundMapAttrs = iterateOp.getAttrOfType(KrnlIterateOp::getBoundsAttrName()) .getValue(); @@ -48,21 +48,21 @@ struct KrnlIterateOpLowering : public OpRewritePattern { map = boundMapAttrs[boundIdx + boundType] .cast() .getValue(); - operands.insert(operands.end(), operandItr, - operandItr + map.getNumInputs()); + operands.insert( + operands.end(), operandItr, operandItr + map.getNumInputs()); std::advance(operandItr, map.getNumInputs()); } nestedForOps.emplace_back(rewriter.create( iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap)); rewriter.setInsertionPoint(nestedForOps.back().getBody(), - nestedForOps.back().getBody()->begin()); + nestedForOps.back().getBody()->begin()); } // Replace induction variable references from those introduced by a // single krnl.iterate to those introduced by multiple affine.for // operations. - for (size_t i = 0; i < nestedForOps.size() - 1; i++) { + for (int64_t i = 0; i < (int64_t)nestedForOps.size() - 1; i++) { auto iterateIV = iterateOp.bodyRegion().front().getArgument(0); auto forIV = nestedForOps[i].getBody()->getArgument(0); iterateIV.replaceAllUsesWith(forIV); @@ -74,11 +74,21 @@ struct KrnlIterateOpLowering : public OpRewritePattern { while (iterateOp.bodyRegion().front().getNumArguments() > 1) iterateOp.bodyRegion().front().eraseArgument(0); - // Transfer krnl.iterate region to innermost for op. - auto innermostForOp = nestedForOps.back(); - innermostForOp.region().getBlocks().clear(); - rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(), - innermostForOp.region().end()); + if (nestedForOps.empty()) { + // If no loops are involved, simply move operations from within iterateOp + // body region to the parent region of iterateOp. + rewriter.setInsertionPoint(iterateOp); + iterateOp.bodyRegion().walk([&](Operation *op) { + if (!op->isKnownTerminator()) + op->replaceAllUsesWith(rewriter.clone(*op)); + }); + } else { + // Transfer krnl.iterate region to innermost for op. + auto innermostForOp = nestedForOps.back(); + innermostForOp.region().getBlocks().clear(); + rewriter.inlineRegionBefore(iterateOp.bodyRegion(), + innermostForOp.region(), innermostForOp.region().end()); + } rewriter.eraseOp(iterateOp); return success(); @@ -93,8 +103,8 @@ class KrnlTerminatorLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(KrnlTerminatorOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + KrnlTerminatorOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op); return success(); } @@ -108,8 +118,8 @@ class KrnlDefineLoopsLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(KrnlDefineLoopsOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + KrnlDefineLoopsOp op, PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); } @@ -123,8 +133,8 @@ class KrnlOptimizeLoopsLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(KrnlOptimizeLoopsOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + KrnlOptimizeLoopsOp op, PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); } @@ -158,8 +168,7 @@ void KrnlToAffineLoweringPass::runOnFunction() { OwningRewritePatternList patterns; patterns.insert( - &getContext()); + KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext()); if (failed(applyPartialConversion(getFunction(), target, patterns))) { signalPassFailure(); @@ -172,5 +181,5 @@ std::unique_ptr mlir::createLowerKrnlPass() { return std::make_unique(); } -static PassRegistration pass("lower-krnl", - "Lower Krnl dialect."); +static PassRegistration pass( + "lower-krnl", "Lower Krnl dialect."); diff --git a/test/mlir/conversion/krnl_to_affine.mlir b/test/mlir/conversion/krnl_to_affine.mlir new file mode 100644 index 0000000..bd49d3e --- /dev/null +++ b/test/mlir/conversion/krnl_to_affine.mlir @@ -0,0 +1,15 @@ +// RUN: onnx-mlir-opt --lower-krnl %s -split-input-file | FileCheck %s + +func @test_lower_degenerate_iterate(%arg0: memref) -> memref { + %0 = alloc() : memref + krnl.iterate() with () { + %1 = load %arg0[] : memref + store %1, %0[] : memref + } + return %0 : memref + // CHECK-LABEL: test_lower_degenerate_iterate + // CHECK-NEXT: [[ALLOC:%.+]] = alloc() : memref + // CHECK-NEXT: [[LOAD:%.+]] = load %{{.*}}[] : memref + // CHECK-NEXT: store [[LOAD]], [[ALLOC]][] : memref + // CHECK-NEXT: return [[ALLOC]] : memref +} \ No newline at end of file