From a6a40cf98977590c64101b2efb35704f6dfdc018 Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Thu, 19 Dec 2019 13:27:15 -0500 Subject: [PATCH] Format Key Files using LLVM Style (#403) * format using llvm style * merge and format --- src/builder/frontend_dialect_transformer.cpp | 22 +- src/compiler/dialect/krnl/krnl_ops.cpp | 132 +++++---- src/compiler/dialect/krnl/krnl_ops.hpp | 16 +- src/compiler/pass/lower_frontend_to_krnl.cpp | 297 ++++++++++--------- src/compiler/transform/lower_krnl.cpp | 47 +-- src/main.cpp | 14 +- 6 files changed, 277 insertions(+), 251 deletions(-) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 1d23802..3f5d8d9 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -382,7 +382,7 @@ private: auto end = default_str.find(",", start + 1); if (end == std::string::npos) { end = default_str.find("}", start + 1); - if (end != std::string::npos && end > start+1) { + if (end != std::string::npos && end > start + 1) { r.push_back(std::stoi(default_str.substr(start + 1, end))); } break; @@ -401,7 +401,7 @@ private: auto end = default_str.find(",", start + 1); if (end == std::string::npos) { end = default_str.find("}", start + 1); - if (end != std::string::npos && end > start+1) { + if (end != std::string::npos && end > start + 1) { r.push_back(std::stof(default_str.substr(start + 1, end))); } break; @@ -420,7 +420,7 @@ private: auto end = default_str.find(",", start + 1); if (end == std::string::npos) { end = default_str.find("}", start + 1); - if (end != std::string::npos && end > start+1) { + if (end != std::string::npos && end > start + 1) { r.push_back(default_str.substr(start + 1, end)); } break; @@ -529,18 +529,19 @@ private: } std::vector attributes; - //for (auto [attr_name, attr_type, attr_default] : attrs) { - for (auto oneAttr: attrs) { + // for (auto [attr_name, attr_type, attr_default] : attrs) { + for (auto oneAttr : attrs) { std::string attr_name; std::string attr_type; std::string attr_default; - std::tie (attr_name, attr_type, attr_default) = oneAttr; + std::tie(attr_name, attr_type, attr_default) = oneAttr; if (attr_type != "") { auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); attributes.push_back(attr); } else { // TODO: the attributes need special handling - //std::cout << "missing " << node.op_type() << " " << attr_name << std::endl; + // std::cout << "missing " << node.op_type() << " " << attr_name << + // std::endl; } } @@ -575,17 +576,18 @@ private: } std::vector attributes; - for (auto oneAttr: attrs) { + for (auto oneAttr : attrs) { std::string attr_name; std::string attr_type; std::string attr_default; - std::tie (attr_name, attr_type, attr_default) = oneAttr; + std::tie(attr_name, attr_type, attr_default) = oneAttr; if (attr_type != "") { auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); attributes.push_back(attr); } else { // TODO: the attributes need special handling - //std::cout << "missing " << node.op_type() << " " << attr_name << std::endl; + // std::cout << "missing " << node.op_type() << " " << attr_name << + // std::endl; } } diff --git a/src/compiler/dialect/krnl/krnl_ops.cpp b/src/compiler/dialect/krnl/krnl_ops.cpp index ee71925..6454936 100644 --- a/src/compiler/dialect/krnl/krnl_ops.cpp +++ b/src/compiler/dialect/krnl/krnl_ops.cpp @@ -9,8 +9,6 @@ #include #include -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallBitVector.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Block.h" @@ -23,6 +21,8 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" #include "src/compiler/dialect/krnl/krnl_helper.hpp" @@ -31,7 +31,7 @@ using namespace mlir; namespace mlir { -KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context) +KrnlOpsDialect::KrnlOpsDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST @@ -44,29 +44,30 @@ KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context) // KrnlDefineLoopsOp //===----------------------------------------------------------------------===// -void KrnlDefineLoopsOp::build( - Builder* builder, OperationState& result, int64_t num_loops) { +void KrnlDefineLoopsOp::build(Builder *builder, OperationState &result, + int64_t num_loops) { // Create the same number of dimension handlers as the number of // dimensions in the associated integer set. result.types.append(num_loops, LoopType::get(builder->getContext())); - result.addAttribute( - getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops)); + result.addAttribute(getNumLoopsAttrName(), + builder->getI32IntegerAttr(num_loops)); } -void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) { +void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) { auto numLoopAttr = op.getAttrOfType(KrnlDefineLoopsOp::getNumLoopsAttrName()); p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue(); } -ParseResult parseKrnlDefineLoopsOp( - OpAsmParser& parser, OperationState& result) { +ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser, + OperationState &result) { // Parse the attribute indicating number of loops defined. IntegerAttr numLoops; - auto& builder = parser.getBuilder(); + auto &builder = parser.getBuilder(); auto intType = builder.getIntegerType(64); if (parser.parseAttribute(numLoops, intType, - KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes)) + KrnlDefineLoopsOp::getNumLoopsAttrName(), + result.attributes)) return failure(); auto loopTypes = llvm::SmallVector( @@ -79,29 +80,29 @@ ParseResult parseKrnlDefineLoopsOp( // KrnlOptimizeLoopsOp //===----------------------------------------------------------------------===// -void KrnlOptimizeLoopsOp::build( - Builder* builder, OperationState& result, int num_optimized_loops) { - result.types.append( - num_optimized_loops, LoopType::get(builder->getContext())); +void KrnlOptimizeLoopsOp::build(Builder *builder, OperationState &result, + int num_optimized_loops) { + result.types.append(num_optimized_loops, + LoopType::get(builder->getContext())); // Create a region and a block for the body. // Schedule intrinsics will be placed into this region. - Region* region = result.addRegion(); - auto* body = new Block(); + Region *region = result.addRegion(); + auto *body = new Block(); region->push_back(body); } -void print(OpAsmPrinter& p, KrnlOptimizeLoopsOp& op) { +void print(OpAsmPrinter &p, KrnlOptimizeLoopsOp &op) { p << "krnl.optimize_loops "; p.printRegion(op.region(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); + /*printBlockTerminators=*/true); p << " : "; p.printFunctionalType(op); } -ParseResult parseKrnlOptimizeLoopsOp( - OpAsmParser& parser, OperationState& result) { +ParseResult parseKrnlOptimizeLoopsOp(OpAsmParser &parser, + OperationState &result) { // Parse the schedule body region. - Region* region = result.addRegion(); + Region *region = result.addRegion(); if (parser.parseRegion(*region, llvm::None, llvm::None)) return failure(); @@ -142,21 +143,22 @@ ParseResult parseKrnlOptimizeLoopsOp( * Then the bounds will be parsed as: * %i0 = 10 to N : %i1 = M to 20 */ -void KrnlIterateOp::build(Builder* builder, OperationState& result, - KrnlIterateOperandPack operandPack) { +void KrnlIterateOp::build(Builder *builder, OperationState &result, + KrnlIterateOperandPack operandPack) { // Record optimized loops and the number of such loops. result.addOperands(operandPack.getOperands()); - result.addAttribute( - KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes()); + result.addAttribute(KrnlIterateOp::getBoundsAttrName(), + operandPack.getAttributes()); - result.addAttribute(getNumOptimizedLoopsAttrName(), + result.addAttribute( + getNumOptimizedLoopsAttrName(), builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops())); // Create a region and a block for the body. The arguments of the region are // the loop induction variables; there can be multiple induction variables // associated with the same krnl.iterate operation. - Region* bodyRegion = result.addRegion(); - auto* body = new Block(); + Region *bodyRegion = result.addRegion(); + auto *body = new Block(); auto body_args = llvm::SmallVector( operandPack.getNumInputLoops(), IndexType::get(builder->getContext())); body->addArguments(body_args); @@ -165,7 +167,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result, ensureTerminator(*bodyRegion, *builder, result.location); } -void print(OpAsmPrinter& p, KrnlIterateOp& op) { +void print(OpAsmPrinter &p, KrnlIterateOp &op) { p << "krnl.iterate("; // Print optimized loops: auto numOptimizedLoops = op.getNumOptimizedLoops(); @@ -180,7 +182,7 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) { auto operandItr = op.operand_begin() + numOptimizedLoops; std::string delimiter; - for (auto& var : inductionVars) { + for (auto &var : inductionVars) { p << delimiter; p.printOperand(*operandItr++); p << " -> "; @@ -194,25 +196,26 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) { p << ")"; p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + /*printBlockTerminators=*/false); } -ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { +ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) { auto builder = parser.getBuilder(); auto context = builder.getContext(); onnf::KrnlDialectOperandParser operandParser(parser); // Parse optimized loops: SmallVector optimizedLoopRefs; - if (parser.parseOperandList( - optimizedLoopRefs, OpAsmParser::Delimiter::Paren) || + if (parser.parseOperandList(optimizedLoopRefs, + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(optimizedLoopRefs, - LoopType::get(result.getContext()), result.operands)) + LoopType::get(result.getContext()), + result.operands)) return failure(); // Record how many optimized loops did we parse. result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(), - builder.getI64IntegerAttr(optimizedLoopRefs.size())); + builder.getI64IntegerAttr(optimizedLoopRefs.size())); // Parse input loops and their lower and upper bounds. SmallVector inductionVarRefs; @@ -222,16 +225,16 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { return failure(); // A function to parse a lower or upper bound. - auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps]( - bool isUpper) -> ParseResult { + auto parseBound = [&result, &builder, &parser, &operandParser, + &boundMaps](bool isUpper) -> ParseResult { // 'min' / 'max' prefixes are generally syntactic sugar, but are required if // the map has multiple results. bool failedToParsedMinMax = failed(parser.parseOptionalKeyword(isUpper ? "min" : "max")); // Try parse an SSA operand. - if (succeeded(operandParser.ParseOptionalOperand( - builder.getIndexType(), result.operands))) { + if (succeeded(operandParser.ParseOptionalOperand(builder.getIndexType(), + result.operands))) { AffineMap map = builder.getSymbolIdentityMap(); boundMaps.emplace_back(AffineMapAttr::get(map)); return success(); @@ -243,8 +246,8 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { llvm::SMLoc attrLoc = parser.getCurrentLocation(); Attribute boundAttr; llvm::SmallVector tempBoundAttrContainer; - if (parser.parseAttribute( - boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer)) + if (parser.parseAttribute(boundAttr, builder.getIndexType(), "temp", + tempBoundAttrContainer)) return failure(); if (auto affineMapAttr = boundAttr.dyn_cast()) { @@ -255,13 +258,15 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { auto map = affineMapAttr.getValue(); if (map.getNumDims() != numDims) - return parser.emitError(parser.getNameLoc(), + return parser.emitError( + parser.getNameLoc(), "dim operand count and integer set dim count must match"); unsigned numDimAndSymbolOperands = result.operands.size() - currentNumOperands; if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) - return parser.emitError(parser.getNameLoc(), + return parser.emitError( + parser.getNameLoc(), "symbol operand count and integer set symbol count must match"); // If the map has multiple results, make sure that we parsed the min/max @@ -269,11 +274,11 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { if (map.getNumResults() > 1 && failedToParsedMinMax) { if (isUpper) return parser.emitError(attrLoc, - "upper loop bound affine map with multiple " - "results requires 'min' prefix"); + "upper loop bound affine map with multiple " + "results requires 'min' prefix"); return parser.emitError(attrLoc, - "lower loop bound affine mapwith " - "multiple results requires 'max' prefix"); + "lower loop bound affine mapwith " + "multiple results requires 'max' prefix"); } boundMaps.emplace_back(AffineMapAttr::get(map)); return success(); @@ -286,7 +291,7 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { } }; - bool keepParsing; // Do we keep parsing loops/bounds? + bool keepParsing; // Do we keep parsing loops/bounds? do { // Parse an input loop operand; operandParser.ParseOperand(LoopType::get(context), result.operands); @@ -316,18 +321,18 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { // At this point, there shouldn't be any operands left to parse. if (operandParser.hasOperandLeft()) return parser.emitError(parser.getCurrentLocation()); - result.addAttribute( - KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps)); + result.addAttribute(KrnlIterateOp::getBoundsAttrName(), + builder.getArrayAttr(boundMaps)); - Region* region = result.addRegion(); - SmallVector inductionVarTypes( - inductionVarRefs.size(), builder.getIndexType()); + Region *region = result.addRegion(); + SmallVector inductionVarTypes(inductionVarRefs.size(), + builder.getIndexType()); if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes)) return failure(); // Ensure iterate region is closed off with krnl.terminate. - KrnlIterateOp::ensureTerminator( - *region, parser.getBuilder(), result.location); + KrnlIterateOp::ensureTerminator(*region, parser.getBuilder(), + result.location); return success(); } @@ -341,18 +346,19 @@ static LogicalResult verify(KrnlIterateOp op) { // KrnlReturnLoopsOp //===----------------------------------------------------------------------===// -void print(OpAsmPrinter& p, KrnlReturnLoopsOp& op) { +void print(OpAsmPrinter &p, KrnlReturnLoopsOp &op) { p << "krnl.return_loops "; p.printOperands(op.operand_begin(), op.operand_end()); } -ParseResult parseKrnlReturnLoopsOp( - OpAsmParser& parser, OperationState& result) { +ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser, + OperationState &result) { // Parse the loops to return. SmallVector timestamp_dim_handlers; if (parser.parseOperandList(timestamp_dim_handlers) || parser.resolveOperands(timestamp_dim_handlers, - LoopType::get(result.getContext()), result.operands)) + LoopType::get(result.getContext()), + result.operands)) return failure(); return success(); @@ -360,4 +366,4 @@ ParseResult parseKrnlReturnLoopsOp( #define GET_OP_CLASSES #include "src/compiler/krnl.cpp.inc" -} // namespace mlir +} // namespace mlir diff --git a/src/compiler/dialect/krnl/krnl_ops.hpp b/src/compiler/dialect/krnl/krnl_ops.hpp index 89d4587..4b9fe4e 100644 --- a/src/compiler/dialect/krnl/krnl_ops.hpp +++ b/src/compiler/dialect/krnl/krnl_ops.hpp @@ -19,12 +19,12 @@ namespace mlir { class KrnlOpsDialect : public Dialect { - public: - KrnlOpsDialect(MLIRContext* context); +public: + KrnlOpsDialect(MLIRContext *context); static StringRef getDialectNamespace() { return "krnl"; } /// Parse a type registered to this dialect. - Type parseType(DialectAsmParser& parser) const override { + Type parseType(DialectAsmParser &parser) const override { if (succeeded(parser.parseOptionalKeyword("loop"))) return LoopType::get(parser.getBuilder().getContext()); @@ -32,15 +32,15 @@ class KrnlOpsDialect : public Dialect { } /// Print a type registered to this dialect. - void printType(Type type, DialectAsmPrinter& os) const override { + void printType(Type type, DialectAsmPrinter &os) const override { switch (type.getKind()) { - case KrnlTypes::Loop: - os << "loop"; - return; + case KrnlTypes::Loop: + os << "loop"; + return; } } }; #define GET_OP_CLASSES #include "src/compiler/krnl.hpp.inc" -} // namespace mlir +} // namespace mlir diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index 9a17df0..c1ca650 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -11,17 +11,16 @@ #include -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Sequence.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" #include "src/compiler/dialect/krnl/krnl_helper.hpp" #include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp" - #include "src/compiler/pass/passes.hpp" using namespace mlir; @@ -98,7 +97,7 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc, // Make sure to allocate at the beginning of the block if // all dimensions are known. - auto* parentBlock = alloc.getOperation()->getBlock(); + auto *parentBlock = alloc.getOperation()->getBlock(); if (hasAllConstantDimensions(type)) alloc.getOperation()->moveBefore(&parentBlock->front()); @@ -113,17 +112,17 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc, // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be // inserted. -static bool checkInsertDealloc(Operation* currentOp) { +static bool checkInsertDealloc(Operation *currentOp) { auto parentBlock = currentOp->getBlock(); bool insertDealloc = true; parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { assert(currentOp->getNumResults() < 2 && - "No more than one result supported (for now)."); + "No more than one result supported (for now)."); // If there is at least one result to investigate. if (currentOp->getNumResults() > 0) { auto result = currentOp->getResult(0); - for (const auto& operand : op.getOperands()) + for (const auto &operand : op.getOperands()) if (operand == result) insertDealloc = false; } @@ -148,7 +147,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // Get run-time dimension information for unknown dimensions used for // broadcasting. -std::map > +std::map> getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, ArrayRef operands) { auto memRefShape = memRefType.getShape(); @@ -196,15 +195,15 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, // given operand. std::vector getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, - ArrayRef loopIVs, Value *operand, - std::map broadcastedDims) { + ArrayRef loopIVs, Value *operand, + std::map broadcastedDims) { // `operand` must has a ranked type. This should have been checked by the // shape inference pass. auto operandShape = operand->getType().cast().getShape(); auto rank = operandShape.size(); auto loopCount = loopIVs.size(); - std::vector newLoopIVs; + std::vector newLoopIVs; for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { auto dimIdx = rank - 1 - reversedIdx; auto loopIdx = loopCount - 1 - reversedIdx; @@ -218,8 +217,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. auto zero = rewriter.create(loc, 0); - auto idx = rewriter.create(loc, broadcastedDims[dimIdx], - zero, loopIVs[loopIdx]); + auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, + loopIVs[loopIdx]); newLoopIVs.insert(newLoopIVs.begin(), idx); } else { // Non-broadcasted dimension @@ -260,26 +259,26 @@ struct ScalarOp { template <> struct ScalarOp { - using FOp = AndOp; // not use + using FOp = AndOp; // not use using IOp = AndOp; }; template <> struct ScalarOp { - using FOp = OrOp; // not use + using FOp = OrOp; // not use using IOp = OrOp; }; template <> struct ScalarOp { - using FOp = XOrOp; // not use + using FOp = XOrOp; // not use using IOp = XOrOp; }; template <> struct ScalarOp { using FOp = ExpOp; - using IOp = ExpOp; // not use + using IOp = ExpOp; // not use }; template <> @@ -297,18 +296,19 @@ using ScalarIOp = typename ScalarOp::IOp; // Scalar unary ops for lowering to Krnl dialect. //===----------------------------------------------------------------------===// template -Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { /* Lower UnaryOp to Ops in the Standard dialect. */ auto loc = op->getLoc(); Type element_type = operands.front()->getType(); if (element_type.isa()) { - return rewriter.create>( - loc, result_types, operands, mlir::None); + return rewriter.create>(loc, result_types, operands, + mlir::None); } else if (element_type.isa()) { - return rewriter.create>( - loc, result_types, operands, mlir::None); + return rewriter.create>(loc, result_types, operands, + mlir::None); } else { emitError(loc, "unsupported element type"); return nullptr; @@ -319,13 +319,14 @@ Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, // Scalar unary ops for lowering ONNXTanhOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, - ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto neg = rewriter.create(loc, zero, operand); @@ -333,7 +334,7 @@ Value* mapToLowerScalarOp(Operation* op, auto negExp = rewriter.create(loc, neg); auto result = rewriter.create(loc, rewriter.create(loc, exp, negExp), - rewriter.create(loc, exp, negExp)); + rewriter.create(loc, exp, negExp)); return result; } @@ -342,13 +343,14 @@ Value* mapToLowerScalarOp(Operation* op, // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, - ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); @@ -365,13 +367,14 @@ Value* mapToLowerScalarOp(Operation* op, // Scalar unary ops for lowering ONNXCoshOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, - ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); @@ -388,13 +391,14 @@ Value* mapToLowerScalarOp(Operation* op, // Scalar unary ops for lowering ONNXSigmoidOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, - ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); @@ -410,9 +414,9 @@ Value* mapToLowerScalarOp(Operation* op, // Scalar unary ops for lowering ONNXHardSigmoidOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, - ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { // %Y = AddFOp(MulFOp(alpha, %X), beta) // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Y, @@ -421,7 +425,7 @@ Value* mapToLowerScalarOp(Operation* op, // %Z, // Constant 1) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto alphaAttr = op->getAttrOfType("HardSigmoid.alpha"); auto betaAttr = op->getAttrOfType("HardSigmoid.beta"); @@ -446,13 +450,14 @@ Value* mapToLowerScalarOp(Operation* op, // Scalar unary ops for lowering ONNXEluOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // %X) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto alphaAttr = op->getAttrOfType("Elu.alpha"); auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); @@ -461,9 +466,10 @@ Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, auto exp = rewriter.create(loc, operand); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); - auto result = rewriter.create(loc, lessThanZero, - rewriter.create( - loc, alpha, rewriter.create(loc, exp, one)), + auto result = rewriter.create( + loc, lessThanZero, + rewriter.create(loc, alpha, + rewriter.create(loc, exp, one)), operand); return result; @@ -473,14 +479,15 @@ Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, // Scalar unary ops for lowering ONNXReluOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, - ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ConstantOp 0, // %X) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto lessThanZero = @@ -494,14 +501,15 @@ Value* mapToLowerScalarOp(Operation* op, // Scalar unary ops for lowering ONNXLeakyReluOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, - ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter& rewriter) { +Value * +mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, %X), // %X) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto alphaAttr = op->getAttrOfType("LeakyRelu.alpha"); auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); @@ -518,16 +526,17 @@ Value* mapToLowerScalarOp(Operation* op, // Scalar unary ops for lowering ONNXSeluOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, - ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // MulFOp(gamma, %X), // MulFOp(gamma, // SubFOp(MulFOp(alpha, ExpOp(%X)), // alpha))) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto alphaAttr = op->getAttrOfType("Selu.alpha"); auto gammaAttr = op->getAttrOfType("Selu.gamma"); @@ -537,9 +546,10 @@ Value* mapToLowerScalarOp(Operation* op, auto exp = rewriter.create(loc, operand); auto greaterThanZero = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); - auto select = rewriter.create(loc, greaterThanZero, operand, - rewriter.create( - loc, rewriter.create(loc, alpha, exp), alpha)); + auto select = rewriter.create( + loc, greaterThanZero, operand, + rewriter.create(loc, rewriter.create(loc, alpha, exp), + alpha)); auto result = rewriter.create(loc, gamma, select); return result; @@ -549,11 +559,13 @@ Value* mapToLowerScalarOp(Operation* op, // Scalar unary ops for lowering ONNXReciprocalOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter& rewriter) { +Value * +mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); - Value* operand = operands[0]; + Value *operand = operands[0]; auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); auto result = rewriter.create(loc, one, operand); @@ -565,14 +577,15 @@ Value* mapToLowerScalarOp(Operation* op, ArrayRef result // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); - Value* lhs = operands[0]; - Value* rhs = operands[1]; + Value *lhs = operands[0]; + Value *rhs = operands[1]; auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; @@ -582,14 +595,15 @@ Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, // Scalar unary ops for lowering ONNXMinOp //===----------------------------------------------------------------------===// template <> -Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter& rewriter) { +Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); - Value* lhs = operands[0]; - Value* rhs = operands[1]; + Value *lhs = operands[0]; + Value *rhs = operands[1]; auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; @@ -599,10 +613,11 @@ Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, //===----------------------------------------------------------------------===// template struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { - ONNXElementwiseUnaryOpLowering(MLIRContext* ctx) + ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. @@ -618,14 +633,14 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? - Value* alloc; + Value *alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else - alloc = insertAllocAndDealloc( - memRefType, loc, rewriter, insertDealloc, {operands[0]}); + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); // Number of loops auto memRefShape = memRefType.getShape(); @@ -633,7 +648,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { // Define loops. auto loopsOp = rewriter.create(loc, rank); - std::vector originalLoops; + std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); @@ -641,12 +656,12 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); - std::vector optimizedLoops; + std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } - Block& optimizationBlock = optimizedLoopsOp.region().front(); + Block &optimizationBlock = optimizedLoopsOp.region().front(); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. @@ -664,7 +679,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { } auto iterateOp = rewriter.create(loc, pack); - Block& iterationBlock = iterateOp.bodyRegion().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: @@ -681,7 +696,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: - SmallVector loopIVs; + SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); @@ -701,10 +716,11 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { //===----------------------------------------------------------------------===// template struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { - ONNXElementwiseVariadicOpLowering(MLIRContext* ctx) + ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. @@ -715,7 +731,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); - Value* alloc; + Value *alloc; bool insertDealloc = checkInsertDealloc(op); // If the output has a dynamic dimension, we compute its dimension at // runtime by using dimensions from the operands. @@ -725,8 +741,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else - alloc = insertAllocAndDealloc( - memRefType, loc, rewriter, insertDealloc, operands); + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + operands); // Number of loops auto memRefShape = memRefType.getShape(); @@ -734,7 +750,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { // Define loops. auto loopsOp = rewriter.create(loc, rank); - std::vector originalLoops; + std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); @@ -742,12 +758,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); - std::vector optimizedLoops; + std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } - Block& optimizationBlock = optimizedLoopsOp.region().front(); + Block &optimizationBlock = optimizedLoopsOp.region().front(); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. @@ -770,7 +786,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { getBroadcastedDimInfo(loc, rewriter, memRefType, operands); auto iterateOp = rewriter.create(loc, pack); - Block& iterationBlock = iterateOp.bodyRegion().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: @@ -786,7 +802,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: - SmallVector loopIVs; + SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); @@ -812,35 +828,36 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { }; struct ONNXReshapeOpLowering : public ConversionPattern { - ONNXReshapeOpLowering(MLIRContext* ctx) + ONNXReshapeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); - Value* alloc; + Value *alloc; // Compute size in bytes. - Value* tensorSize = rewriter.create(loc, - rewriter.getIntegerAttr( - rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); + Value *tensorSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); } else { auto memRefShape = memRefType.getShape(); - SmallVector allocOperands; + SmallVector allocOperands; for (int i = 0; i < memRefShape.size(); ++i) { // The shape array can always be used to construct shape information of // the result. - Value* index = rewriter.create( + Value *index = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); - Value* loadedVal = rewriter.create(loc, operands[1], index); - Value* int64LoadedVal = rewriter.create( + Value *loadedVal = rewriter.create(loc, operands[1], index); + Value *int64LoadedVal = rewriter.create( loc, loadedVal, rewriter.getIntegerType(64)); tensorSize = rewriter.create(loc, tensorSize, int64LoadedVal); allocOperands.push_back(rewriter.create( @@ -851,7 +868,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern { // Make sure to allocate at the beginning of the block if // all dimensions are known. - auto* parentBlock = allocateMemref.getOperation()->getBlock(); + auto *parentBlock = allocateMemref.getOperation()->getBlock(); if (insertDealloc) { auto dealloc = rewriter.create(loc, allocateMemref); dealloc.getOperation()->moveBefore(&parentBlock->back()); @@ -874,7 +891,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern { struct TensorTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; - LogicalResult convertType(Type t, SmallVectorImpl& results) override { + LogicalResult convertType(Type t, SmallVectorImpl &results) override { if (auto tensor_type = t.dyn_cast()) { results.push_back(convertTensorToMemRef(tensor_type)); return success(); @@ -889,12 +906,12 @@ struct TensorTypeConverter : public TypeConverter { /// inputs. Once unranked results can be handled gracefully this /// override needs to be removed in favour of the original MLIR one.] bool isSignatureLegal(FunctionType funcType) { - return llvm::all_of( - funcType.getInputs(), [this](Type type) { return isLegal(type); }); + return llvm::all_of(funcType.getInputs(), + [this](Type type) { return isLegal(type); }); } }; -} // end anonymous namespace. +} // end anonymous namespace. //===----------------------------------------------------------------------===// // Frontend to Krnl Dialect lowering pass @@ -906,7 +923,7 @@ struct FrontendToKrnlLoweringPass : public ModulePass { void runOnModule() final; }; -} // end anonymous namespace. +} // end anonymous namespace. void FrontendToKrnlLoweringPass::runOnModule() { auto module = getModule(); @@ -943,32 +960,32 @@ void FrontendToKrnlLoweringPass::runOnModule() { // Type conversion for function signatures. // Call MLIR FuncOp signature conversion when result type is // a ranked tensor. - populateFuncOpTypeConversionPattern( - patterns, &getContext(), tensor_to_memref_converter); + populateFuncOpTypeConversionPattern(patterns, &getContext(), + tensor_to_memref_converter); // Frontent operation lowering. patterns.insert, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXReshapeOpLowering>(&getContext()); + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXReshapeOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` @@ -981,5 +998,5 @@ std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); } -static PassRegistration pass( - "lower-frontend", "Lower frontend ops to Krnl dialect."); +static PassRegistration + pass("lower-frontend", "Lower frontend ops to Krnl dialect."); diff --git a/src/compiler/transform/lower_krnl.cpp b/src/compiler/transform/lower_krnl.cpp index 36abb1c..36f8f56 100644 --- a/src/compiler/transform/lower_krnl.cpp +++ b/src/compiler/transform/lower_krnl.cpp @@ -17,8 +17,8 @@ namespace { struct KrnlIterateOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite( - KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override { + PatternMatchResult matchAndRewrite(KrnlIterateOp iterateOp, + PatternRewriter &rewriter) const override { auto boundMapAttrs = iterateOp.getAttrOfType(KrnlIterateOp::getBoundsAttrName()) .getValue(); @@ -30,23 +30,23 @@ struct KrnlIterateOpLowering : public OpRewritePattern { operandItr++; // Organize operands into lower/upper bounds in affine.for ready formats. - SmallVector lbOperands, ubOperands; + SmallVector lbOperands, ubOperands; AffineMap lbMap, ubMap; for (int boundType = 0; boundType < 2; boundType++) { - auto& operands = boundType == 0 ? lbOperands : ubOperands; - auto& map = boundType == 0 ? lbMap : ubMap; + auto &operands = boundType == 0 ? lbOperands : ubOperands; + auto &map = boundType == 0 ? lbMap : ubMap; map = boundMapAttrs[boundIdx + boundType] .cast() .getValue(); - operands.insert( - operands.end(), operandItr, operandItr + map.getNumInputs()); + operands.insert(operands.end(), operandItr, + operandItr + map.getNumInputs()); std::advance(operandItr, map.getNumInputs()); } nestedForOps.emplace_back(rewriter.create( iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap)); rewriter.setInsertionPoint(nestedForOps.back().getBody(), - nestedForOps.back().getBody()->begin()); + nestedForOps.back().getBody()->begin()); } // Replace induction variable references from those introduced by a @@ -68,7 +68,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern { auto innermostForOp = nestedForOps.back(); innermostForOp.region().getBlocks().clear(); rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(), - innermostForOp.region().end()); + innermostForOp.region().end()); rewriter.eraseOp(iterateOp); return matchSuccess(); @@ -80,11 +80,11 @@ struct KrnlIterateOpLowering : public OpRewritePattern { //===----------------------------------------------------------------------===// class KrnlTerminatorLowering : public OpRewritePattern { - public: +public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite( - KrnlTerminatorOp op, PatternRewriter& rewriter) const override { + PatternMatchResult matchAndRewrite(KrnlTerminatorOp op, + PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op); return matchSuccess(); } @@ -95,11 +95,11 @@ class KrnlTerminatorLowering : public OpRewritePattern { //===----------------------------------------------------------------------===// class KrnlDefineLoopsLowering : public OpRewritePattern { - public: +public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite( - KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override { + PatternMatchResult matchAndRewrite(KrnlDefineLoopsOp op, + PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return matchSuccess(); } @@ -110,11 +110,11 @@ class KrnlDefineLoopsLowering : public OpRewritePattern { //===----------------------------------------------------------------------===// class KrnlOptimizeLoopsLowering : public OpRewritePattern { - public: +public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite( - KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override { + PatternMatchResult matchAndRewrite(KrnlOptimizeLoopsOp op, + PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return matchSuccess(); } @@ -132,7 +132,7 @@ struct KrnlToAffineLoweringPass : public FunctionPass { void runOnFunction() final; }; -} // end anonymous namespace. +} // end anonymous namespace. void KrnlToAffineLoweringPass::runOnFunction() { auto function = getFunction(); @@ -146,17 +146,18 @@ void KrnlToAffineLoweringPass::runOnFunction() { OwningRewritePatternList patterns; patterns.insert(&getContext()); + KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>( + &getContext()); if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } -} // namespace +} // namespace std::unique_ptr mlir::createLowerKrnlPass() { return std::make_unique(); } -static PassRegistration pass( - "lower-krnl", "Lower Krnl dialect."); \ No newline at end of file +static PassRegistration pass("lower-krnl", + "Lower Krnl dialect."); \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 3767280..4f1e946 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -76,7 +76,7 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext& context, } } -int main(int ac, char* av[]) { +int main(int ac, char *av[]) { namespace po = boost::program_options; po::options_description desc("ONNF available options"); @@ -91,8 +91,8 @@ int main(int ac, char* av[]) { po::positional_options_description p; p.add("onnx-model", -1); po::variables_map vm; - po::store( - po::command_line_parser(ac, av).options(desc).positional(p).run(), vm); + po::store(po::command_line_parser(ac, av).options(desc).positional(p).run(), + vm); // TODO: allow multiple input files assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!"); @@ -137,10 +137,10 @@ int main(int ac, char* av[]) { // Write LLVM bitcode to disk. std::error_code EC; - llvm::raw_fd_ostream moduleBitcodeStream( - "model.bc", EC, llvm::sys::fs::F_None); - llvm::WriteBitcodeToFile( - *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); + llvm::raw_fd_ostream moduleBitcodeStream("model.bc", EC, + llvm::sys::fs::F_None); + llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), + moduleBitcodeStream); moduleBitcodeStream.flush(); return 0;