Format Key Files using LLVM Style (#403)

* format using llvm style

* merge and format
This commit is contained in:
Tian Jin 2019-12-19 13:27:15 -05:00 committed by Tian Jin
parent 06a968d4a1
commit a6a40cf989
6 changed files with 277 additions and 251 deletions

View File

@ -382,7 +382,7 @@ private:
auto end = default_str.find(",", start + 1); auto end = default_str.find(",", start + 1);
if (end == std::string::npos) { if (end == std::string::npos) {
end = default_str.find("}", start + 1); end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start+1) { if (end != std::string::npos && end > start + 1) {
r.push_back(std::stoi(default_str.substr(start + 1, end))); r.push_back(std::stoi(default_str.substr(start + 1, end)));
} }
break; break;
@ -401,7 +401,7 @@ private:
auto end = default_str.find(",", start + 1); auto end = default_str.find(",", start + 1);
if (end == std::string::npos) { if (end == std::string::npos) {
end = default_str.find("}", start + 1); end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start+1) { if (end != std::string::npos && end > start + 1) {
r.push_back(std::stof(default_str.substr(start + 1, end))); r.push_back(std::stof(default_str.substr(start + 1, end)));
} }
break; break;
@ -420,7 +420,7 @@ private:
auto end = default_str.find(",", start + 1); auto end = default_str.find(",", start + 1);
if (end == std::string::npos) { if (end == std::string::npos) {
end = default_str.find("}", start + 1); end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start+1) { if (end != std::string::npos && end > start + 1) {
r.push_back(default_str.substr(start + 1, end)); r.push_back(default_str.substr(start + 1, end));
} }
break; break;
@ -529,18 +529,19 @@ private:
} }
std::vector<mlir::NamedAttribute> attributes; std::vector<mlir::NamedAttribute> attributes;
//for (auto [attr_name, attr_type, attr_default] : attrs) { // for (auto [attr_name, attr_type, attr_default] : attrs) {
for (auto oneAttr: attrs) { for (auto oneAttr : attrs) {
std::string attr_name; std::string attr_name;
std::string attr_type; std::string attr_type;
std::string attr_default; std::string attr_default;
std::tie (attr_name, attr_type, attr_default) = oneAttr; std::tie(attr_name, attr_type, attr_default) = oneAttr;
if (attr_type != "") { if (attr_type != "") {
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
attributes.push_back(attr); attributes.push_back(attr);
} else { } else {
// TODO: the attributes need special handling // TODO: the attributes need special handling
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl; // std::cout << "missing " << node.op_type() << " " << attr_name <<
// std::endl;
} }
} }
@ -575,17 +576,18 @@ private:
} }
std::vector<mlir::NamedAttribute> attributes; std::vector<mlir::NamedAttribute> attributes;
for (auto oneAttr: attrs) { for (auto oneAttr : attrs) {
std::string attr_name; std::string attr_name;
std::string attr_type; std::string attr_type;
std::string attr_default; std::string attr_default;
std::tie (attr_name, attr_type, attr_default) = oneAttr; std::tie(attr_name, attr_type, attr_default) = oneAttr;
if (attr_type != "") { if (attr_type != "") {
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
attributes.push_back(attr); attributes.push_back(attr);
} else { } else {
// TODO: the attributes need special handling // TODO: the attributes need special handling
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl; // std::cout << "missing " << node.op_type() << " " << attr_name <<
// std::endl;
} }
} }

View File

