Format Key Files using LLVM Style (#403)
* format using llvm style * merge and format
This commit is contained in:
parent
06a968d4a1
commit
a6a40cf989
|
@ -382,7 +382,7 @@ private:
|
||||||
auto end = default_str.find(",", start + 1);
|
auto end = default_str.find(",", start + 1);
|
||||||
if (end == std::string::npos) {
|
if (end == std::string::npos) {
|
||||||
end = default_str.find("}", start + 1);
|
end = default_str.find("}", start + 1);
|
||||||
if (end != std::string::npos && end > start+1) {
|
if (end != std::string::npos && end > start + 1) {
|
||||||
r.push_back(std::stoi(default_str.substr(start + 1, end)));
|
r.push_back(std::stoi(default_str.substr(start + 1, end)));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -401,7 +401,7 @@ private:
|
||||||
auto end = default_str.find(",", start + 1);
|
auto end = default_str.find(",", start + 1);
|
||||||
if (end == std::string::npos) {
|
if (end == std::string::npos) {
|
||||||
end = default_str.find("}", start + 1);
|
end = default_str.find("}", start + 1);
|
||||||
if (end != std::string::npos && end > start+1) {
|
if (end != std::string::npos && end > start + 1) {
|
||||||
r.push_back(std::stof(default_str.substr(start + 1, end)));
|
r.push_back(std::stof(default_str.substr(start + 1, end)));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -420,7 +420,7 @@ private:
|
||||||
auto end = default_str.find(",", start + 1);
|
auto end = default_str.find(",", start + 1);
|
||||||
if (end == std::string::npos) {
|
if (end == std::string::npos) {
|
||||||
end = default_str.find("}", start + 1);
|
end = default_str.find("}", start + 1);
|
||||||
if (end != std::string::npos && end > start+1) {
|
if (end != std::string::npos && end > start + 1) {
|
||||||
r.push_back(default_str.substr(start + 1, end));
|
r.push_back(default_str.substr(start + 1, end));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -529,18 +529,19 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
//for (auto [attr_name, attr_type, attr_default] : attrs) {
|
// for (auto [attr_name, attr_type, attr_default] : attrs) {
|
||||||
for (auto oneAttr: attrs) {
|
for (auto oneAttr : attrs) {
|
||||||
std::string attr_name;
|
std::string attr_name;
|
||||||
std::string attr_type;
|
std::string attr_type;
|
||||||
std::string attr_default;
|
std::string attr_default;
|
||||||
std::tie (attr_name, attr_type, attr_default) = oneAttr;
|
std::tie(attr_name, attr_type, attr_default) = oneAttr;
|
||||||
if (attr_type != "") {
|
if (attr_type != "") {
|
||||||
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
||||||
attributes.push_back(attr);
|
attributes.push_back(attr);
|
||||||
} else {
|
} else {
|
||||||
// TODO: the attributes need special handling
|
// TODO: the attributes need special handling
|
||||||
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
|
// std::cout << "missing " << node.op_type() << " " << attr_name <<
|
||||||
|
// std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -575,17 +576,18 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
for (auto oneAttr: attrs) {
|
for (auto oneAttr : attrs) {
|
||||||
std::string attr_name;
|
std::string attr_name;
|
||||||
std::string attr_type;
|
std::string attr_type;
|
||||||
std::string attr_default;
|
std::string attr_default;
|
||||||
std::tie (attr_name, attr_type, attr_default) = oneAttr;
|
std::tie(attr_name, attr_type, attr_default) = oneAttr;
|
||||||
if (attr_type != "") {
|
if (attr_type != "") {
|
||||||
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
||||||
attributes.push_back(attr);
|
attributes.push_back(attr);
|
||||||
} else {
|
} else {
|
||||||
// TODO: the attributes need special handling
|
// TODO: the attributes need special handling
|
||||||
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
|
// std::cout << "missing " << node.op_type() << " " << attr_name <<
|
||||||
|
// std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,6 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
||||||
#include "llvm/ADT/SetVector.h"
|
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
|
||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
|
@ -23,6 +21,8 @@
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
|
||||||
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context)
|
KrnlOpsDialect::KrnlOpsDialect(MLIRContext *context)
|
||||||
: Dialect(getDialectNamespace(), context) {
|
: Dialect(getDialectNamespace(), context) {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
|
@ -44,29 +44,30 @@ KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context)
|
||||||
// KrnlDefineLoopsOp
|
// KrnlDefineLoopsOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void KrnlDefineLoopsOp::build(
|
void KrnlDefineLoopsOp::build(Builder *builder, OperationState &result,
|
||||||
Builder* builder, OperationState& result, int64_t num_loops) {
|
int64_t num_loops) {
|
||||||
// Create the same number of dimension handlers as the number of
|
// Create the same number of dimension handlers as the number of
|
||||||
// dimensions in the associated integer set.
|
// dimensions in the associated integer set.
|
||||||
result.types.append(num_loops, LoopType::get(builder->getContext()));
|
result.types.append(num_loops, LoopType::get(builder->getContext()));
|
||||||
result.addAttribute(
|
result.addAttribute(getNumLoopsAttrName(),
|
||||||
getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops));
|
builder->getI32IntegerAttr(num_loops));
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) {
|
void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) {
|
||||||
auto numLoopAttr =
|
auto numLoopAttr =
|
||||||
op.getAttrOfType<IntegerAttr>(KrnlDefineLoopsOp::getNumLoopsAttrName());
|
op.getAttrOfType<IntegerAttr>(KrnlDefineLoopsOp::getNumLoopsAttrName());
|
||||||
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
|
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseKrnlDefineLoopsOp(
|
ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser,
|
||||||
OpAsmParser& parser, OperationState& result) {
|
OperationState &result) {
|
||||||
// Parse the attribute indicating number of loops defined.
|
// Parse the attribute indicating number of loops defined.
|
||||||
IntegerAttr numLoops;
|
IntegerAttr numLoops;
|
||||||
auto& builder = parser.getBuilder();
|
auto &builder = parser.getBuilder();
|
||||||
auto intType = builder.getIntegerType(64);
|
auto intType = builder.getIntegerType(64);
|
||||||
if (parser.parseAttribute(numLoops, intType,
|
if (parser.parseAttribute(numLoops, intType,
|
||||||
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
|
KrnlDefineLoopsOp::getNumLoopsAttrName(),
|
||||||
|
result.attributes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto loopTypes = llvm::SmallVector<Type, 4>(
|
auto loopTypes = llvm::SmallVector<Type, 4>(
|
||||||
|
@ -79,29 +80,29 @@ ParseResult parseKrnlDefineLoopsOp(
|
||||||
// KrnlOptimizeLoopsOp
|
// KrnlOptimizeLoopsOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void KrnlOptimizeLoopsOp::build(
|
void KrnlOptimizeLoopsOp::build(Builder *builder, OperationState &result,
|
||||||
Builder* builder, OperationState& result, int num_optimized_loops) {
|
int num_optimized_loops) {
|
||||||
result.types.append(
|
result.types.append(num_optimized_loops,
|
||||||
num_optimized_loops, LoopType::get(builder->getContext()));
|
LoopType::get(builder->getContext()));
|
||||||
// Create a region and a block for the body.
|
// Create a region and a block for the body.
|
||||||
// Schedule intrinsics will be placed into this region.
|
// Schedule intrinsics will be placed into this region.
|
||||||
Region* region = result.addRegion();
|
Region *region = result.addRegion();
|
||||||
auto* body = new Block();
|
auto *body = new Block();
|
||||||
region->push_back(body);
|
region->push_back(body);
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(OpAsmPrinter& p, KrnlOptimizeLoopsOp& op) {
|
void print(OpAsmPrinter &p, KrnlOptimizeLoopsOp &op) {
|
||||||
p << "krnl.optimize_loops ";
|
p << "krnl.optimize_loops ";
|
||||||
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
|
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
|
||||||
/*printBlockTerminators=*/true);
|
/*printBlockTerminators=*/true);
|
||||||
p << " : ";
|
p << " : ";
|
||||||
p.printFunctionalType(op);
|
p.printFunctionalType(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseKrnlOptimizeLoopsOp(
|
ParseResult parseKrnlOptimizeLoopsOp(OpAsmParser &parser,
|
||||||
OpAsmParser& parser, OperationState& result) {
|
OperationState &result) {
|
||||||
// Parse the schedule body region.
|
// Parse the schedule body region.
|
||||||
Region* region = result.addRegion();
|
Region *region = result.addRegion();
|
||||||
if (parser.parseRegion(*region, llvm::None, llvm::None))
|
if (parser.parseRegion(*region, llvm::None, llvm::None))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -142,21 +143,22 @@ ParseResult parseKrnlOptimizeLoopsOp(
|
||||||
* Then the bounds will be parsed as:
|
* Then the bounds will be parsed as:
|
||||||
* %i0 = 10 to N : %i1 = M to 20
|
* %i0 = 10 to N : %i1 = M to 20
|
||||||
*/
|
*/
|
||||||
void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
void KrnlIterateOp::build(Builder *builder, OperationState &result,
|
||||||
KrnlIterateOperandPack operandPack) {
|
KrnlIterateOperandPack operandPack) {
|
||||||
// Record optimized loops and the number of such loops.
|
// Record optimized loops and the number of such loops.
|
||||||
result.addOperands(operandPack.getOperands());
|
result.addOperands(operandPack.getOperands());
|
||||||
result.addAttribute(
|
result.addAttribute(KrnlIterateOp::getBoundsAttrName(),
|
||||||
KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes());
|
operandPack.getAttributes());
|
||||||
|
|
||||||
result.addAttribute(getNumOptimizedLoopsAttrName(),
|
result.addAttribute(
|
||||||
|
getNumOptimizedLoopsAttrName(),
|
||||||
builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
|
builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
|
||||||
|
|
||||||
// Create a region and a block for the body. The arguments of the region are
|
// Create a region and a block for the body. The arguments of the region are
|
||||||
// the loop induction variables; there can be multiple induction variables
|
// the loop induction variables; there can be multiple induction variables
|
||||||
// associated with the same krnl.iterate operation.
|
// associated with the same krnl.iterate operation.
|
||||||
Region* bodyRegion = result.addRegion();
|
Region *bodyRegion = result.addRegion();
|
||||||
auto* body = new Block();
|
auto *body = new Block();
|
||||||
auto body_args = llvm::SmallVector<Type, 4>(
|
auto body_args = llvm::SmallVector<Type, 4>(
|
||||||
operandPack.getNumInputLoops(), IndexType::get(builder->getContext()));
|
operandPack.getNumInputLoops(), IndexType::get(builder->getContext()));
|
||||||
body->addArguments(body_args);
|
body->addArguments(body_args);
|
||||||
|
@ -165,7 +167,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
||||||
ensureTerminator(*bodyRegion, *builder, result.location);
|
ensureTerminator(*bodyRegion, *builder, result.location);
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
void print(OpAsmPrinter &p, KrnlIterateOp &op) {
|
||||||
p << "krnl.iterate(";
|
p << "krnl.iterate(";
|
||||||
// Print optimized loops:
|
// Print optimized loops:
|
||||||
auto numOptimizedLoops = op.getNumOptimizedLoops();
|
auto numOptimizedLoops = op.getNumOptimizedLoops();
|
||||||
|
@ -180,7 +182,7 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
||||||
auto operandItr = op.operand_begin() + numOptimizedLoops;
|
auto operandItr = op.operand_begin() + numOptimizedLoops;
|
||||||
|
|
||||||
std::string delimiter;
|
std::string delimiter;
|
||||||
for (auto& var : inductionVars) {
|
for (auto &var : inductionVars) {
|
||||||
p << delimiter;
|
p << delimiter;
|
||||||
p.printOperand(*operandItr++);
|
p.printOperand(*operandItr++);
|
||||||
p << " -> ";
|
p << " -> ";
|
||||||
|
@ -194,25 +196,26 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
||||||
|
|
||||||
p << ")";
|
p << ")";
|
||||||
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
|
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
|
||||||
/*printBlockTerminators=*/false);
|
/*printBlockTerminators=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
ParseResult parseKrnlIterateOp(OpAsmParser &parser, OperationState &result) {
|
||||||
auto builder = parser.getBuilder();
|
auto builder = parser.getBuilder();
|
||||||
auto context = builder.getContext();
|
auto context = builder.getContext();
|
||||||
onnf::KrnlDialectOperandParser operandParser(parser);
|
onnf::KrnlDialectOperandParser operandParser(parser);
|
||||||
|
|
||||||
// Parse optimized loops:
|
// Parse optimized loops:
|
||||||
SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
|
SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
|
||||||
if (parser.parseOperandList(
|
if (parser.parseOperandList(optimizedLoopRefs,
|
||||||
optimizedLoopRefs, OpAsmParser::Delimiter::Paren) ||
|
OpAsmParser::Delimiter::Paren) ||
|
||||||
parser.resolveOperands(optimizedLoopRefs,
|
parser.resolveOperands(optimizedLoopRefs,
|
||||||
LoopType::get(result.getContext()), result.operands))
|
LoopType::get(result.getContext()),
|
||||||
|
result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Record how many optimized loops did we parse.
|
// Record how many optimized loops did we parse.
|
||||||
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
|
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
|
||||||
builder.getI64IntegerAttr(optimizedLoopRefs.size()));
|
builder.getI64IntegerAttr(optimizedLoopRefs.size()));
|
||||||
|
|
||||||
// Parse input loops and their lower and upper bounds.
|
// Parse input loops and their lower and upper bounds.
|
||||||
SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs;
|
SmallVector<OpAsmParser::OperandType, 4> inductionVarRefs;
|
||||||
|
@ -222,16 +225,16 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// A function to parse a lower or upper bound.
|
// A function to parse a lower or upper bound.
|
||||||
auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps](
|
auto parseBound = [&result, &builder, &parser, &operandParser,
|
||||||
bool isUpper) -> ParseResult {
|
&boundMaps](bool isUpper) -> ParseResult {
|
||||||
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
|
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
|
||||||
// the map has multiple results.
|
// the map has multiple results.
|
||||||
bool failedToParsedMinMax =
|
bool failedToParsedMinMax =
|
||||||
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
|
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
|
||||||
|
|
||||||
// Try parse an SSA operand.
|
// Try parse an SSA operand.
|
||||||
if (succeeded(operandParser.ParseOptionalOperand(
|
if (succeeded(operandParser.ParseOptionalOperand(builder.getIndexType(),
|
||||||
builder.getIndexType(), result.operands))) {
|
result.operands))) {
|
||||||
AffineMap map = builder.getSymbolIdentityMap();
|
AffineMap map = builder.getSymbolIdentityMap();
|
||||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
return success();
|
return success();
|
||||||
|
@ -243,8 +246,8 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||||
llvm::SMLoc attrLoc = parser.getCurrentLocation();
|
llvm::SMLoc attrLoc = parser.getCurrentLocation();
|
||||||
Attribute boundAttr;
|
Attribute boundAttr;
|
||||||
llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
|
llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
|
||||||
if (parser.parseAttribute(
|
if (parser.parseAttribute(boundAttr, builder.getIndexType(), "temp",
|
||||||
boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer))
|
tempBoundAttrContainer))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
|
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
|
||||||
|
@ -255,13 +258,15 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||||
|
|
||||||
auto map = affineMapAttr.getValue();
|
auto map = affineMapAttr.getValue();
|
||||||
if (map.getNumDims() != numDims)
|
if (map.getNumDims() != numDims)
|
||||||
return parser.emitError(parser.getNameLoc(),
|
return parser.emitError(
|
||||||
|
parser.getNameLoc(),
|
||||||
"dim operand count and integer set dim count must match");
|
"dim operand count and integer set dim count must match");
|
||||||
|
|
||||||
unsigned numDimAndSymbolOperands =
|
unsigned numDimAndSymbolOperands =
|
||||||
result.operands.size() - currentNumOperands;
|
result.operands.size() - currentNumOperands;
|
||||||
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
|
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
|
||||||
return parser.emitError(parser.getNameLoc(),
|
return parser.emitError(
|
||||||
|
parser.getNameLoc(),
|
||||||
"symbol operand count and integer set symbol count must match");
|
"symbol operand count and integer set symbol count must match");
|
||||||
|
|
||||||
// If the map has multiple results, make sure that we parsed the min/max
|
// If the map has multiple results, make sure that we parsed the min/max
|
||||||
|
@ -269,11 +274,11 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||||
if (map.getNumResults() > 1 && failedToParsedMinMax) {
|
if (map.getNumResults() > 1 && failedToParsedMinMax) {
|
||||||
if (isUpper)
|
if (isUpper)
|
||||||
return parser.emitError(attrLoc,
|
return parser.emitError(attrLoc,
|
||||||
"upper loop bound affine map with multiple "
|
"upper loop bound affine map with multiple "
|
||||||
"results requires 'min' prefix");
|
"results requires 'min' prefix");
|
||||||
return parser.emitError(attrLoc,
|
return parser.emitError(attrLoc,
|
||||||
"lower loop bound affine mapwith "
|
"lower loop bound affine mapwith "
|
||||||
"multiple results requires 'max' prefix");
|
"multiple results requires 'max' prefix");
|
||||||
}
|
}
|
||||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
return success();
|
return success();
|
||||||
|
@ -286,7 +291,7 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool keepParsing; // Do we keep parsing loops/bounds?
|
bool keepParsing; // Do we keep parsing loops/bounds?
|
||||||
do {
|
do {
|
||||||
// Parse an input loop operand;
|
// Parse an input loop operand;
|
||||||
operandParser.ParseOperand(LoopType::get(context), result.operands);
|
operandParser.ParseOperand(LoopType::get(context), result.operands);
|
||||||
|
@ -316,18 +321,18 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||||
// At this point, there shouldn't be any operands left to parse.
|
// At this point, there shouldn't be any operands left to parse.
|
||||||
if (operandParser.hasOperandLeft())
|
if (operandParser.hasOperandLeft())
|
||||||
return parser.emitError(parser.getCurrentLocation());
|
return parser.emitError(parser.getCurrentLocation());
|
||||||
result.addAttribute(
|
result.addAttribute(KrnlIterateOp::getBoundsAttrName(),
|
||||||
KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps));
|
builder.getArrayAttr(boundMaps));
|
||||||
|
|
||||||
Region* region = result.addRegion();
|
Region *region = result.addRegion();
|
||||||
SmallVector<Type, 4> inductionVarTypes(
|
SmallVector<Type, 4> inductionVarTypes(inductionVarRefs.size(),
|
||||||
inductionVarRefs.size(), builder.getIndexType());
|
builder.getIndexType());
|
||||||
if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
|
if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Ensure iterate region is closed off with krnl.terminate.
|
// Ensure iterate region is closed off with krnl.terminate.
|
||||||
KrnlIterateOp::ensureTerminator(
|
KrnlIterateOp::ensureTerminator(*region, parser.getBuilder(),
|
||||||
*region, parser.getBuilder(), result.location);
|
result.location);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -341,18 +346,19 @@ static LogicalResult verify(KrnlIterateOp op) {
|
||||||
// KrnlReturnLoopsOp
|
// KrnlReturnLoopsOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void print(OpAsmPrinter& p, KrnlReturnLoopsOp& op) {
|
void print(OpAsmPrinter &p, KrnlReturnLoopsOp &op) {
|
||||||
p << "krnl.return_loops ";
|
p << "krnl.return_loops ";
|
||||||
p.printOperands(op.operand_begin(), op.operand_end());
|
p.printOperands(op.operand_begin(), op.operand_end());
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseKrnlReturnLoopsOp(
|
ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser,
|
||||||
OpAsmParser& parser, OperationState& result) {
|
OperationState &result) {
|
||||||
// Parse the loops to return.
|
// Parse the loops to return.
|
||||||
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
|
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
|
||||||
if (parser.parseOperandList(timestamp_dim_handlers) ||
|
if (parser.parseOperandList(timestamp_dim_handlers) ||
|
||||||
parser.resolveOperands(timestamp_dim_handlers,
|
parser.resolveOperands(timestamp_dim_handlers,
|
||||||
LoopType::get(result.getContext()), result.operands))
|
LoopType::get(result.getContext()),
|
||||||
|
result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -360,4 +366,4 @@ ParseResult parseKrnlReturnLoopsOp(
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "src/compiler/krnl.cpp.inc"
|
#include "src/compiler/krnl.cpp.inc"
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -19,12 +19,12 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class KrnlOpsDialect : public Dialect {
|
class KrnlOpsDialect : public Dialect {
|
||||||
public:
|
public:
|
||||||
KrnlOpsDialect(MLIRContext* context);
|
KrnlOpsDialect(MLIRContext *context);
|
||||||
static StringRef getDialectNamespace() { return "krnl"; }
|
static StringRef getDialectNamespace() { return "krnl"; }
|
||||||
|
|
||||||
/// Parse a type registered to this dialect.
|
/// Parse a type registered to this dialect.
|
||||||
Type parseType(DialectAsmParser& parser) const override {
|
Type parseType(DialectAsmParser &parser) const override {
|
||||||
if (succeeded(parser.parseOptionalKeyword("loop")))
|
if (succeeded(parser.parseOptionalKeyword("loop")))
|
||||||
return LoopType::get(parser.getBuilder().getContext());
|
return LoopType::get(parser.getBuilder().getContext());
|
||||||
|
|
||||||
|
@ -32,15 +32,15 @@ class KrnlOpsDialect : public Dialect {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Print a type registered to this dialect.
|
/// Print a type registered to this dialect.
|
||||||
void printType(Type type, DialectAsmPrinter& os) const override {
|
void printType(Type type, DialectAsmPrinter &os) const override {
|
||||||
switch (type.getKind()) {
|
switch (type.getKind()) {
|
||||||
case KrnlTypes::Loop:
|
case KrnlTypes::Loop:
|
||||||
os << "loop";
|
os << "loop";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "src/compiler/krnl.hpp.inc"
|
#include "src/compiler/krnl.hpp.inc"
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -11,17 +11,16 @@
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
|
||||||
#include "llvm/ADT/Sequence.h"
|
|
||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
|
||||||
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
|
||||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
|
||||||
#include "src/compiler/pass/passes.hpp"
|
#include "src/compiler/pass/passes.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -98,7 +97,7 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
|
|
||||||
// Make sure to allocate at the beginning of the block if
|
// Make sure to allocate at the beginning of the block if
|
||||||
// all dimensions are known.
|
// all dimensions are known.
|
||||||
auto* parentBlock = alloc.getOperation()->getBlock();
|
auto *parentBlock = alloc.getOperation()->getBlock();
|
||||||
if (hasAllConstantDimensions(type))
|
if (hasAllConstantDimensions(type))
|
||||||
alloc.getOperation()->moveBefore(&parentBlock->front());
|
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||||
|
|
||||||
|
@ -113,17 +112,17 @@ static Value *insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
// Determine if current function returns the result value of the
|
// Determine if current function returns the result value of the
|
||||||
// current op being lowered. If it does then dealloc should not be
|
// current op being lowered. If it does then dealloc should not be
|
||||||
// inserted.
|
// inserted.
|
||||||
static bool checkInsertDealloc(Operation* currentOp) {
|
static bool checkInsertDealloc(Operation *currentOp) {
|
||||||
auto parentBlock = currentOp->getBlock();
|
auto parentBlock = currentOp->getBlock();
|
||||||
|
|
||||||
bool insertDealloc = true;
|
bool insertDealloc = true;
|
||||||
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
||||||
assert(currentOp->getNumResults() < 2 &&
|
assert(currentOp->getNumResults() < 2 &&
|
||||||
"No more than one result supported (for now).");
|
"No more than one result supported (for now).");
|
||||||
// If there is at least one result to investigate.
|
// If there is at least one result to investigate.
|
||||||
if (currentOp->getNumResults() > 0) {
|
if (currentOp->getNumResults() > 0) {
|
||||||
auto result = currentOp->getResult(0);
|
auto result = currentOp->getResult(0);
|
||||||
for (const auto& operand : op.getOperands())
|
for (const auto &operand : op.getOperands())
|
||||||
if (operand == result)
|
if (operand == result)
|
||||||
insertDealloc = false;
|
insertDealloc = false;
|
||||||
}
|
}
|
||||||
|
@ -148,7 +147,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
// broadcasting.
|
// broadcasting.
|
||||||
std::map<int, std::map<int, Value *> >
|
std::map<int, std::map<int, Value *>>
|
||||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
MemRefType memRefType, ArrayRef<Value *> operands) {
|
MemRefType memRefType, ArrayRef<Value *> operands) {
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
|
@ -196,15 +195,15 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
// given operand.
|
// given operand.
|
||||||
std::vector<Value *>
|
std::vector<Value *>
|
||||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
ArrayRef<Value *> loopIVs, Value *operand,
|
ArrayRef<Value *> loopIVs, Value *operand,
|
||||||
std::map<int, Value *> broadcastedDims) {
|
std::map<int, Value *> broadcastedDims) {
|
||||||
// `operand` must has a ranked type. This should have been checked by the
|
// `operand` must has a ranked type. This should have been checked by the
|
||||||
// shape inference pass.
|
// shape inference pass.
|
||||||
auto operandShape = operand->getType().cast<MemRefType>().getShape();
|
auto operandShape = operand->getType().cast<MemRefType>().getShape();
|
||||||
auto rank = operandShape.size();
|
auto rank = operandShape.size();
|
||||||
auto loopCount = loopIVs.size();
|
auto loopCount = loopIVs.size();
|
||||||
|
|
||||||
std::vector<Value*> newLoopIVs;
|
std::vector<Value *> newLoopIVs;
|
||||||
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
|
||||||
auto dimIdx = rank - 1 - reversedIdx;
|
auto dimIdx = rank - 1 - reversedIdx;
|
||||||
auto loopIdx = loopCount - 1 - reversedIdx;
|
auto loopIdx = loopCount - 1 - reversedIdx;
|
||||||
|
@ -218,8 +217,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
// If its value is 1, it is broadcasted dimension.
|
// If its value is 1, it is broadcasted dimension.
|
||||||
// Otherwise, non-broadcasted dimension.
|
// Otherwise, non-broadcasted dimension.
|
||||||
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx],
|
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
|
||||||
zero, loopIVs[loopIdx]);
|
loopIVs[loopIdx]);
|
||||||
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
newLoopIVs.insert(newLoopIVs.begin(), idx);
|
||||||
} else {
|
} else {
|
||||||
// Non-broadcasted dimension
|
// Non-broadcasted dimension
|
||||||
|
@ -260,26 +259,26 @@ struct ScalarOp<ONNXSubOp> {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<ONNXAndOp> {
|
struct ScalarOp<ONNXAndOp> {
|
||||||
using FOp = AndOp; // not use
|
using FOp = AndOp; // not use
|
||||||
using IOp = AndOp;
|
using IOp = AndOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<ONNXOrOp> {
|
struct ScalarOp<ONNXOrOp> {
|
||||||
using FOp = OrOp; // not use
|
using FOp = OrOp; // not use
|
||||||
using IOp = OrOp;
|
using IOp = OrOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<ONNXXorOp> {
|
struct ScalarOp<ONNXXorOp> {
|
||||||
using FOp = XOrOp; // not use
|
using FOp = XOrOp; // not use
|
||||||
using IOp = XOrOp;
|
using IOp = XOrOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<ONNXExpOp> {
|
struct ScalarOp<ONNXExpOp> {
|
||||||
using FOp = ExpOp;
|
using FOp = ExpOp;
|
||||||
using IOp = ExpOp; // not use
|
using IOp = ExpOp; // not use
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -297,18 +296,19 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
|
||||||
// Scalar unary ops for lowering to Krnl dialect.
|
// Scalar unary ops for lowering to Krnl dialect.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <typename UnaryOp>
|
template <typename UnaryOp>
|
||||||
Value* mapToLowerScalarOp(Operation* op, ArrayRef<Type> result_types,
|
Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
/* Lower UnaryOp to Ops in the Standard dialect.
|
/* Lower UnaryOp to Ops in the Standard dialect.
|
||||||
*/
|
*/
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Type element_type = operands.front()->getType();
|
Type element_type = operands.front()->getType();
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
return rewriter.create<ScalarIOp<UnaryOp>>(
|
return rewriter.create<ScalarIOp<UnaryOp>>(loc, result_types, operands,
|
||||||
loc, result_types, operands, mlir::None);
|
mlir::None);
|
||||||
} else if (element_type.isa<FloatType>()) {
|
} else if (element_type.isa<FloatType>()) {
|
||||||
return rewriter.create<ScalarFOp<UnaryOp>>(
|
return rewriter.create<ScalarFOp<UnaryOp>>(loc, result_types, operands,
|
||||||
loc, result_types, operands, mlir::None);
|
mlir::None);
|
||||||
} else {
|
} else {
|
||||||
emitError(loc, "unsupported element type");
|
emitError(loc, "unsupported element type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -319,13 +319,14 @@ Value* mapToLowerScalarOp(Operation* op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXTanhOp
|
// Scalar unary ops for lowering ONNXTanhOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXTanhOp>(Operation* op,
|
Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
ArrayRef<Type> result_types,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
|
@ -333,7 +334,7 @@ Value* mapToLowerScalarOp<ONNXTanhOp>(Operation* op,
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
auto result =
|
auto result =
|
||||||
rewriter.create<DivFOp>(loc, rewriter.create<SubFOp>(loc, exp, negExp),
|
rewriter.create<DivFOp>(loc, rewriter.create<SubFOp>(loc, exp, negExp),
|
||||||
rewriter.create<AddFOp>(loc, exp, negExp));
|
rewriter.create<AddFOp>(loc, exp, negExp));
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -342,13 +343,14 @@ Value* mapToLowerScalarOp<ONNXTanhOp>(Operation* op,
|
||||||
// Scalar unary ops for lowering ONNXSinhOp
|
// Scalar unary ops for lowering ONNXSinhOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXSinhOp>(Operation* op,
|
Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
ArrayRef<Type> result_types,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
||||||
|
@ -365,13 +367,14 @@ Value* mapToLowerScalarOp<ONNXSinhOp>(Operation* op,
|
||||||
// Scalar unary ops for lowering ONNXCoshOp
|
// Scalar unary ops for lowering ONNXCoshOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXCoshOp>(Operation* op,
|
Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
ArrayRef<Type> result_types,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
||||||
|
@ -388,13 +391,14 @@ Value* mapToLowerScalarOp<ONNXCoshOp>(Operation* op,
|
||||||
// Scalar unary ops for lowering ONNXSigmoidOp
|
// Scalar unary ops for lowering ONNXSigmoidOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXSigmoidOp>(Operation* op,
|
Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
ArrayRef<Type> result_types,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
||||||
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||||
|
@ -410,9 +414,9 @@ Value* mapToLowerScalarOp<ONNXSigmoidOp>(Operation* op,
|
||||||
// Scalar unary ops for lowering ONNXHardSigmoidOp
|
// Scalar unary ops for lowering ONNXHardSigmoidOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op,
|
Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value *> operands,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
// %Y = AddFOp(MulFOp(alpha, %X), beta)
|
// %Y = AddFOp(MulFOp(alpha, %X), beta)
|
||||||
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
|
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
|
||||||
// %Y,
|
// %Y,
|
||||||
|
@ -421,7 +425,7 @@ Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op,
|
||||||
// %Z,
|
// %Z,
|
||||||
// Constant 1)
|
// Constant 1)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
|
||||||
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
|
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
|
||||||
|
|
||||||
|
@ -446,13 +450,14 @@ Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op,
|
||||||
// Scalar unary ops for lowering ONNXEluOp
|
// Scalar unary ops for lowering ONNXEluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types,
|
Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
|
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
@ -461,9 +466,10 @@ Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types,
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
auto result = rewriter.create<SelectOp>(loc, lessThanZero,
|
auto result = rewriter.create<SelectOp>(
|
||||||
rewriter.create<MulFOp>(
|
loc, lessThanZero,
|
||||||
loc, alpha, rewriter.create<SubFOp>(loc, exp, one)),
|
rewriter.create<MulFOp>(loc, alpha,
|
||||||
|
rewriter.create<SubFOp>(loc, exp, one)),
|
||||||
operand);
|
operand);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -473,14 +479,15 @@ Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXReluOp
|
// Scalar unary ops for lowering ONNXReluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXReluOp>(Operation* op,
|
Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
ArrayRef<Type> result_types,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// ConstantOp 0,
|
// ConstantOp 0,
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
|
@ -494,14 +501,15 @@ Value* mapToLowerScalarOp<ONNXReluOp>(Operation* op,
|
||||||
// Scalar unary ops for lowering ONNXLeakyReluOp
|
// Scalar unary ops for lowering ONNXLeakyReluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXLeakyReluOp>(Operation* op,
|
Value *
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// MulFOp(alpha, %X),
|
// MulFOp(alpha, %X),
|
||||||
// %X)
|
// %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
@ -518,16 +526,17 @@ Value* mapToLowerScalarOp<ONNXLeakyReluOp>(Operation* op,
|
||||||
// Scalar unary ops for lowering ONNXSeluOp
|
// Scalar unary ops for lowering ONNXSeluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
|
Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
ArrayRef<Type> result_types,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
|
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
|
||||||
// MulFOp(gamma, %X),
|
// MulFOp(gamma, %X),
|
||||||
// MulFOp(gamma,
|
// MulFOp(gamma,
|
||||||
// SubFOp(MulFOp(alpha, ExpOp(%X)),
|
// SubFOp(MulFOp(alpha, ExpOp(%X)),
|
||||||
// alpha)))
|
// alpha)))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
|
||||||
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
|
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
|
||||||
|
|
||||||
|
@ -537,9 +546,10 @@ Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
auto greaterThanZero =
|
auto greaterThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
||||||
auto select = rewriter.create<SelectOp>(loc, greaterThanZero, operand,
|
auto select = rewriter.create<SelectOp>(
|
||||||
rewriter.create<SubFOp>(
|
loc, greaterThanZero, operand,
|
||||||
loc, rewriter.create<MulFOp>(loc, alpha, exp), alpha));
|
rewriter.create<SubFOp>(loc, rewriter.create<MulFOp>(loc, alpha, exp),
|
||||||
|
alpha));
|
||||||
auto result = rewriter.create<MulFOp>(loc, gamma, select);
|
auto result = rewriter.create<MulFOp>(loc, gamma, select);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -549,11 +559,13 @@ Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
|
||||||
// Scalar unary ops for lowering ONNXReciprocalOp
|
// Scalar unary ops for lowering ONNXReciprocalOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXReciprocalOp>(Operation* op, ArrayRef<Type> result_types,
|
Value *
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value *operand = operands[0];
|
||||||
|
|
||||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||||
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
auto result = rewriter.create<DivFOp>(loc, one, operand);
|
||||||
|
@ -565,14 +577,15 @@ Value* mapToLowerScalarOp<ONNXReciprocalOp>(Operation* op, ArrayRef<Type> result
|
||||||
// Scalar unary ops for lowering ONNXMaxOp
|
// Scalar unary ops for lowering ONNXMaxOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXMaxOp>(Operation* op, ArrayRef<Type> result_types,
|
Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
|
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
|
||||||
// %X,
|
// %X,
|
||||||
// %Y)
|
// %Y)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* lhs = operands[0];
|
Value *lhs = operands[0];
|
||||||
Value* rhs = operands[1];
|
Value *rhs = operands[1];
|
||||||
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
||||||
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||||
return result;
|
return result;
|
||||||
|
@ -582,14 +595,15 @@ Value* mapToLowerScalarOp<ONNXMaxOp>(Operation* op, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXMinOp
|
// Scalar unary ops for lowering ONNXMinOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXMinOp>(Operation* op, ArrayRef<Type> result_types,
|
Value *mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
|
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
|
||||||
// %X,
|
// %X,
|
||||||
// %Y)
|
// %Y)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value* lhs = operands[0];
|
Value *lhs = operands[0];
|
||||||
Value* rhs = operands[1];
|
Value *rhs = operands[1];
|
||||||
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
||||||
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||||
return result;
|
return result;
|
||||||
|
@ -599,10 +613,11 @@ Value* mapToLowerScalarOp<ONNXMinOp>(Operation* op, ArrayRef<Type> result_types,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <typename ElementwiseUnaryOp>
|
template <typename ElementwiseUnaryOp>
|
||||||
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
ONNXElementwiseUnaryOpLowering(MLIRContext* ctx)
|
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
||||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
PatternMatchResult
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
// TODO: Check that the types are valid.
|
// TODO: Check that the types are valid.
|
||||||
// An element-wise unary operation must have all operands and the result of
|
// An element-wise unary operation must have all operands and the result of
|
||||||
// the same type. This should have been verified by the verifier.
|
// the same type. This should have been verified by the verifier.
|
||||||
|
@ -618,14 +633,14 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
// dimensions with the result at this pre-optimization phase.
|
// dimensions with the result at this pre-optimization phase.
|
||||||
// TODO: verify that dimensions match.
|
// TODO: verify that dimensions match.
|
||||||
// TODO: can the dimension of the result differ after optimizations?
|
// TODO: can the dimension of the result differ after optimizations?
|
||||||
Value* alloc;
|
Value *alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||||
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
{operands[0]});
|
||||||
|
|
||||||
// Number of loops
|
// Number of loops
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
|
@ -633,7 +648,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Define loops.
|
// Define loops.
|
||||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||||
std::vector<Value*> originalLoops;
|
std::vector<Value *> originalLoops;
|
||||||
originalLoops.reserve(rank);
|
originalLoops.reserve(rank);
|
||||||
for (auto result : loopsOp.getResults()) {
|
for (auto result : loopsOp.getResults()) {
|
||||||
originalLoops.push_back(result);
|
originalLoops.push_back(result);
|
||||||
|
@ -641,12 +656,12 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Define loop optimization.
|
// Define loop optimization.
|
||||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||||
std::vector<Value*> optimizedLoops;
|
std::vector<Value *> optimizedLoops;
|
||||||
optimizedLoops.reserve(rank);
|
optimizedLoops.reserve(rank);
|
||||||
for (auto result : optimizedLoopsOp.getResults()) {
|
for (auto result : optimizedLoopsOp.getResults()) {
|
||||||
optimizedLoops.push_back(result);
|
optimizedLoops.push_back(result);
|
||||||
}
|
}
|
||||||
Block& optimizationBlock = optimizedLoopsOp.region().front();
|
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
|
|
||||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
// Iterate over the loop nest.
|
// Iterate over the loop nest.
|
||||||
|
@ -664,7 +679,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
Block& iterationBlock = iterateOp.bodyRegion().front();
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
// Now perform the insertions into the body of the
|
// Now perform the insertions into the body of the
|
||||||
// just generated instructions:
|
// just generated instructions:
|
||||||
|
@ -681,7 +696,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
// Handle the operation:
|
// Handle the operation:
|
||||||
SmallVector<Value*, 4> loopIVs;
|
SmallVector<Value *, 4> loopIVs;
|
||||||
for (auto arg : iterationBlock.getArguments())
|
for (auto arg : iterationBlock.getArguments())
|
||||||
loopIVs.push_back(arg);
|
loopIVs.push_back(arg);
|
||||||
|
|
||||||
|
@ -701,10 +716,11 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <typename ElementwiseVariadicOp>
|
template <typename ElementwiseVariadicOp>
|
||||||
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
ONNXElementwiseVariadicOpLowering(MLIRContext* ctx)
|
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
||||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
PatternMatchResult
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
// TODO: Check that the types are valid.
|
// TODO: Check that the types are valid.
|
||||||
// An element-wise variadic operation must have all operands and the result
|
// An element-wise variadic operation must have all operands and the result
|
||||||
// of the same type. This should have been verified by the verifier.
|
// of the same type. This should have been verified by the verifier.
|
||||||
|
@ -715,7 +731,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
|
||||||
Value* alloc;
|
Value *alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
// If the output has a dynamic dimension, we compute its dimension at
|
// If the output has a dynamic dimension, we compute its dimension at
|
||||||
// runtime by using dimensions from the operands.
|
// runtime by using dimensions from the operands.
|
||||||
|
@ -725,8 +741,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||||
memRefType, loc, rewriter, insertDealloc, operands);
|
operands);
|
||||||
|
|
||||||
// Number of loops
|
// Number of loops
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
|
@ -734,7 +750,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Define loops.
|
// Define loops.
|
||||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||||
std::vector<Value*> originalLoops;
|
std::vector<Value *> originalLoops;
|
||||||
originalLoops.reserve(rank);
|
originalLoops.reserve(rank);
|
||||||
for (auto result : loopsOp.getResults()) {
|
for (auto result : loopsOp.getResults()) {
|
||||||
originalLoops.push_back(result);
|
originalLoops.push_back(result);
|
||||||
|
@ -742,12 +758,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Define loop optimization.
|
// Define loop optimization.
|
||||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||||
std::vector<Value*> optimizedLoops;
|
std::vector<Value *> optimizedLoops;
|
||||||
optimizedLoops.reserve(rank);
|
optimizedLoops.reserve(rank);
|
||||||
for (auto result : optimizedLoopsOp.getResults()) {
|
for (auto result : optimizedLoopsOp.getResults()) {
|
||||||
optimizedLoops.push_back(result);
|
optimizedLoops.push_back(result);
|
||||||
}
|
}
|
||||||
Block& optimizationBlock = optimizedLoopsOp.region().front();
|
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
|
|
||||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
// Iterate over the loop nest.
|
// Iterate over the loop nest.
|
||||||
|
@ -770,7 +786,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
|
getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
|
||||||
|
|
||||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
Block& iterationBlock = iterateOp.bodyRegion().front();
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
// Now perform the insertions into the body of the
|
// Now perform the insertions into the body of the
|
||||||
// just generated instructions:
|
// just generated instructions:
|
||||||
|
@ -786,7 +802,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
// Handle the operation:
|
// Handle the operation:
|
||||||
SmallVector<Value*, 4> loopIVs;
|
SmallVector<Value *, 4> loopIVs;
|
||||||
for (auto arg : iterationBlock.getArguments())
|
for (auto arg : iterationBlock.getArguments())
|
||||||
loopIVs.push_back(arg);
|
loopIVs.push_back(arg);
|
||||||
|
|
||||||
|
@ -812,35 +828,36 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ONNXReshapeOpLowering : public ConversionPattern {
|
struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
ONNXReshapeOpLowering(MLIRContext* ctx)
|
ONNXReshapeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
PatternMatchResult
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
Value* alloc;
|
Value *alloc;
|
||||||
|
|
||||||
// Compute size in bytes.
|
// Compute size in bytes.
|
||||||
Value* tensorSize = rewriter.create<ConstantOp>(loc,
|
Value *tensorSize = rewriter.create<ConstantOp>(
|
||||||
rewriter.getIntegerAttr(
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)));
|
getMemRefEltSizeInBytes(memRefType)));
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
if (hasAllConstantDimensions(memRefType)) {
|
if (hasAllConstantDimensions(memRefType)) {
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
} else {
|
} else {
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
SmallVector<Value*, 4> allocOperands;
|
SmallVector<Value *, 4> allocOperands;
|
||||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
// The shape array can always be used to construct shape information of
|
// The shape array can always be used to construct shape information of
|
||||||
// the result.
|
// the result.
|
||||||
Value* index = rewriter.create<ConstantOp>(
|
Value *index = rewriter.create<ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||||
Value* loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
Value *loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||||
Value* int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
Value *int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
||||||
loc, loadedVal, rewriter.getIntegerType(64));
|
loc, loadedVal, rewriter.getIntegerType(64));
|
||||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
||||||
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
||||||
|
@ -851,7 +868,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Make sure to allocate at the beginning of the block if
|
// Make sure to allocate at the beginning of the block if
|
||||||
// all dimensions are known.
|
// all dimensions are known.
|
||||||
auto* parentBlock = allocateMemref.getOperation()->getBlock();
|
auto *parentBlock = allocateMemref.getOperation()->getBlock();
|
||||||
if (insertDealloc) {
|
if (insertDealloc) {
|
||||||
auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref);
|
auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref);
|
||||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
@ -874,7 +891,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
struct TensorTypeConverter : public TypeConverter {
|
struct TensorTypeConverter : public TypeConverter {
|
||||||
using TypeConverter::TypeConverter;
|
using TypeConverter::TypeConverter;
|
||||||
|
|
||||||
LogicalResult convertType(Type t, SmallVectorImpl<Type>& results) override {
|
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
|
||||||
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
||||||
results.push_back(convertTensorToMemRef(tensor_type));
|
results.push_back(convertTensorToMemRef(tensor_type));
|
||||||
return success();
|
return success();
|
||||||
|
@ -889,12 +906,12 @@ struct TensorTypeConverter : public TypeConverter {
|
||||||
/// inputs. Once unranked results can be handled gracefully this
|
/// inputs. Once unranked results can be handled gracefully this
|
||||||
/// override needs to be removed in favour of the original MLIR one.]
|
/// override needs to be removed in favour of the original MLIR one.]
|
||||||
bool isSignatureLegal(FunctionType funcType) {
|
bool isSignatureLegal(FunctionType funcType) {
|
||||||
return llvm::all_of(
|
return llvm::all_of(funcType.getInputs(),
|
||||||
funcType.getInputs(), [this](Type type) { return isLegal(type); });
|
[this](Type type) { return isLegal(type); });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Frontend to Krnl Dialect lowering pass
|
// Frontend to Krnl Dialect lowering pass
|
||||||
|
@ -906,7 +923,7 @@ struct FrontendToKrnlLoweringPass
|
||||||
: public ModulePass<FrontendToKrnlLoweringPass> {
|
: public ModulePass<FrontendToKrnlLoweringPass> {
|
||||||
void runOnModule() final;
|
void runOnModule() final;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
||||||
void FrontendToKrnlLoweringPass::runOnModule() {
|
void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
auto module = getModule();
|
auto module = getModule();
|
||||||
|
@ -943,32 +960,32 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
// Type conversion for function signatures.
|
// Type conversion for function signatures.
|
||||||
// Call MLIR FuncOp signature conversion when result type is
|
// Call MLIR FuncOp signature conversion when result type is
|
||||||
// a ranked tensor.
|
// a ranked tensor.
|
||||||
populateFuncOpTypeConversionPattern(
|
populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
||||||
patterns, &getContext(), tensor_to_memref_converter);
|
tensor_to_memref_converter);
|
||||||
|
|
||||||
// Frontent operation lowering.
|
// Frontent operation lowering.
|
||||||
patterns.insert<ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>,
|
patterns.insert<ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||||
ONNXReshapeOpLowering>(&getContext());
|
ONNXReshapeOpLowering>(&getContext());
|
||||||
|
|
||||||
// With the target and rewrite patterns defined, we can now attempt the
|
// With the target and rewrite patterns defined, we can now attempt the
|
||||||
// conversion. The conversion will signal failure if any of our `illegal`
|
// conversion. The conversion will signal failure if any of our `illegal`
|
||||||
|
@ -981,5 +998,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
|
||||||
return std::make_unique<FrontendToKrnlLoweringPass>();
|
return std::make_unique<FrontendToKrnlLoweringPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<FrontendToKrnlLoweringPass> pass(
|
static PassRegistration<FrontendToKrnlLoweringPass>
|
||||||
"lower-frontend", "Lower frontend ops to Krnl dialect.");
|
pass("lower-frontend", "Lower frontend ops to Krnl dialect.");
|
||||||
|
|
|
@ -17,8 +17,8 @@ namespace {
|
||||||
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||||
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
|
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(
|
PatternMatchResult matchAndRewrite(KrnlIterateOp iterateOp,
|
||||||
KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto boundMapAttrs =
|
auto boundMapAttrs =
|
||||||
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
||||||
.getValue();
|
.getValue();
|
||||||
|
@ -30,23 +30,23 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||||
operandItr++;
|
operandItr++;
|
||||||
|
|
||||||
// Organize operands into lower/upper bounds in affine.for ready formats.
|
// Organize operands into lower/upper bounds in affine.for ready formats.
|
||||||
SmallVector<Value*, 4> lbOperands, ubOperands;
|
SmallVector<Value *, 4> lbOperands, ubOperands;
|
||||||
AffineMap lbMap, ubMap;
|
AffineMap lbMap, ubMap;
|
||||||
for (int boundType = 0; boundType < 2; boundType++) {
|
for (int boundType = 0; boundType < 2; boundType++) {
|
||||||
auto& operands = boundType == 0 ? lbOperands : ubOperands;
|
auto &operands = boundType == 0 ? lbOperands : ubOperands;
|
||||||
auto& map = boundType == 0 ? lbMap : ubMap;
|
auto &map = boundType == 0 ? lbMap : ubMap;
|
||||||
map = boundMapAttrs[boundIdx + boundType]
|
map = boundMapAttrs[boundIdx + boundType]
|
||||||
.cast<AffineMapAttr>()
|
.cast<AffineMapAttr>()
|
||||||
.getValue();
|
.getValue();
|
||||||
operands.insert(
|
operands.insert(operands.end(), operandItr,
|
||||||
operands.end(), operandItr, operandItr + map.getNumInputs());
|
operandItr + map.getNumInputs());
|
||||||
std::advance(operandItr, map.getNumInputs());
|
std::advance(operandItr, map.getNumInputs());
|
||||||
}
|
}
|
||||||
|
|
||||||
nestedForOps.emplace_back(rewriter.create<AffineForOp>(
|
nestedForOps.emplace_back(rewriter.create<AffineForOp>(
|
||||||
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap));
|
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap));
|
||||||
rewriter.setInsertionPoint(nestedForOps.back().getBody(),
|
rewriter.setInsertionPoint(nestedForOps.back().getBody(),
|
||||||
nestedForOps.back().getBody()->begin());
|
nestedForOps.back().getBody()->begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace induction variable references from those introduced by a
|
// Replace induction variable references from those introduced by a
|
||||||
|
@ -68,7 +68,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||||
auto innermostForOp = nestedForOps.back();
|
auto innermostForOp = nestedForOps.back();
|
||||||
innermostForOp.region().getBlocks().clear();
|
innermostForOp.region().getBlocks().clear();
|
||||||
rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(),
|
rewriter.inlineRegionBefore(iterateOp.bodyRegion(), innermostForOp.region(),
|
||||||
innermostForOp.region().end());
|
innermostForOp.region().end());
|
||||||
|
|
||||||
rewriter.eraseOp(iterateOp);
|
rewriter.eraseOp(iterateOp);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
|
@ -80,11 +80,11 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
|
class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
|
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(
|
PatternMatchResult matchAndRewrite(KrnlTerminatorOp op,
|
||||||
KrnlTerminatorOp op, PatternRewriter& rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
@ -95,11 +95,11 @@ class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
|
class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
|
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(
|
PatternMatchResult matchAndRewrite(KrnlDefineLoopsOp op,
|
||||||
KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
@ -110,11 +110,11 @@ class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
|
class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
|
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(
|
PatternMatchResult matchAndRewrite(KrnlOptimizeLoopsOp op,
|
||||||
KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
@ -132,7 +132,7 @@ struct KrnlToAffineLoweringPass
|
||||||
: public FunctionPass<KrnlToAffineLoweringPass> {
|
: public FunctionPass<KrnlToAffineLoweringPass> {
|
||||||
void runOnFunction() final;
|
void runOnFunction() final;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
||||||
void KrnlToAffineLoweringPass::runOnFunction() {
|
void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
auto function = getFunction();
|
auto function = getFunction();
|
||||||
|
@ -146,17 +146,18 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||||
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext());
|
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(
|
||||||
|
&getContext());
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
||||||
return std::make_unique<KrnlToAffineLoweringPass>();
|
return std::make_unique<KrnlToAffineLoweringPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<KrnlToAffineLoweringPass> pass(
|
static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl",
|
||||||
"lower-krnl", "Lower Krnl dialect.");
|
"Lower Krnl dialect.");
|
14
src/main.cpp
14
src/main.cpp
|
@ -76,7 +76,7 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext& context,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int ac, char* av[]) {
|
int main(int ac, char *av[]) {
|
||||||
namespace po = boost::program_options;
|
namespace po = boost::program_options;
|
||||||
|
|
||||||
po::options_description desc("ONNF available options");
|
po::options_description desc("ONNF available options");
|
||||||
|
@ -91,8 +91,8 @@ int main(int ac, char* av[]) {
|
||||||
po::positional_options_description p;
|
po::positional_options_description p;
|
||||||
p.add("onnx-model", -1);
|
p.add("onnx-model", -1);
|
||||||
po::variables_map vm;
|
po::variables_map vm;
|
||||||
po::store(
|
po::store(po::command_line_parser(ac, av).options(desc).positional(p).run(),
|
||||||
po::command_line_parser(ac, av).options(desc).positional(p).run(), vm);
|
vm);
|
||||||
|
|
||||||
// TODO: allow multiple input files
|
// TODO: allow multiple input files
|
||||||
assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!");
|
assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!");
|
||||||
|
@ -137,10 +137,10 @@ int main(int ac, char* av[]) {
|
||||||
|
|
||||||
// Write LLVM bitcode to disk.
|
// Write LLVM bitcode to disk.
|
||||||
std::error_code EC;
|
std::error_code EC;
|
||||||
llvm::raw_fd_ostream moduleBitcodeStream(
|
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", EC,
|
||||||
"model.bc", EC, llvm::sys::fs::F_None);
|
llvm::sys::fs::F_None);
|
||||||
llvm::WriteBitcodeToFile(
|
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
||||||
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
moduleBitcodeStream);
|
||||||
moduleBitcodeStream.flush();
|
moduleBitcodeStream.flush();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
Loading…
Reference in New Issue