Format Key Files using LLVM Style (#403)

* format using llvm style

* merge and format
This commit is contained in:
Tian Jin 2019-12-19 13:27:15 -05:00 committed by Tian Jin
parent 06a968d4a1
commit a6a40cf989
6 changed files with 277 additions and 251 deletions

View File

@ -540,7 +540,8 @@ private:
attributes.push_back(attr);
} else {
// TODO: the attributes need special handling
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
// std::cout << "missing " << node.op_type() << " " << attr_name <<
// std::endl;
}
}
@ -585,7 +586,8 @@ private:
attributes.push_back(attr);
} else {
// TODO: the attributes need special handling
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
// std::cout << "missing " << node.op_type() << " " << attr_name <<
// std::endl;
}
}

View File

@ -9,8 +9,6 @@
#include <iostream>
#include <queue>
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Block.h"
@ -23,6 +21,8 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
@ -44,13 +44,13 @@ KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context)
// KrnlDefineLoopsOp
//===----------------------------------------------------------------------===//
void KrnlDefineLoopsOp::build(
Builder* builder, OperationState& result, int64_t num_loops) {
void KrnlDefineLoopsOp::build(Builder *builder, OperationState &result,
int64_t num_loops) {
// Create the same number of dimension handlers as the number of
// dimensions in the associated integer set.
result.types.append(num_loops, LoopType::get(builder->getContext()));
result.addAttribute(
getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops));
result.addAttribute(getNumLoopsAttrName(),
builder->getI32IntegerAttr(num_loops));
}
void print(OpAsmPrinter &p, KrnlDefineLoopsOp &op) {
@ -59,14 +59,15 @@ void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) {
p << "krnl.define_loops " << numLoopAttr.getValue().getSExtValue();
}
ParseResult parseKrnlDefineLoopsOp(
OpAsmParser& parser, OperationState& result) {
ParseResult parseKrnlDefineLoopsOp(OpAsmParser &parser,
OperationState &result) {
// Parse the attribute indicating number of loops defined.
IntegerAttr numLoops;
auto &builder = parser.getBuilder();
auto intType = builder.getIntegerType(64);
if (parser.parseAttribute(numLoops, intType,
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
KrnlDefineLoopsOp::getNumLoopsAttrName(),
result.attributes))
return failure();
auto loopTypes = llvm::SmallVector<Type, 4>(
@ -79,10 +80,10 @@ ParseResult parseKrnlDefineLoopsOp(
// KrnlOptimizeLoopsOp
//===----------------------------------------------------------------------===//
void KrnlOptimizeLoopsOp::build(
Builder* builder, OperationState& result, int num_optimized_loops) {
result.types.append(
num_optimized_loops, LoopType::get(builder->getContext()));
void KrnlOptimizeLoopsOp::build(Builder *builder, OperationState &result,
int num_optimized_loops) {
result.types.append(num_optimized_loops,
LoopType::get(builder->getContext()));
// Create a region and a block for the body.
// Schedule intrinsics will be placed into this region.
Region *region = result.addRegion();
@ -98,8 +99,8 @@ void print(OpAsmPrinter& p, KrnlOptimizeLoopsOp& op) {
p.printFunctionalType(op);
}
ParseResult parseKrnlOptimizeLoopsOp(
OpAsmParser& parser, OperationState& result) {
ParseResult parseKrnlOptimizeLoopsOp(OpAsmParser &parser,
OperationState &result) {
// Parse the schedule body region.
Region *region = result.addRegion();
if (parser.parseRegion(*region, llvm::None, llvm::None))
@ -146,10 +147,11 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
KrnlIterateOperandPack operandPack) {
// Record optimized loops and the number of such loops.
result.addOperands(operandPack.getOperands());
result.addAttribute(
KrnlIterateOp::getBoundsAttrName(), operandPack.getAttributes());
result.addAttribute(KrnlIterateOp::getBoundsAttrName(),
operandPack.getAttributes());
result.addAttribute(getNumOptimizedLoopsAttrName(),
result.addAttribute(
getNumOptimizedLoopsAttrName(),
builder->getI64IntegerAttr(operandPack.getNumOptimizedLoops()));
// Create a region and a block for the body. The arguments of the region are
@ -204,10 +206,11 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
// Parse optimized loops:
SmallVector<OpAsmParser::OperandType, 4> optimizedLoopRefs;
if (parser.parseOperandList(
optimizedLoopRefs, OpAsmParser::Delimiter::Paren) ||
if (parser.parseOperandList(optimizedLoopRefs,
OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(optimizedLoopRefs,
LoopType::get(result.getContext()), result.operands))
LoopType::get(result.getContext()),
result.operands))
return failure();
// Record how many optimized loops did we parse.
@ -222,16 +225,16 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
return failure();
// A function to parse a lower or upper bound.
auto parseBound = [&result, &builder, &parser, &operandParser, &boundMaps](
bool isUpper) -> ParseResult {
auto parseBound = [&result, &builder, &parser, &operandParser,
&boundMaps](bool isUpper) -> ParseResult {
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
// the map has multiple results.
bool failedToParsedMinMax =
failed(parser.parseOptionalKeyword(isUpper ? "min" : "max"));
// Try parse an SSA operand.
if (succeeded(operandParser.ParseOptionalOperand(
builder.getIndexType(), result.operands))) {
if (succeeded(operandParser.ParseOptionalOperand(builder.getIndexType(),
result.operands))) {
AffineMap map = builder.getSymbolIdentityMap();
boundMaps.emplace_back(AffineMapAttr::get(map));
return success();
@ -243,8 +246,8 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
llvm::SMLoc attrLoc = parser.getCurrentLocation();
Attribute boundAttr;
llvm::SmallVector<NamedAttribute, 1> tempBoundAttrContainer;
if (parser.parseAttribute(
boundAttr, builder.getIndexType(), "temp", tempBoundAttrContainer))
if (parser.parseAttribute(boundAttr, builder.getIndexType(), "temp",
tempBoundAttrContainer))
return failure();
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
@ -255,13 +258,15 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
auto map = affineMapAttr.getValue();
if (map.getNumDims() != numDims)
return parser.emitError(parser.getNameLoc(),
return parser.emitError(
parser.getNameLoc(),
"dim operand count and integer set dim count must match");
unsigned numDimAndSymbolOperands =
result.operands.size() - currentNumOperands;
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
return parser.emitError(parser.getNameLoc(),
return parser.emitError(
parser.getNameLoc(),
"symbol operand count and integer set symbol count must match");
// If the map has multiple results, make sure that we parsed the min/max
@ -316,18 +321,18 @@ ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
// At this point, there shouldn't be any operands left to parse.
if (operandParser.hasOperandLeft())
return parser.emitError(parser.getCurrentLocation());
result.addAttribute(
KrnlIterateOp::getBoundsAttrName(), builder.getArrayAttr(boundMaps));
result.addAttribute(KrnlIterateOp::getBoundsAttrName(),
builder.getArrayAttr(boundMaps));
Region *region = result.addRegion();
SmallVector<Type, 4> inductionVarTypes(
inductionVarRefs.size(), builder.getIndexType());
SmallVector<Type, 4> inductionVarTypes(inductionVarRefs.size(),
builder.getIndexType());
if (parser.parseRegion(*region, inductionVarRefs, inductionVarTypes))
return failure();
// Ensure iterate region is closed off with krnl.terminate.
KrnlIterateOp::ensureTerminator(
*region, parser.getBuilder(), result.location);
KrnlIterateOp::ensureTerminator(*region, parser.getBuilder(),
result.location);
return success();
}
@ -346,13 +351,14 @@ void print(OpAsmPrinter& p, KrnlReturnLoopsOp& op) {
p.printOperands(op.operand_begin(), op.operand_end());
}
ParseResult parseKrnlReturnLoopsOp(
OpAsmParser& parser, OperationState& result) {
ParseResult parseKrnlReturnLoopsOp(OpAsmParser &parser,
OperationState &result) {
// Parse the loops to return.
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
if (parser.parseOperandList(timestamp_dim_handlers) ||
parser.resolveOperands(timestamp_dim_handlers,
LoopType::get(result.getContext()), result.operands))
LoopType::get(result.getContext()),
result.operands))
return failure();
return success();

View File

@ -11,17 +11,16 @@
#include <map>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "src/compiler/dialect/krnl/krnl_helper.hpp"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "src/compiler/pass/passes.hpp"
using namespace mlir;
@ -218,8 +217,8 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
// If its value is 1, it is broadcasted dimension.
// Otherwise, non-broadcasted dimension.
auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx],
zero, loopIVs[loopIdx]);
auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
loopIVs[loopIdx]);
newLoopIVs.insert(newLoopIVs.begin(), idx);
} else {
// Non-broadcasted dimension
@ -298,17 +297,18 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
//===----------------------------------------------------------------------===//
template <typename UnaryOp>
Value *mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
/* Lower UnaryOp to Ops in the Standard dialect.
*/
auto loc = op->getLoc();
Type element_type = operands.front()->getType();
if (element_type.isa<IntegerType>()) {
return rewriter.create<ScalarIOp<UnaryOp>>(
loc, result_types, operands, mlir::None);
return rewriter.create<ScalarIOp<UnaryOp>>(loc, result_types, operands,
mlir::None);
} else if (element_type.isa<FloatType>()) {
return rewriter.create<ScalarFOp<UnaryOp>>(
loc, result_types, operands, mlir::None);
return rewriter.create<ScalarFOp<UnaryOp>>(loc, result_types, operands,
mlir::None);
} else {
emitError(loc, "unsupported element type");
return nullptr;
@ -320,7 +320,8 @@ Value* mapToLowerScalarOp(Operation* op, ArrayRef<Type> result_types,
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXTanhOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
@ -343,7 +344,8 @@ Value* mapToLowerScalarOp<ONNXTanhOp>(Operation* op,
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXSinhOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2)
@ -366,7 +368,8 @@ Value* mapToLowerScalarOp<ONNXSinhOp>(Operation* op,
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXCoshOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// ConstantOp 2)
@ -389,7 +392,8 @@ Value* mapToLowerScalarOp<ONNXCoshOp>(Operation* op,
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
@ -410,8 +414,8 @@ Value* mapToLowerScalarOp<ONNXSigmoidOp>(Operation* op,
// Scalar unary ops for lowering ONNXHardSigmoidOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
Value *mapToLowerScalarOp<ONNXHardSigmoidOp>(
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// %Y = AddFOp(MulFOp(alpha, %X), beta)
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
@ -447,7 +451,8 @@ Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op,
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
// %X)
@ -461,9 +466,10 @@ Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types,
auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero,
rewriter.create<MulFOp>(
loc, alpha, rewriter.create<SubFOp>(loc, exp, one)),
auto result = rewriter.create<SelectOp>(
loc, lessThanZero,
rewriter.create<MulFOp>(loc, alpha,
rewriter.create<SubFOp>(loc, exp, one)),
operand);
return result;
@ -474,7 +480,8 @@ Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types,
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXReluOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// ConstantOp 0,
@ -494,8 +501,9 @@ Value* mapToLowerScalarOp<ONNXReluOp>(Operation* op,
// Scalar unary ops for lowering ONNXLeakyReluOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXLeakyReluOp>(Operation* op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
Value *
mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
// MulFOp(alpha, %X),
@ -519,7 +527,8 @@ Value* mapToLowerScalarOp<ONNXLeakyReluOp>(Operation* op,
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXSeluOp>(Operation *op,
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
// MulFOp(gamma, %X),
@ -537,9 +546,10 @@ Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
auto exp = rewriter.create<ExpOp>(loc, operand);
auto greaterThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
auto select = rewriter.create<SelectOp>(loc, greaterThanZero, operand,
rewriter.create<SubFOp>(
loc, rewriter.create<MulFOp>(loc, alpha, exp), alpha));
auto select = rewriter.create<SelectOp>(
loc, greaterThanZero, operand,
rewriter.create<SubFOp>(loc, rewriter.create<MulFOp>(loc, alpha, exp),
alpha));
auto result = rewriter.create<MulFOp>(loc, gamma, select);
return result;
@ -549,8 +559,10 @@ Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
// Scalar unary ops for lowering ONNXReciprocalOp
//===----------------------------------------------------------------------===//
template <>
Value* mapToLowerScalarOp<ONNXReciprocalOp>(Operation* op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
Value *
mapToLowerScalarOp<ONNXReciprocalOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
auto loc = op->getLoc();
Value *operand = operands[0];
@ -566,7 +578,8 @@ Value* mapToLowerScalarOp<ONNXReciprocalOp>(Operation* op, ArrayRef<Type> result
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
// %X,
// %Y)
@ -583,7 +596,8 @@ Value* mapToLowerScalarOp<ONNXMaxOp>(Operation* op, ArrayRef<Type> result_types,
//===----------------------------------------------------------------------===//
template <>
Value *mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) {
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
// %X,
// %Y)
@ -601,7 +615,8 @@ template <typename ElementwiseUnaryOp>
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of
@ -624,8 +639,8 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {operands[0]});
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
{operands[0]});
// Number of loops
auto memRefShape = memRefType.getShape();
@ -703,7 +718,8 @@ template <typename ElementwiseVariadicOp>
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
// TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result
@ -725,8 +741,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, operands);
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
operands);
// Number of loops
auto memRefShape = memRefType.getShape();
@ -815,7 +831,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
ONNXReshapeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
@ -825,9 +842,9 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
Value *alloc;
// Compute size in bytes.
Value* tensorSize = rewriter.create<ConstantOp>(loc,
rewriter.getIntegerAttr(
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)));
Value *tensorSize = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
getMemRefEltSizeInBytes(memRefType)));
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType)) {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
@ -889,8 +906,8 @@ struct TensorTypeConverter : public TypeConverter {
/// inputs. Once unranked results can be handled gracefully this
/// override needs to be removed in favour of the original MLIR one.]
bool isSignatureLegal(FunctionType funcType) {
return llvm::all_of(
funcType.getInputs(), [this](Type type) { return isLegal(type); });
return llvm::all_of(funcType.getInputs(),
[this](Type type) { return isLegal(type); });
}
};
@ -943,8 +960,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
// Type conversion for function signatures.
// Call MLIR FuncOp signature conversion when result type is
// a ranked tensor.
populateFuncOpTypeConversionPattern(
patterns, &getContext(), tensor_to_memref_converter);
populateFuncOpTypeConversionPattern(patterns, &getContext(),
tensor_to_memref_converter);
// Frontent operation lowering.
patterns.insert<ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>,
@ -981,5 +998,5 @@ std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
return std::make_unique<FrontendToKrnlLoweringPass>();
}
static PassRegistration<FrontendToKrnlLoweringPass> pass(
"lower-frontend", "Lower frontend ops to Krnl dialect.");
static PassRegistration<FrontendToKrnlLoweringPass>
pass("lower-frontend", "Lower frontend ops to Krnl dialect.");

View File

@ -17,8 +17,8 @@ namespace {
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(
KrnlIterateOp iterateOp, PatternRewriter& rewriter) const override {
PatternMatchResult matchAndRewrite(KrnlIterateOp iterateOp,
PatternRewriter &rewriter) const override {
auto boundMapAttrs =
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
.getValue();
@ -38,8 +38,8 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
map = boundMapAttrs[boundIdx + boundType]
.cast<AffineMapAttr>()
.getValue();
operands.insert(
operands.end(), operandItr, operandItr + map.getNumInputs());
operands.insert(operands.end(), operandItr,
operandItr + map.getNumInputs());
std::advance(operandItr, map.getNumInputs());
}
@ -83,8 +83,8 @@ class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
public:
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(
KrnlTerminatorOp op, PatternRewriter& rewriter) const override {
PatternMatchResult matchAndRewrite(KrnlTerminatorOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
return matchSuccess();
}
@ -98,8 +98,8 @@ class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
public:
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(
KrnlDefineLoopsOp op, PatternRewriter& rewriter) const override {
PatternMatchResult matchAndRewrite(KrnlDefineLoopsOp op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return matchSuccess();
}
@ -113,8 +113,8 @@ class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
public:
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(
KrnlOptimizeLoopsOp op, PatternRewriter& rewriter) const override {
PatternMatchResult matchAndRewrite(KrnlOptimizeLoopsOp op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return matchSuccess();
}
@ -146,7 +146,8 @@ void KrnlToAffineLoweringPass::runOnFunction() {
OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext());
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(
&getContext());
if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
@ -158,5 +159,5 @@ std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
return std::make_unique<KrnlToAffineLoweringPass>();
}
static PassRegistration<KrnlToAffineLoweringPass> pass(
"lower-krnl", "Lower Krnl dialect.");
static PassRegistration<KrnlToAffineLoweringPass> pass("lower-krnl",
"Lower Krnl dialect.");

View File

@ -91,8 +91,8 @@ int main(int ac, char* av[]) {
po::positional_options_description p;
p.add("onnx-model", -1);
po::variables_map vm;
po::store(
po::command_line_parser(ac, av).options(desc).positional(p).run(), vm);
po::store(po::command_line_parser(ac, av).options(desc).positional(p).run(),
vm);
// TODO: allow multiple input files
assert(vm.count("onnx-model") < 2 && "At most one input file can be provided!");
@ -137,10 +137,10 @@ int main(int ac, char* av[]) {
// Write LLVM bitcode to disk.
std::error_code EC;
llvm::raw_fd_ostream moduleBitcodeStream(
"model.bc", EC, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", EC,
llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
moduleBitcodeStream);
moduleBitcodeStream.flush();
return 0;