@ -9,8 +9,6 @@
#include <iostream> #include <iostream>
#include <queue> #include <queue>
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
@ -23,6 +21,8 @@
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "src/compiler/dialect/krnl/krnl_helper.hpp" #include "src/compiler/dialect/krnl/krnl_helper.hpp"
@ -31,7 +31,7 @@
using namespace mlir; using namespace mlir;
namespace mlir { namespace mlir {
KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context) KrnlOpsDialect::KrnlOpsDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) { : Dialect(getDialectNamespace(), context) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
@ -44,29 +44,30 @@ KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context)
// KrnlDefineLoopsOp // KrnlDefineLoopsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void KrnlDefineLoopsOp::build( void KrnlDefineLoopsOp::build(Builder *builder, OperationState &result,
Builder* builder, OperationState& result, int64_t num_loops) { int64_t num_loops) {
// Create the same number of dimension handlers as the number of // Create the same number of dimension handlers as the number of
// dimensions in the associated integer set. // dimensions in the associated integer set.
result.types.append(num_loops, LoopType::get(builder->getContext())); result.types.append(num_loops, LoopType::get(builder->getContext()));
result.addAttribute( result.addAttribute(getNumLoopsAttrName(),
getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops)); builder->getI32IntegerAttr(num_loops));
} }
void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) { void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) {
auto numLoopAttr = auto numLoopAttr =
op.getAttrOfType<IntegerAttr>(KrnlDefineLoopsOp::getNumLoopsAttrName()); op.getAttrOfType<IntegerAttr>(KrnlDefineLoopsOp::getNumLoopsAttrName());
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue(); p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
} }
ParseResult parseKrnlDefineLoopsOp( ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser,
OpAsmParser& parser, OperationState& result) { OperationState &result) {
// Parse the attribute indicating number of loops defined. // Parse the attribute indicating number of loops defined.
IntegerAttr numLoops; IntegerAttr numLoops;
auto& builder = parser.getBuilder(); auto &builder = parser.getBuilder();
auto intType = builder.getIntegerType(64); auto intType = builder.getIntegerType(64);
if (parser.parseAttribute(numLoops, intType, if (parser.parseAttribute(numLoops, intType,
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes)) KrnlDefineLoopsOp::getNumLoopsAttrName(),
result.attributes))
return failure(); return failure();
auto loopTypes = llvm::SmallVector<Type, 4>( auto loopTypes = llvm::SmallVector<Type, 4>(
@ -79,29 +80,29 @@ ParseResult parseKrnlDefineLoopsOp(
// KrnlOptimizeLoopsOp // KrnlOptimizeLoopsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void KrnlOptimizeLoopsOp::build( void KrnlOptimizeLoopsOp::build(Builder *builder, OperationState &result,
Builder* builder, OperationState& result, int num_optimized_loops) { int num_optimized_loops) {
result.types.append( result.types.append(num_optimized_loops,
num_optimized_loops, LoopType::get(builder->getContext())); LoopType::get(builder->getContext()));
// Create a region and a block for the body. // Create a region and a block for the body.
// Schedule intrinsics will be placed into this region. // Schedule intrinsics will be placed into this region.
Region* region = result.addRegion(); Region *region = result.addRegion();
auto* body = new Block(); auto *body = new Block();
region->push_back(body); region->push_back(body);
} }
void print(OpAsmPrinter& p, KrnlOptimizeLoopsOp& op) { void print(OpAsmPrinter &p, KrnlOptimizeLoopsOp &op) {
p << "krnl.optimize_loops "; p << "krnl.optimize_loops ";
p.printRegion(op.region(), /*printEntryBlockArgs=*/false, p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true); /*printBlockTerminators=*/true);
p << " : "; p << " : ";
p.printFunctionalType(op); p.printFunctionalType(op);
} }
ParseResult parseKrnlOptimizeLoopsOp( ParseResult parseKrnlOptimizeLoopsOp(OpAsmParser &parser,
OpAsmParser& parser, OperationState& result) { OperationState &result) {
// Parse the schedule body region. // Parse the schedule body region.
Region* region = result.addRegion(); Region *region = result.addRegion();
if (parser.parseRegion(*region, llvm::None, llvm::None)) if (parser.parseRegion(*region, llvm::None, llvm::None))
return failure(); return failure();
@ -142,21 +143,22 @@ ParseResult parseKrnlOptimizeLoopsOp(
* Then the bounds will be parsed as: * Then the bounds will be parsed as:
* %i0 = 10 to N : %i1 = M to 20 * %i0 = 10 to N : %i1 = M to 20
*/ */
void KrnlIterateOp::build(Builder* builder, OperationState& result, void KrnlIterateOp::build(Builder *builder, OperationState &result,
KrnlIterateOperandPack operandPack) { KrnlIterateOperandPack operandPack) {
// Record optimized loops and the number of such loops. // Record optimized loops and the number of such loops.
result.addOperands(operandPack.getOperands()); result.addOperands(operandPack.getOperands());
result.addAttribute( result.addAttribute(KrnlIterateOp::getBoundsAttrName(),
KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes()); operandPack.getAttributes());
result.addAttribute(getNumOptimizedLoopsAttrName(), result.addAttribute(
getNumOptimizedLoopsAttrName(),
builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops())); builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
// Create a region and a block for the body. The arguments of the region are // Create a region and a block for the body. The arguments of the region are
// the loop induction variables; there can be multiple induction variables // the loop induction variables; there can be multiple induction variables
// associated with the same krnl.iterate operation. // associated with the same krnl.iterate operation.
Region* bodyRegion = result.addRegion(); Region *bodyRegion = result.addRegion();
auto* body = new Block(); auto *body = new Block();
auto body_args = llvm::SmallVector<Type, 4>( auto body_args = llvm::SmallVector<Type, 4>(
operandPack.getNumInputLoops(), IndexType::get(builder->getContext())); operandPack.getNumInputLoops(), IndexType::get(builder->getContext()));
body->addArguments(body_args); body->addArguments(body_args);
@ -165,7 +167,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
ensureTerminator(*bodyRegion, *builder, result.location); ensureTerminator(*bodyRegion, *builder, result.location);
} }
void print(OpAsmPrinter& p, KrnlIterateOp& op) { void print(OpAsmPrinter &p, KrnlIterateOp &op) {
p << "krnl.iterate("; p << "krnl.iterate(";
// Print optimized loops: // Print optimized loops:
auto numOptimizedLoops = op.getNumOptimizedLoops(); auto numOptimizedLoops = op.getNumOptimizedLoops();
@ -180,7 +182,7 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) {
auto operandItr = op.operand_begin() + numOptimizedLoops; auto operandItr = op.operand_begin() + numOptimizedLoops;
std::string delimiter; std::string delimiter;
for (auto& var : inductionVars) { for (auto &var : inductionVars) {
p << delimiter; p << delimiter;
p.printOperand(*operandItr++); p.printOperand(*operandItr++);
p << " -> "; p << " -> ";
@ -194,25 +196,26 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) {
p << ")"; p << ")";
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false, p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false); /*printBlockTerminators=*/false);
} }
ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
auto builder = parser.getBuilder(); auto builder = parser.getBuilder();
auto context = builder.getContext(); auto context = builder.getContext();
onnf::KrnlDialectOperandParser operandParser(parser); onnf::KrnlDialectOperandParser operandParser(parser);
// Parse optimized loops: // Parse optimized loops:
SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs; SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
if (parser.parseOperandList( if (parser.parseOperandList(optimizedLoopRefs,
optimizedLoopRefs, OpAsmParser::Delimiter::Paren) || OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(optimizedLoopRefs, parser.resolveOperands(optimizedLoopRefs,
LoopType::get(result.getContext()), result.operands)) LoopType::get(result.getContext()),
result.operands))
return failure(); return failure();
// Record how many optimized loops did we parse. // Record how many optimized loops did we parse.
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(), result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
builder.getI64IntegerAttr(optimizedLoopRefs.size())); builder.getI64IntegerAttr(optimizedLoopRefs.size()));
// Parse input loops and their lower and upper bounds. // Parse input loops and their lower and upper bounds.
SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs; SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs;
@ -222,16 +225,16 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
return failure(); return failure();
// A function to parse a lower or upper bound. // A function to parse a lower or upper bound.
auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps]( auto parseBound = [&result, &builder, &parser, &operandParser,
bool isUpper) -> ParseResult { &boundMaps](bool isUpper) -> ParseResult {
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
// the map has multiple results. // the map has multiple results.
bool failedToParsedMinMax = bool failedToParsedMinMax =
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max")); failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
// Try parse an SSA operand. // Try parse an SSA operand.
if (succeeded(operandParser.ParseOptionalOperand( if (succeeded(operandParser.ParseOptionalOperand(builder.getIndexType(),
builder.getIndexType(), result.operands))) { result.operands))) {
AffineMap map = builder.getSymbolIdentityMap(); AffineMap map = builder.getSymbolIdentityMap();
boundMaps.emplace_back(AffineMapAttr::get(map)); boundMaps.emplace_back(AffineMapAttr::get(map));
return success(); return success();
@ -243,8 +246,8 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
llvm::SMLoc attrLoc = parser.getCurrentLocation(); llvm::SMLoc attrLoc = parser.getCurrentLocation();
Attribute boundAttr; Attribute boundAttr;
llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer; llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
if (parser.parseAttribute( if (parser.parseAttribute(boundAttr, builder.getIndexType(), "temp",
boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer)) tempBoundAttrContainer))
return failure(); return failure();
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
@ -255,13 +258,15 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
auto map = affineMapAttr.getValue(); auto map = affineMapAttr.getValue();
if (map.getNumDims() != numDims) if (map.getNumDims() != numDims)
return parser.emitError(parser.getNameLoc(), return parser.emitError(
parser.getNameLoc(),
"dim operand count and integer set dim count must match"); "dim operand count and integer set dim count must match");
unsigned numDimAndSymbolOperands = unsigned numDimAndSymbolOperands =
result.operands.size() - currentNumOperands; result.operands.size() - currentNumOperands;
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
return parser.emitError(parser.getNameLoc(), return parser.emitError(
parser.getNameLoc(),
"symbol operand count and integer set symbol count must match"); "symbol operand count and integer set symbol count must match");
// If the map has multiple results, make sure that we parsed the min/max // If the map has multiple results, make sure that we parsed the min/max
@ -269,11 +274,11 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
if (map.getNumResults() > 1 && failedToParsedMinMax) { if (map.getNumResults() > 1 && failedToParsedMinMax) {
if (isUpper) if (isUpper)
return parser.emitError(attrLoc, return parser.emitError(attrLoc,
"upper loop bound affine map with multiple " "upper loop bound affine map with multiple "
"results requires 'min' prefix"); "results requires 'min' prefix");
return parser.emitError(attrLoc, return parser.emitError(attrLoc,
"lower loop bound affine mapwith " "lower loop bound affine mapwith "
"multiple results requires 'max' prefix"); "multiple results requires 'max' prefix");
} }
boundMaps.emplace_back(AffineMapAttr::get(map)); boundMaps.emplace_back(AffineMapAttr::get(map));
return success(); return success();
@ -286,7 +291,7 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
} }
}; };
bool keepParsing; // Do we keep parsing loops/bounds? bool keepParsing; // Do we keep parsing loops/bounds?
do { do {
// Parse an input loop operand; // Parse an input loop operand;
operandParser.ParseOperand(LoopType::get(context), result.operands); operandParser.ParseOperand(LoopType::get(context), result.operands);
@ -316,18 +321,18 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
// At this point, there shouldn't be any operands left to parse. // At this point, there shouldn't be any operands left to parse.
if (operandParser.hasOperandLeft()) if (operandParser.hasOperandLeft())
return parser.emitError(parser.getCurrentLocation()); return parser.emitError(parser.getCurrentLocation());
result.addAttribute( result.addAttribute(KrnlIterateOp::getBoundsAttrName(),
KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps)); builder.getArrayAttr(boundMaps));
Region* region = result.addRegion(); Region *region = result.addRegion();
SmallVector<Type, 4> inductionVarTypes( SmallVector<Type, 4> inductionVarTypes(inductionVarRefs.size(),
inductionVarRefs.size(), builder.getIndexType()); builder.getIndexType());
if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes)) if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
return failure(); return failure();
// Ensure iterate region is closed off with krnl.terminate. // Ensure iterate region is closed off with krnl.terminate.
KrnlIterateOp::ensureTerminator( KrnlIterateOp::ensureTerminator(*region, parser.getBuilder(),
*region, parser.getBuilder(), result.location); result.location);
return success(); return success();
} }
@ -341,18 +346,19 @@ static LogicalResult verify(KrnlIterateOp op) {
// KrnlReturnLoopsOp // KrnlReturnLoopsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void print(OpAsmPrinter& p, KrnlReturnLoopsOp& op) { void print(OpAsmPrinter &p, KrnlReturnLoopsOp &op) {
p << "krnl.return_loops "; p << "krnl.return_loops ";
p.printOperands(op.operand_begin(), op.operand_end()); p.printOperands(op.operand_begin(), op.operand_end());
} }
ParseResult parseKrnlReturnLoopsOp( ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser,
OpAsmParser& parser, OperationState& result) { OperationState &result) {
// Parse the loops to return. // Parse the loops to return.
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers; SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
if (parser.parseOperandList(timestamp_dim_handlers) || if (parser.parseOperandList(timestamp_dim_handlers) ||
parser.resolveOperands(timestamp_dim_handlers, parser.resolveOperands(timestamp_dim_handlers,
LoopType::get(result.getContext()), result.operands)) LoopType::get(result.getContext()),
result.operands))
return failure(); return failure();
return success(); return success();
@ -360,4 +366,4 @@ ParseResult parseKrnlReturnLoopsOp(
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/compiler/krnl.cpp.inc" #include "src/compiler/krnl.cpp.inc"
} // namespace mlir } // namespace mlir

View File

@ -19,12 +19,12 @@
namespace mlir { namespace mlir {
class KrnlOpsDialect : public Dialect { class KrnlOpsDialect : public Dialect {
public: public:
KrnlOpsDialect(MLIRContext* context); KrnlOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "krnl"; } static StringRef getDialectNamespace() { return "krnl"; }
/// Parse a type registered to this dialect. /// Parse a type registered to this dialect.
Type parseType(DialectAsmParser& parser) const override { Type parseType(DialectAsmParser &parser) const override {
if (succeeded(parser.parseOptionalKeyword("loop"))) if (succeeded(parser.parseOptionalKeyword("loop")))
return LoopType::get(parser.getBuilder().getContext()); return LoopType::get(parser.getBuilder().getContext());
@ -32,15 +32,15 @@ class KrnlOpsDialect : public Dialect {
} }
/// Print a type registered to this dialect. /// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter& os) const override { void printType(Type type, DialectAsmPrinter &os) const override {
switch (type.getKind()) { switch (type.getKind()) {
case KrnlTypes::Loop: case KrnlTypes::Loop:
os << "loop"; os << "loop";
return; return;
} }
} }
}; };
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/compiler/krnl.hpp.inc" #include "src/compiler/krnl.hpp.inc"
} // namespace mlir } // namespace mlir

View File

@ -11,17 +11,16 @@
#include <map> #include <map>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "src/compiler/dialect/krnl/krnl_helper.hpp" #include "src/compiler/dialect/krnl/krnl_helper.hpp"
#include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "src/compiler/pass/passes.hpp" #include "src/compiler/pass/passes.hpp"
using namespace mlir; using namespace mlir;
@ -98,7 +97,7 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to allocate at the beginning of the block if // Make sure to allocate at the beginning of the block if
// all dimensions are known. // all dimensions are known.
auto* parentBlock = alloc.getOperation()->getBlock(); auto *parentBlock = alloc.getOperation()->getBlock();
if (hasAllConstantDimensions(type)) if (hasAllConstantDimensions(type))
alloc.getOperation()->moveBefore(&parentBlock->front()); alloc.getOperation()->moveBefore(&parentBlock->front());
@ -113,17 +112,17 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc,
// Determine if current function returns the result value of the // Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be // current op being lowered. If it does then dealloc should not be
// inserted. // inserted.
static bool checkInsertDealloc(Operation* currentOp) { static bool checkInsertDealloc(Operation *currentOp) {
auto parentBlock = currentOp->getBlock(); auto parentBlock = currentOp->getBlock();
bool insertDealloc = true; bool insertDealloc = true;
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
assert(currentOp->getNumResults() < 2 && assert(currentOp->getNumResults() < 2 &&
"No more than one result supported (for now)."); "No more than one result supported (for now).");
// If there is at least one result to investigate. // If there is at least one result to investigate.
if (currentOp->getNumResults() > 0) { if (currentOp->getNumResults() > 0) {
auto result = currentOp->getResult(0); auto result = currentOp->getResult(0);
for (const auto& operand : op.getOperands()) for (const auto &operand : op.getOperands())
if (operand == result) if (operand == result)
insertDealloc = false; insertDealloc = false;
} }
@ -148,7 +147,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
std::map<int, std::map<int, Value *> > std::map<int, std::map<int, Value *>>
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
MemRefType memRefType, ArrayRef<Value *> operands) { MemRefType memRefType, ArrayRef<Value *> operands) {
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
@ -196,15 +195,15 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
// given operand. // given operand.
std::vector<Value *> std::vector<Value *>
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value *> loopIVs, Value *operand, ArrayRef<Value *> loopIVs, Value *operand,
std::map<int, Value *> broadcastedDims) { std::map<int, Value *> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the // `operand` must has a ranked type. This should have been checked by the
// shape inference pass. // shape inference pass.
auto operandShape = operand->getType().cast<MemRefType>().getShape(); auto operandShape = operand->getType().cast<MemRefType>().getShape();
auto rank = operandShape.size(); auto rank = operandShape.size();
auto loopCount = loopIVs.size(); auto loopCount = loopIVs.size();
std::vector<Value*> newLoopIVs; std::vector<Value *> newLoopIVs;
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
auto dimIdx = rank - 1 - reversedIdx; auto dimIdx = rank - 1 - reversedIdx;
auto loopIdx = loopCount - 1 - reversedIdx; auto loopIdx = loopCount - 1 - reversedIdx;
@ -218,8 +217,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
// If its value is 1, it is broadcasted dimension. // If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension. // Otherwise, non-broadcasted dimension.
auto zero = rewriter.create<ConstantIndexOp>(loc, 0); auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
zero, loopIVs[loopIdx]); loopIVs[loopIdx]);
newLoopIVs.insert(newLoopIVs.begin(), idx); newLoopIVs.insert(newLoopIVs.begin(), idx);
} else { } else {
// Non-broadcasted dimension // Non-broadcasted dimension
@ -260,26 +259,26 @@ struct ScalarOp<ONNXSubOp> {
template <> template <>
struct ScalarOp<ONNXAndOp> { struct ScalarOp<ONNXAndOp> {
using FOp = AndOp; // not use using FOp = AndOp; // not use
using IOp = AndOp; using IOp = AndOp;
}; };
template <> template <>
struct ScalarOp<ONNXOrOp> { struct ScalarOp<ONNXOrOp> {
using FOp = OrOp; // not use using FOp = OrOp; // not use
using IOp = OrOp; using IOp = OrOp;
}; };
template <> template <>
struct ScalarOp<ONNXXorOp> { struct ScalarOp<ONNXXorOp> {
using FOp = XOrOp; // not use using FOp = XOrOp; // not use
using IOp = XOrOp; using IOp = XOrOp;
}; };
template <> template <>
struct ScalarOp<ONNXExpOp> { struct ScalarOp<ONNXExpOp> {
using FOp = ExpOp; using FOp = ExpOp;
using IOp = ExpOp; // not use using IOp = ExpOp; // not use
}; };
template <> template <>
@ -297,18 +296,19 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
// Scalar unary ops for lowering to Krnl dialect. // Scalar unary ops for lowering to Krnl dialect.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <typename UnaryOp> template <typename UnaryOp>
Value* mapToLowerScalarOp(Operation* op, ArrayRef<Type> result_types, Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
/* Lower UnaryOp to Ops in the Standard dialect. /* Lower UnaryOp to Ops in the Standard dialect.
*/ */
auto loc = op->getLoc(); auto loc = op->getLoc();
Type element_type = operands.front()->getType(); Type element_type = operands.front()->getType();
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
return rewriter.create<ScalarIOp<UnaryOp>>( return rewriter.create<ScalarIOp<UnaryOp>>(loc, result_types, operands,
loc, result_types, operands, mlir::None); mlir::None);
} else if (element_type.isa<FloatType>()) { } else if (element_type.isa<FloatType>()) {
return rewriter.create<ScalarFOp<UnaryOp>>( return rewriter.create<ScalarFOp<UnaryOp>>(loc, result_types, operands,
loc, result_types, operands, mlir::None); mlir::None);
} else { } else {
emitError(loc, "unsupported element type"); emitError(loc, "unsupported element type");
return nullptr; return nullptr;
@ -319,13 +319,14 @@ Value* mapToLowerScalarOp(Operation* op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXTanhOp // Scalar unary ops for lowering ONNXTanhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXTanhOp>(Operation* op, Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands, ArrayRef<Type> result_types,
ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto neg = rewriter.create<SubFOp>(loc, zero, operand); auto neg = rewriter.create<SubFOp>(loc, zero, operand);
@ -333,7 +334,7 @@ Value* mapToLowerScalarOp<ONNXTanhOp>(Operation* op,
auto negExp = rewriter.create<ExpOp>(loc, neg); auto negExp = rewriter.create<ExpOp>(loc, neg);
auto result = auto result =
rewriter.create<DivFOp>(loc, rewriter.create<SubFOp>(loc, exp, negExp), rewriter.create<DivFOp>(loc, rewriter.create<SubFOp>(loc, exp, negExp),
rewriter.create<AddFOp>(loc, exp, negExp)); rewriter.create<AddFOp>(loc, exp, negExp));
return result; return result;
} }
@ -342,13 +343,14 @@ Value* mapToLowerScalarOp<ONNXTanhOp>(Operation* op,
// Scalar unary ops for lowering ONNXSinhOp // Scalar unary ops for lowering ONNXSinhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXSinhOp>(Operation* op, Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands, ArrayRef<Type> result_types,
ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
@ -365,13 +367,14 @@ Value* mapToLowerScalarOp<ONNXSinhOp>(Operation* op,
// Scalar unary ops for lowering ONNXCoshOp // Scalar unary ops for lowering ONNXCoshOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXCoshOp>(Operation* op, Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands, ArrayRef<Type> result_types,
ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2) // ConstantOp 2)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f)); auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
@ -388,13 +391,14 @@ Value* mapToLowerScalarOp<ONNXCoshOp>(Operation* op,
// Scalar unary ops for lowering ONNXSigmoidOp // Scalar unary ops for lowering ONNXSigmoidOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXSigmoidOp>(Operation* op, Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands, ArrayRef<Type> result_types,
ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
@ -410,9 +414,9 @@ Value* mapToLowerScalarOp<ONNXSigmoidOp>(Operation* op,
// Scalar unary ops for lowering ONNXHardSigmoidOp // Scalar unary ops for lowering ONNXHardSigmoidOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op, Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
ArrayRef<Type> result_types, ArrayRef<Value*> operands, Operation *op, ArrayRef<Type> result_types, ArrayRef<Value *> operands,
ConversionPatternRewriter& rewriter) { ConversionPatternRewriter &rewriter) {
// %Y = AddFOp(MulFOp(alpha, %X), beta) // %Y = AddFOp(MulFOp(alpha, %X), beta)
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
// %Y, // %Y,
@ -421,7 +425,7 @@ Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op,
// %Z, // %Z,
// Constant 1) // Constant 1)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta"); auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
@ -446,13 +450,14 @@ Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op,
// Scalar unary ops for lowering ONNXEluOp // Scalar unary ops for lowering ONNXEluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types, Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
// %X) // %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
@ -461,9 +466,10 @@ Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types,
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero = auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero, auto result = rewriter.create<SelectOp>(
rewriter.create<MulFOp>( loc, lessThanZero,
loc, alpha, rewriter.create<SubFOp>(loc, exp, one)), rewriter.create<MulFOp>(loc, alpha,
rewriter.create<SubFOp>(loc, exp, one)),
operand); operand);
return result; return result;
@ -473,14 +479,15 @@ Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXReluOp // Scalar unary ops for lowering ONNXReluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXReluOp>(Operation* op, Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands, ArrayRef<Type> result_types,
ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// ConstantOp 0, // ConstantOp 0,
// %X) // %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto lessThanZero = auto lessThanZero =
@ -494,14 +501,15 @@ Value* mapToLowerScalarOp<ONNXReluOp>(Operation* op,
// Scalar unary ops for lowering ONNXLeakyReluOp // Scalar unary ops for lowering ONNXLeakyReluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXLeakyReluOp>(Operation* op, Value *
ArrayRef<Type> result_types, ArrayRef<Value*> operands, mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, %X), // MulFOp(alpha, %X),
// %X) // %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f)); auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
@ -518,16 +526,17 @@ Value* mapToLowerScalarOp<ONNXLeakyReluOp>(Operation* op,
// Scalar unary ops for lowering ONNXSeluOp // Scalar unary ops for lowering ONNXSeluOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op, Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands, ArrayRef<Type> result_types,
ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
// MulFOp(gamma, %X), // MulFOp(gamma, %X),
// MulFOp(gamma, // MulFOp(gamma,
// SubFOp(MulFOp(alpha, ExpOp(%X)), // SubFOp(MulFOp(alpha, ExpOp(%X)),
// alpha))) // alpha)))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma"); auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
@ -537,9 +546,10 @@ Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
auto exp = rewriter.create<ExpOp>(loc, operand); auto exp = rewriter.create<ExpOp>(loc, operand);
auto greaterThanZero = auto greaterThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero); rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
auto select = rewriter.create<SelectOp>(loc, greaterThanZero, operand, auto select = rewriter.create<SelectOp>(
rewriter.create<SubFOp>( loc, greaterThanZero, operand,
loc, rewriter.create<MulFOp>(loc, alpha, exp), alpha)); rewriter.create<SubFOp>(loc, rewriter.create<MulFOp>(loc, alpha, exp),
alpha));
auto result = rewriter.create<MulFOp>(loc, gamma, select); auto result = rewriter.create<MulFOp>(loc, gamma, select);
return result; return result;
@ -549,11 +559,13 @@ Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
// Scalar unary ops for lowering ONNXReciprocalOp // Scalar unary ops for lowering ONNXReciprocalOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXReciprocalOp>(Operation* op, ArrayRef<Type> result_types, Value *
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* operand = operands[0]; Value *operand = operands[0];
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f)); auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto result = rewriter.create<DivFOp>(loc, one, operand); auto result = rewriter.create<DivFOp>(loc, one, operand);
@ -565,14 +577,15 @@ Value* mapToLowerScalarOp<ONNXReciprocalOp>(Operation* op, ArrayRef<Type> result
// Scalar unary ops for lowering ONNXMaxOp // Scalar unary ops for lowering ONNXMaxOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXMaxOp>(Operation* op, ArrayRef<Type> result_types, Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
// %X, // %X,
// %Y) // %Y)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* lhs = operands[0]; Value *lhs = operands[0];
Value* rhs = operands[1]; Value *rhs = operands[1];
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result; return result;
@ -582,14 +595,15 @@ Value* mapToLowerScalarOp<ONNXMaxOp>(Operation* op, ArrayRef<Type> result_types,
// Scalar unary ops for lowering ONNXMinOp // Scalar unary ops for lowering ONNXMinOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
Value* mapToLowerScalarOp<ONNXMinOp>(Operation* op, ArrayRef<Type> result_types, Value *mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) { ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
// %X, // %X,
// %Y) // %Y)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value* lhs = operands[0]; Value *lhs = operands[0];
Value* rhs = operands[1]; Value *rhs = operands[1];
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs); auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs); auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result; return result;
@ -599,10 +613,11 @@ Value* mapToLowerScalarOp<ONNXMinOp>(Operation* op, ArrayRef<Type> result_types,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <typename ElementwiseUnaryOp> template <typename ElementwiseUnaryOp>
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
ONNXElementwiseUnaryOpLowering(MLIRContext* ctx) ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands, PatternMatchResult
ConversionPatternRewriter& rewriter) const final { matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of // An element-wise unary operation must have all operands and the result of
// the same type. This should have been verified by the verifier. // the same type. This should have been verified by the verifier.
@ -618,14 +633,14 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// dimensions with the result at this pre-optimization phase. // dimensions with the result at this pre-optimization phase.
// TODO: verify that dimensions match. // TODO: verify that dimensions match.
// TODO: can the dimension of the result differ after optimizations? // TODO: can the dimension of the result differ after optimizations?
Value* alloc; Value *alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc( alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
memRefType, loc, rewriter, insertDealloc, {operands[0]}); {operands[0]});
// Number of loops // Number of loops
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
@ -633,7 +648,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// Define loops. // Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value*> originalLoops; std::vector<Value *> originalLoops;
originalLoops.reserve(rank); originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) { for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result); originalLoops.push_back(result);
@ -641,12 +656,12 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// Define loop optimization. // Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value*> optimizedLoops; std::vector<Value *> optimizedLoops;
optimizedLoops.reserve(rank); optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) { for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result); optimizedLoops.push_back(result);
} }
Block& optimizationBlock = optimizedLoopsOp.region().front(); Block &optimizationBlock = optimizedLoopsOp.region().front();
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest. // Iterate over the loop nest.
@ -664,7 +679,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
} }
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block& iterationBlock = iterateOp.bodyRegion().front(); Block &iterationBlock = iterateOp.bodyRegion().front();
// Now perform the insertions into the body of the // Now perform the insertions into the body of the
// just generated instructions: // just generated instructions:
@ -681,7 +696,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&iterationBlock); rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation: // Handle the operation:
SmallVector<Value*, 4> loopIVs; SmallVector<Value *, 4> loopIVs;
for (auto arg : iterationBlock.getArguments()) for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg); loopIVs.push_back(arg);
@ -701,10 +716,11 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <typename ElementwiseVariadicOp> template <typename ElementwiseVariadicOp>
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
ONNXElementwiseVariadicOpLowering(MLIRContext* ctx) ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands, PatternMatchResult
ConversionPatternRewriter& rewriter) const final { matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result // An element-wise variadic operation must have all operands and the result
// of the same type. This should have been verified by the verifier. // of the same type. This should have been verified by the verifier.
@ -715,7 +731,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertTensorToMemRef(tensorType);
Value* alloc; Value *alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
// If the output has a dynamic dimension, we compute its dimension at // If the output has a dynamic dimension, we compute its dimension at
// runtime by using dimensions from the operands. // runtime by using dimensions from the operands.
@ -725,8 +741,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc( alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
memRefType, loc, rewriter, insertDealloc, operands); operands);
// Number of loops // Number of loops
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
@ -734,7 +750,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Define loops. // Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank); auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value*> originalLoops; std::vector<Value *> originalLoops;
originalLoops.reserve(rank); originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) { for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result); originalLoops.push_back(result);
@ -742,12 +758,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// Define loop optimization. // Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank); auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value*> optimizedLoops; std::vector<Value *> optimizedLoops;
optimizedLoops.reserve(rank); optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) { for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result); optimizedLoops.push_back(result);
} }
Block& optimizationBlock = optimizedLoopsOp.region().front(); Block &optimizationBlock = optimizedLoopsOp.region().front();
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest. // Iterate over the loop nest.
@ -770,7 +786,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
getBroadcastedDimInfo(loc, rewriter, memRefType, operands); getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block& iterationBlock = iterateOp.bodyRegion().front(); Block &iterationBlock = iterateOp.bodyRegion().front();
// Now perform the insertions into the body of the // Now perform the insertions into the body of the
// just generated instructions: // just generated instructions:
@ -786,7 +802,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&iterationBlock); rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation: // Handle the operation:
SmallVector<Value*, 4> loopIVs; SmallVector<Value *, 4> loopIVs;
for (auto arg : iterationBlock.getArguments()) for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg); loopIVs.push_back(arg);
@ -812,35 +828,36 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
}; };
struct ONNXReshapeOpLowering : public ConversionPattern { struct ONNXReshapeOpLowering : public ConversionPattern {
ONNXReshapeOpLowering(MLIRContext* ctx) ONNXReshapeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands, PatternMatchResult
ConversionPatternRewriter& rewriter) const final { matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>(); auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertTensorToMemRef(tensorType);
Value* alloc; Value *alloc;
// Compute size in bytes. // Compute size in bytes.
Value* tensorSize = rewriter.create<ConstantOp>(loc, Value *tensorSize = rewriter.create<ConstantOp>(
rewriter.getIntegerAttr( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); getMemRefEltSizeInBytes(memRefType)));
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType)) { if (hasAllConstantDimensions(memRefType)) {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
} else { } else {
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
SmallVector<Value*, 4> allocOperands; SmallVector<Value *, 4> allocOperands;
for (int i = 0; i < memRefShape.size(); ++i) { for (int i = 0; i < memRefShape.size(); ++i) {
// The shape array can always be used to construct shape information of // The shape array can always be used to construct shape information of
// the result. // the result.
Value* index = rewriter.create<ConstantOp>( Value *index = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
Value* loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); Value *loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
Value* int64LoadedVal = rewriter.create<ZeroExtendIOp>( Value *int64LoadedVal = rewriter.create<ZeroExtendIOp>(
loc, loadedVal, rewriter.getIntegerType(64)); loc, loadedVal, rewriter.getIntegerType(64));
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal); tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
allocOperands.push_back(rewriter.create<IndexCastOp>( allocOperands.push_back(rewriter.create<IndexCastOp>(
@ -851,7 +868,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
// Make sure to allocate at the beginning of the block if // Make sure to allocate at the beginning of the block if
// all dimensions are known. // all dimensions are known.
auto* parentBlock = allocateMemref.getOperation()->getBlock(); auto *parentBlock = allocateMemref.getOperation()->getBlock();
if (insertDealloc) { if (insertDealloc) {
auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref); auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref);
dealloc.getOperation()->moveBefore(&parentBlock->back()); dealloc.getOperation()->moveBefore(&parentBlock->back());
@ -874,7 +891,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
struct TensorTypeConverter : public TypeConverter { struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter; using TypeConverter::TypeConverter;
LogicalResult convertType(Type t, SmallVectorImpl<Type>& results) override { LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
if (auto tensor_type = t.dyn_cast<TensorType>()) { if (auto tensor_type = t.dyn_cast<TensorType>()) {
results.push_back(convertTensorToMemRef(tensor_type)); results.push_back(convertTensorToMemRef(tensor_type));
return success(); return success();
@ -889,12 +906,12 @@ struct TensorTypeConverter : public TypeConverter {
/// inputs. Once unranked results can be handled gracefully this /// inputs. Once unranked results can be handled gracefully this
/// override needs to be removed in favour of the original MLIR one.] /// override needs to be removed in favour of the original MLIR one.]
bool isSignatureLegal(FunctionType funcType) { bool isSignatureLegal(FunctionType funcType) {
return llvm::all_of( return llvm::all_of(funcType.getInputs(),
funcType.getInputs(), [this](Type type) { return isLegal(type); }); [this](Type type) { return isLegal(type); });
} }
}; };
} // end anonymous namespace. } // end anonymous namespace.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Frontend to Krnl Dialect lowering pass // Frontend to Krnl Dialect lowering pass
@ -906,7 +923,7 @@ struct FrontendToKrnlLoweringPass
: public ModulePass<FrontendToKrnlLoweringPass> { : public ModulePass<FrontendToKrnlLoweringPass> {
void runOnModule() final; void runOnModule() final;
}; };
} // end anonymous namespace. } // end anonymous namespace.
void FrontendToKrnlLoweringPass::runOnModule() { void FrontendToKrnlLoweringPass::runOnModule() {
auto module = getModule(); auto module = getModule();
@ -943,32 +960,32 @@ void FrontendToKrnlLoweringPass::runOnModule() {
// Type conversion for function signatures. // Type conversion for function signatures.
// Call MLIR FuncOp signature conversion when result type is // Call MLIR FuncOp signature conversion when result type is
// a ranked tensor. // a ranked tensor.
populateFuncOpTypeConversionPattern( populateFuncOpTypeConversionPattern(patterns, &getContext(),
patterns, &getContext(), tensor_to_memref_converter); tensor_to_memref_converter);
// Frontent operation lowering. // Frontent operation lowering.
patterns.insert<ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>, patterns.insert<ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering>(&getContext()); ONNXReshapeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`
@ -981,5 +998,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
return std::make_unique<FrontendToKrnlLoweringPass>(); return std::make_unique<FrontendToKrnlLoweringPass>();
} }
static PassRegistration<FrontendToKrnlLoweringPass> pass( static PassRegistration<FrontendToKrnlLoweringPass>
"lower-frontend", "Lower frontend ops to Krnl dialect."); pass("lower-frontend", "Lower frontend ops to Krnl dialect.");

View File

@ -17,8 +17,8 @@ namespace {
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> { struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern; using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(KrnlIterateOp iterateOp,
KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override { PatternRewriter &rewriter) const override {
auto boundMapAttrs = auto boundMapAttrs =
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName()) iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
.getValue(); .getValue();
@ -30,23 +30,23 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
operandItr++; operandItr++;
// Organize operands into lower/upper bounds in affine.for ready formats. // Organize operands into lower/upper bounds in affine.for ready formats.
SmallVector<Value*, 4> lbOperands, ubOperands; SmallVector<Value *, 4> lbOperands, ubOperands;
AffineMap lbMap, ubMap; AffineMap lbMap, ubMap;
for (int boundType = 0; boundType < 2; boundType++) { for (int boundType = 0; boundType < 2; boundType++) {
auto& operands = boundType == 0 ? lbOperands : ubOperands; auto &operands = boundType == 0 ? lbOperands : ubOperands;
auto& map = boundType == 0 ? lbMap : ubMap; auto &map = boundType == 0 ? lbMap : ubMap;
map = boundMapAttrs[boundIdx + boundType] map = boundMapAttrs[boundIdx + boundType]
.cast<AffineMapAttr>() .cast<AffineMapAttr>()
.getValue(); .getValue();
operands.insert( operands.insert(operands.end(), operandItr,
operands.end(), operandItr, operandItr + map.getNumInputs()); operandItr + map.getNumInputs());
std::advance(operandItr, map.getNumInputs()); std::advance(operandItr, map.getNumInputs());
} }
nestedForOps.emplace_back(rewriter.create<AffineForOp>( nestedForOps.emplace_back(rewriter.create<AffineForOp>(
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap)); iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap));
rewriter.setInsertionPoint(nestedForOps.back().getBody(), rewriter.setInsertionPoint(nestedForOps.back().getBody(),
nestedForOps.back().getBody()->begin()); nestedForOps.back().getBody()->begin());
} }
// Replace induction variable references from those introduced by a // Replace induction variable references from those introduced by a
@ -68,7 +68,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
auto innermostForOp = nestedForOps.back(); auto innermostForOp = nestedForOps.back();
innermostForOp.region().getBlocks().clear(); innermostForOp.region().getBlocks().clear();
rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(), rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(),
innermostForOp.region().end()); innermostForOp.region().end());
rewriter.eraseOp(iterateOp); rewriter.eraseOp(iterateOp);
return matchSuccess(); return matchSuccess();
@ -80,11 +80,11 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> { class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
public: public:
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern; using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(KrnlTerminatorOp op,
KrnlTerminatorOp op, PatternRewriter& rewriter) const override { PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op); rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
return matchSuccess(); return matchSuccess();
} }
@ -95,11 +95,11 @@ class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> { class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
public: public:
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern; using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(KrnlDefineLoopsOp op,
KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override { PatternRewriter &rewriter) const override {
rewriter.eraseOp(op); rewriter.eraseOp(op);
return matchSuccess(); return matchSuccess();
} }
@ -110,11 +110,11 @@ class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> { class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
public: public:
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern; using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(KrnlOptimizeLoopsOp op,
KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override { PatternRewriter &rewriter) const override {
rewriter.eraseOp(op); rewriter.eraseOp(op);
return matchSuccess(); return matchSuccess();
} }
@ -132,7 +132,7 @@ struct KrnlToAffineLoweringPass
: public FunctionPass<KrnlToAffineLoweringPass> { : public FunctionPass<KrnlToAffineLoweringPass> {
void runOnFunction() final; void runOnFunction() final;
}; };
} // end anonymous namespace. } // end anonymous namespace.
void KrnlToAffineLoweringPass::runOnFunction() { void KrnlToAffineLoweringPass::runOnFunction() {
auto function = getFunction(); auto function = getFunction();
@ -146,17 +146,18 @@ void KrnlToAffineLoweringPass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering, patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext()); KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(
&getContext());
if (failed(applyPartialConversion(getFunction(), target, patterns))) if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure(); signalPassFailure();
} }
} // namespace } // namespace
std::unique_ptr<Pass> mlir::createLowerKrnlPass() { std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
return std::make_unique<KrnlToAffineLoweringPass>(); return std::make_unique<KrnlToAffineLoweringPass>();
} }
static PassRegistration<KrnlToAffineLoweringPass> pass( static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl",
"lower-krnl", "Lower Krnl dialect."); "Lower Krnl dialect.");

View File

@ -76,7 +76,7 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext& context,
} }
} }
int main(int ac, char* av[]) { int main(int ac, char *av[]) {
namespace po = boost::program_options; namespace po = boost::program_options;
po::options_description desc("ONNF available options"); po::options_description desc("ONNF available options");
@ -91,8 +91,8 @@ int main(int ac, char* av[]) {
po::positional_options_description p; po::positional_options_description p;
p.add("onnx-model", -1); p.add("onnx-model", -1);
po::variables_map vm; po::variables_map vm;
po::store( po::store(po::command_line_parser(ac, av).options(desc).positional(p).run(),
po::command_line_parser(ac, av).options(desc).positional(p).run(), vm); vm);
// TODO: allow multiple input files // TODO: allow multiple input files
assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!"); assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!");
@ -137,10 +137,10 @@ int main(int ac, char* av[]) {
// Write LLVM bitcode to disk. // Write LLVM bitcode to disk.
std::error_code EC; std::error_code EC;
llvm::raw_fd_ostream moduleBitcodeStream( llvm::raw_fd_ostream moduleBitcodeStream("model.bc", EC,
"model.bc", EC, llvm::sys::fs::F_None); llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile( llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); moduleBitcodeStream);
moduleBitcodeStream.flush(); moduleBitcodeStream.flush();
return 0; return 0;