[MLIR] Krnl Dialect Definition (#357)

* a complete, roud-trippable Krnl dialect operation definition

* remove old dialect definition files, edit build files

* register dialect

* check in src for onnf_opt and dimension handler types

* re-trigger jenkins

* fix build

* clarify operation semantics

* add verifier for krnl.iterate

* refactor to make things clear

* do not hard code types

* nit and add comments

* fix rebase

* update op implementation

* fix merge

* update kernel dialect definition

* more comment on how to use the builder for krnl.iterate operation

* ammend the comment

* can parse krnl.iterate

* can parse and print if bounds are not SSA values

* address comments

* better error handling

* Update CMakeLists.txt

* update comment

* reflow comments
This commit is contained in:
Tian Jin 2019-11-11 21:31:56 -05:00 committed by Doru Bercea
parent 03be41f7df
commit 780e6f0aa0
14 changed files with 833 additions and 78 deletions

View File

@ -3,3 +3,5 @@ add_executable(onnf main.cpp)
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
target_link_libraries(onnf builder compiler ${Boost_LIBRARIES})
install(TARGETS onnf DESTINATION bin)

View File

@ -1,9 +1,13 @@
add_library(
compiler
ir/knl/knl_ops.cpp
ir/knl/knl_ops.hpp
dialect/krnl/krnl_ops.cpp
dialect/krnl/krnl_ops.hpp
dialect/krnl/krnl_types.cpp
dialect/krnl/krnl_types.hpp
dialect/onnx/onnx_ops.cpp
dialect/onnx/onnx_ops.hpp
dialect/krnl/parser_helper.cpp
dialect/krnl/parser_helper.hpp
pass/shape_inference_pass.cpp
pass/shape_inference_interface.hpp
pass/passes.hpp)
@ -25,7 +29,7 @@ find_package(Boost 1.54.0
# target_link_libraries(compiler isl inja ${Boost_LIBRARIES})
target_link_libraries(compiler
${Boost_LIBRARIES}
)
${MLIRLIBS} curses)
add_executable(onnf-opt
tool/onnf_opt/onnf_opt.cpp)
@ -34,12 +38,6 @@ target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler)
target_include_directories(onnf-opt PRIVATE ../..)
target_include_directories(onnf-opt PRIVATE ${CMAKE_BINARY_DIR})
set(LLVM_TARGET_DEFINITIONS ir/knl/knl.td)
onnf_tablegen(knl.hpp.inc -gen-op-decls)
onnf_tablegen(knl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_kir)
add_dependencies(compiler gen_kir)
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
@ -51,3 +49,10 @@ onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
add_public_tablegen_target(gen_onnx)
add_dependencies(compiler gen_onnx)
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
onnf_tablegen(krnl.hpp.inc -gen-op-decls)
onnf_tablegen(krnl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_krnl_ops)
add_dependencies(compiler gen_krnl_ops)
add_dependencies(onnf-opt gen_krnl_ops)

View File

@ -0,0 +1,398 @@
//===--------------------- krnl_ops.cpp - MLIR Operations -----------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#include <iostream>
#include <queue>
#include "src/compiler/dialect/krnl/parser_helper.hpp"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "krnl_ops.hpp"
using namespace mlir;
namespace mlir {
KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "src/compiler/krnl.cpp.inc"
>();
addTypes<LoopType>();
}
//===----------------------------------------------------------------------===//
// KrnlDefineLoopsOp
//===----------------------------------------------------------------------===//
void KrnlDefineLoopsOp::build(
Builder* builder, OperationState& result, int64_t num_loops) {
// Create the same number of dimension handlers as the number of
// dimensions in the associated integer set.
result.types.append(num_loops, LoopType::get(builder->getContext()));
result.addAttribute(
getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops));
}
void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) {
auto num_loop_attr = op.getAttrOfType<IntegerAttr>(op.getNumLoopsAttrName());
p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue();
}
ParseResult parseKrnlDefineLoopsOp(
OpAsmParser& parser, OperationState& result) {
// Parse the attribute indicating number of loops defined.
IntegerAttr num_loops;
auto& builder = parser.getBuilder();
auto int32_type = builder.getIntegerType(64);
if (parser.parseAttribute(num_loops, int32_type,
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
return failure();
auto loop_types = llvm::SmallVector<Type, 4>(
num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext()));
if (parser.addTypesToList(loop_types, result.types))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// KrnlOptimizeLoopsOp
//===----------------------------------------------------------------------===//
void KrnlOptimizeLoopsOp::build(
Builder* builder, OperationState& result, int num_optimized_loops) {
result.types.append(
num_optimized_loops, LoopType::get(builder->getContext()));
// Create a region and a block for the body.
// Schedule intrinsics will be placed into this region.
Region* region = result.addRegion();
auto* body = new Block();
region->push_back(body);
}
void print(OpAsmPrinter& p, KrnlOptimizeLoopsOp& op) {
p << "krnl.optimize_loops ";
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
p << " : ";
p.printFunctionalType(op);
}
ParseResult parseKrnlOptimizeLoopsOp(
OpAsmParser& parser, OperationState& result) {
// Parse the schedule body region.
Region* region = result.addRegion();
if (parser.parseRegion(*region, llvm::None, llvm::None))
return failure();
// Parse the function type for the schedule operation.
// Then following the hint of this parsed function type, parse the
// returned timestamp space dimension handlers.
FunctionType schedule_func_type;
if (parser.parseColonType(schedule_func_type) ||
parser.addTypesToList(schedule_func_type.getResults(), result.types)) {
failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// KrnlIterateOp
//===----------------------------------------------------------------------===//
/*!
* Build a Krnl Dialect iterate operation.
* input_loops: a collection of input krnl.loops being optimized.
* optimized_loops: a collection of optimized (scheduled) krnl.loops.
* operand_bounds: a collection of SSA value bounds.
* const_bounds: a collection of constant bounds.
* bound_types: a collection of integer values indicating how bounds are given.
* 0 : bound is given as an integer in const_bounds.
* 1 : bound is given as an operand in operand_bounds.
* 2 : bound is given as an affine map. (TODO).
*
* The following example illustrates how induction variable bounds are parsed
* from builder function inputs:
*
* - operand_bounds = [N, M]
* - const_bounds = [10, 20]
* - bound_types = [0, 1, 1, 0]
*
* Then the bounds will be parsed as:
* %i0 = 10 to N : %i1 = M to 20
*/
void KrnlIterateOp::build(Builder* builder, OperationState& result,
ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops,
ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds,
ArrayRef<int> bound_types) {
// Record optimized loops and the number of such loops.
result.addOperands(optimized_loops);
result.addAttribute(getNumOptimizedLoopsAttrName(),
builder->getI64IntegerAttr(optimized_loops.size()));
// Record input loops and the number of such loops.
result.addOperands(input_loops);
result.addAttribute(getNumInputLoopsAttrName(),
builder->getI64IntegerAttr(input_loops.size()));
// Record bound either as attribute or from operand list.
auto next_operand_bound = operand_bounds.begin();
auto next_const_bound = const_bounds.begin();
for (size_t i = 0; i < bound_types.size(); i++) {
auto bound_type = bound_types[i];
if (bound_type == 0) {
// Constant bound.
result.addAttribute(getBoundAttrName(i / 2, i % 2),
builder->getI64IntegerAttr(*next_const_bound));
next_const_bound = std::next(next_const_bound);
} else {
// Operand bound.
result.addOperands(*next_operand_bound);
next_operand_bound = std::next(next_operand_bound);
}
}
// Record bound types as attribute:
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
builder->getI32ArrayAttr(bound_types));
// Create a region and a block for the body. The arguments of the region is
// the loop induction variables; there can be multiple induction variables
// associated with the same krnl.iterate operation.
Region* bodyRegion = result.addRegion();
auto* body = new Block();
auto body_args = llvm::SmallVector<Type, 4>(
input_loops.size(), IndexType::get(builder->getContext()));
body->addArguments(body_args);
bodyRegion->push_back(body);
ensureTerminator(*bodyRegion, *builder, result.location);
}
void print(OpAsmPrinter& p, KrnlIterateOp& op) {
p << "krnl.iterate(";
// Print optimized loops:
auto num_optimized_loops = op.getNumOptimizedLoops();
p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops);
p << ") with (";
// Set up iterator to input loops:
auto num_input_loops = op.getNumInputLoops();
auto input_loop_begin = op.operand_begin() + num_optimized_loops;
// Set up iterators to operand bounds.
auto next_operand_bound = input_loop_begin + num_input_loops;
// Function to print a lower or upper bound.
auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) {
IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>();
if (type.getValue().getSExtValue() == 0) {
// Bound is an operand.
p.printOperand(*next_operand_bound);
next_operand_bound = std::next(next_operand_bound);
} else {
// Bound is an integer attribute.
auto bound_idx = idx / 2;
auto is_ub = idx % 2;
IntegerAttr bound = op.getAttrOfType<IntegerAttr>(
KrnlIterateOp::getBoundAttrName(bound_idx, is_ub));
p << bound.getValue().getSExtValue();
}
};
auto induction_variables = op.bodyRegion().front().getArguments();
ArrayRef<Attribute> bound_types =
op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundTypesAttrName())
.getValue();
// Print input loop operands, induction variables and their ranges.
for (size_t i = 0; i < num_input_loops; i++) {
if (i != 0)
p << ", ";
p.printOperand(*std::next(input_loop_begin, i));
p << " -> ";
// Print induction variable block argument.
p.printOperand(induction_variables[i]);
p << " = ";
print_bound(bound_types, 2 * i); // Print lower bound.
p << " to ";
print_bound(bound_types, 2 * i + 1); // Print upper bound.
}
p << ")";
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
auto builder = parser.getBuilder();
auto context = builder.getContext();
onnf::KrnlDialectOperandParser operand_parser(parser);
// Parse optimized loops:
SmallVector<OpAsmParser::OperandType, 4> num_optimized_loops;
if (parser.parseOperandList(
num_optimized_loops, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(num_optimized_loops,
LoopType::get(result.getContext()), result.operands))
return failure();
// Record how many optimized loops did we parse.
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
builder.getI64IntegerAttr(num_optimized_loops.size()));
// Parse input loops and their lower and upper bounds.
SmallVector<OpAsmParser::OperandType, 4> in_loop_refs, induction_var_refs;
SmallVector<Value*, 4> in_loop_operands, operand_bounds;
SmallVector<Attribute, 4> bound_types;
SmallVector<IntegerAttr, 4> const_bounds;
if (parser.parseKeyword("with") || parser.parseLParen())
return failure();
// A function to parse a lower or upper bound.
auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types,
&operand_bounds, &const_bounds](
bool is_ub, size_t bound_pair_count) -> ParseResult {
// Try parse an SSA operand.
Value* bound;
operand_parser.ParseOptionalOperand(builder.getIndexType(), bound);
if (bound != nullptr) {
// Parsed an SSA id as bound.
operand_bounds.emplace_back(bound);
// Record bound_type as an operand type.
bound_types.emplace_back(builder.getI32IntegerAttr(0));
} else {
// Bound is not an SSA id, then it must be an integer.
// Parse an integer constant attribute.
IntegerAttr boundAttr;
if (parser.parseAttribute(boundAttr, builder.getIndexType(),
KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub),
result.attributes))
return failure();
const_bounds.emplace_back(
builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue()));
// Record that the bound_type is a constant integer attribute.
bound_types.emplace_back(builder.getI32IntegerAttr(1));
}
};
bool keep_parsing; // Do we keep parsing loops/bounds?
size_t bound_pair_count = 0; // Record the number of bound pairs parsed.
do {
// Parse an input loop operand;
Value* in_loop_operand;
operand_parser.ParseOperand(LoopType::get(context), in_loop_operand);
in_loop_operands.emplace_back(in_loop_operand);
parser.parseArrow();
// Parse induction variable.
OpAsmParser::OperandType induction_var;
if (parser.parseRegionArgument(induction_var) || parser.parseEqual())
return failure();
induction_var_refs.emplace_back(induction_var);
// Parse bound par (min to max).
if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") ||
parse_bound(true, bound_pair_count))
return failure();
bound_pair_count++;
// We may fail to parse a comma if an operand bound is followed by
// a comma and the next input loop operand, in which case
// the entire "{operand bound}, {input_loop_operand}" sequence will
// be parsed as an operand list.
parser.parseOptionalComma();
// If we don't see a RParen token, we keep parsing.
keep_parsing = failed(parser.parseOptionalRParen());
} while (keep_parsing);
// At this point, there shouldn't be any operands left to parse.
if (operand_parser.has_operand_left())
return parser.emitError(parser.getCurrentLocation());
// Record how many input loops did we parse.
result.addOperands(in_loop_operands);
result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(),
builder.getI64IntegerAttr(in_loop_operands.size()));
// Add operand bounds to the list of operands of current operation.
result.addOperands(operand_bounds);
// A list of 2N elements where the (2n) and (2n+1) th element specifies
// whether the lower and upper bound of the n'th induction variable is stored
// as an operand or as an attribute. N being the number of input loops
// specified in this krnl.iterate operation.
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
builder.getArrayAttr(bound_types));
// Parse the schedule body region.
Region* region = result.addRegion();
SmallVector<Type, 4> induction_var_types(
induction_var_refs.size(), builder.getIndexType());
if (parser.parseRegion(*region, induction_var_refs, induction_var_types))
return failure();
// Ensure iterate region is closed off with krnl.terminate.
KrnlIterateOp::ensureTerminator(
*region, parser.getBuilder(), result.location);
return success();
}
static LogicalResult verify(KrnlIterateOp op) {
// TODO: Verify number of induction variable bounds matches the number of
// input loops.
}
//===----------------------------------------------------------------------===//
// KrnlReturnLoopsOp
//===----------------------------------------------------------------------===//
void print(OpAsmPrinter& p, KrnlReturnLoopsOp& op) {
p << "krnl.return_loops ";
p.printOperands(op.operand_begin(), op.operand_end());
}
ParseResult parseKrnlReturnLoopsOp(
OpAsmParser& parser, OperationState& result) {
// Parse the loops to return.
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
if (parser.parseOperandList(timestamp_dim_handlers) ||
parser.resolveOperands(timestamp_dim_handlers,
LoopType::get(result.getContext()), result.operands))
return failure();
return success();
}
#define GET_OP_CLASSES
#include "src/compiler/krnl.cpp.inc"
} // namespace mlir

View File

@ -0,0 +1,51 @@
//===--------------------- krnl_ops.hpp - MLIR Operations -----------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "src/compiler/dialect/krnl/krnl_types.hpp"
namespace mlir {
class KrnlOpsDialect : public Dialect {
public:
KrnlOpsDialect(MLIRContext* context);
static StringRef getDialectNamespace() { return "krnl"; }
/// Parse a type registered to this dialect. Overriding this method is
/// required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
mlir::Type parseType(
llvm::StringRef tyData, mlir::Location loc) const override {
MLIRContext* context = getContext();
if (tyData.consume_front("loop"))
return LoopType::get(context);
else
return (emitError(loc, "Unexpected type: " + tyData), Type());
}
/// Print a type registered to this dialect. Overriding this method is
/// only required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
void printType(mlir::Type type, llvm::raw_ostream& os) const override {
switch (type.getKind()) {
case KrnlTypes::Loop:
os << "loop";
return;
}
}
};
#define GET_OP_CLASSES
#include "src/compiler/krnl.hpp.inc"
} // namespace mlir

View File

@ -0,0 +1,209 @@
//===--------------------- krnl_ops.td - MLIR Operations ------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
include "mlir/IR/OpBase.td"
def Krnl_Dialect : Dialect {
let name = "krnl";
let cppNamespace = "";
}
// Require regions to have krnl.terminate terminator operation.
def ImplicitKrnlTerminator
: SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
let summary = "define_loops operation";
let description = [{
The "krnl.define_loops" operation is used to define input loops,
those are the for loops appearing in the input program that we
intend to optimize.
}];
let arguments = (ins);
let results = (outs Variadic<AnyType>);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result,"
"int64_t num_loops">
];
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let extraClassDeclaration = [{
static StringRef getNumLoopsAttrName() { return "num_loops"; }
// Helper function to extract the number of loops being defined.
int64_t getNumLoops() {
auto num_loops =
getAttrOfType<IntegerAttr>(
getNumLoopsAttrName())
.getValue()
.getSExtValue();
return num_loops;
}
}];
}
def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
let summary = "optimize_loops operation";
let description = [{
The "krnl.optimize_loops" operation is essentially a cosmetic operation
which exists to encapsulate a region where loops are being scheduled/optimized.
The optimized loops are returned at the end of the
region associated with the krnl.optimize_loops operation.
For example:
TBD once we have actual schedule intrinsics.
}];
let arguments = (ins Variadic<AnyType>);
let results = (outs Variadic<AnyType>);
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result, "
"int timestamp_space_rank">
];
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
let summary = "iterate operation";
let description = [{
The "krnl.iterate" operation is conceptually equivalent to a nested for loops.
For instance, say we have the following two
%l0, %l1 = krnl.define_loops 2
%o0, %o1 = krnl.optimize_loops {
// Identity schedule.
krnl.return_loops %l0, %l1
}
Then, consider the following krnl.iterate operation:
krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) {
// Some operations.
}
It is equivalent to:
for (i0=0; i0<10; i0++)
for (i1=0; i1<10; i1++)
// Some operations.
}];
let arguments = (ins Variadic<AnyType>);
let regions = (region SizedRegion<1>:$bodyRegion);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result, "
"ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, "
"ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, "
"ArrayRef<int> bound_types">
];
let extraClassDeclaration = [{
// In krnl.iterate operation, three types of SSA values are stored:
// - Optimized krnl.loops.
// - Input krnl.loops.
// - SSA value based induction variable bound (parametric bound).
// We record the number of optimized and input loops to separate these three
// group of operands out.
static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; }
int64_t getNumOptimizedLoops() {
auto num_optimized_loops =
getAttrOfType<IntegerAttr>(
getNumOptimizedLoopsAttrName())
.getValue()
.getSExtValue();
return num_optimized_loops;
}
static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; }
int64_t getNumInputLoops() {
auto num_loops =
getAttrOfType<IntegerAttr>(
getNumInputLoopsAttrName())
.getValue()
.getSExtValue();
return num_loops;
}
// Constant bounds are stored here as a list attribute.
static StringRef getConstantBoundsAttrName() { return "constant_bounds"; }
// Store type of each bound as three types:
// - 0 = constant attribute.
// - 1 = operand type.
// - 2 = affine maps (TODO).
static StringRef getBoundTypesAttrName() { return "bound_types"; }
// Get dynamic attribute name for the i-th lower and upper bound.
static std::string getBoundAttrName(int64_t i, bool is_ub) {
std::string bound_type = is_ub ? "_ub" : "_lb";
std::string bound_idx = std::to_string(i);
return "__bound_" + bound_idx + bound_type;
}
}];
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let verifier = [{ return ::verify(*this); }];
}
def KrnlReturnLoopsOp : Op<Krnl_Dialect, "return_loops", [Terminator]> {
let summary = "Krnl return handler operation";
let description = [{
Krnl return_loops operation is a terminator operation for returning
scheduled dimension handlers in the krnl.optimize_loops region.
}];
let arguments = (ins Variadic<AnyType>);
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
let summary = "Krnl terminator operation";
let description = [{
Krnl terminator is a special terminator operation for blocks inside krnl
iterate operations. It unconditionally transmits the control flow to the
successor of the operation enclosing the region.
This operation does _not_ have a custom syntax. However, krnl control
operations omit the terminator in their custom syntax for brevity.
}];
// No custom parsing/printing form.
let parser = ?;
let printer = ?;
// Fully specified by traits.
let verifier = ?;
}

View File

@ -0,0 +1,9 @@
//===--------------------- krnl_types.cpp - MLIR Operations ---------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#include "krnl_types.hpp"

View File

@ -0,0 +1,36 @@
//===--------------------- krnl_types.hpp - MLIR Operations ---------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#pragma once
#include <mlir/IR/Types.h>
namespace mlir {
namespace KrnlTypes {
enum Kinds {
// A krnl.loop is simply a reference to a for loop and will be used to:
// - Indicate the presence of a for loop in krnl.iterate.
// - Identify the loop in optimization intrinsics.
Loop = mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
};
}
class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> {
public:
using Base::Base;
// Support type inquiry through isa, cast and dyn_cast.
static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; }
// Get a unique instance of Loop type.
static LoopType get(mlir::MLIRContext* context) {
return Base::get(context, KrnlTypes::Loop);
}
};
} // namespace mlir

View File

@ -0,0 +1,52 @@
//===------------------ parser_helper.cpp - MLIR Operations ---------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#include "parser_helper.hpp"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
namespace onnf {
mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
mlir::Type operand_type, mlir::Value*& operand) {
// If operand queue is empty, parse more operands and cache them.
if (_operand_ref_queue.empty()) {
// Parse operand types:
llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> operand_refs;
_parser.parseOperandList(operand_refs);
// Record operands:
for (auto& operand_ref : operand_refs)
_operand_ref_queue.emplace(operand_ref);
}
// If we parsed some operand reference(s), resolve the ref to an operand:
if (!_operand_ref_queue.empty()) {
auto operand_ref = _operand_ref_queue.front();
_operand_ref_queue.pop();
llvm::SmallVector<mlir::Value*, 1> operands;
_parser.resolveOperand(operand_ref, operand_type, operands);
operand = operands.front();
return mlir::success();
} else {
operand = nullptr;
return mlir::failure();
}
}
mlir::ParseResult KrnlDialectOperandParser::ParseOperand(
mlir::Type operand_type, mlir::Value*& operand) {
ParseOptionalOperand(operand_type, operand);
if (operand == nullptr)
return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand.");
return mlir::success();
}
} // namespace onnf

View File

@ -0,0 +1,46 @@
//===------------------ parser_helper.hpp - MLIR Operations ---------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#pragma once
#include <queue>
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
namespace onnf {
class KrnlDialectOperandParser {
public:
KrnlDialectOperandParser(mlir::OpAsmParser& parser)
: _parser(parser), _builder(parser.getBuilder()){};
// Parse an optional operand.
mlir::ParseResult ParseOptionalOperand(
mlir::Type operand_type, mlir::Value*& operand);
// Parse a required operand.
mlir::ParseResult ParseOperand(
mlir::Type operand_type, mlir::Value*& operand);
// Do we have more operands to parse?
bool has_operand_left() { return !_operand_ref_queue.empty(); }
private:
mlir::OpAsmParser& _parser;
mlir::Builder& _builder;
// A queue storing the parsed SSA id references.
std::queue<mlir::OpAsmParser::OperandType> _operand_ref_queue;
};
} // namespace onnf

View File

@ -1,27 +0,0 @@
include "mlir/IR/OpBase.td"
def Knl_Dialect : Dialect {
let name = "knl";
let cppNamespace = "";
}
def KnlIterate : Op<Knl_Dialect, "iterate"> {
let summary = "iterate operation";
let description = [{
The "knl.iterate" operation is conceptually equivalent to a nested for loop
in that it represents ordered interation of integer coordinates within an
affine integer set.
}];
let arguments = (ins Variadic<AnyType>);
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result, "
"IntegerSet set, ArrayRef<Value *> args">
];
}

View File

@ -1,23 +0,0 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "knl_ops.hpp"
namespace mlir {
KnlOpsDialect::KnlOpsDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "src/compiler/knl.cpp.inc"
>();
}
} // namespace mlir
namespace onnf {}

View File

@ -1,19 +0,0 @@
#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class KnlOpsDialect : public Dialect {
public:
KnlOpsDialect(MLIRContext* context);
static StringRef getDialectNamespace() { return "knl"; }
};
#define GET_OP_CLASSES
#include "src/compiler/knl.hpp.inc"
} // namespace mlir
namespace onnf {}

View File

@ -16,6 +16,11 @@
#include <mlir/Support/MlirOptMain.h>
#include <mlir/Dialect/StandardOps/Ops.h>
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/helper.hpp"
using namespace onnf;
static llvm::cl::opt<std::string> input_filename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
@ -52,6 +57,7 @@ int main(int argc, char** argv) {
auto output = mlir::openOutputFile(output_filename, &error_message);
mlir::registerDialect<mlir::KrnlOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,

View File

@ -1,3 +1,11 @@
//===--------------------------- main.cpp ---------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#include <cmath>
#include <cstdlib>
#include <iostream>
@ -21,6 +29,7 @@
#include <boost/program_options.hpp>
#include "src/builder/frontend_dialect_transformer.hpp"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "src/compiler/pass/passes.hpp"
@ -57,6 +66,7 @@ int main(int ac, char* av[]) {
}
mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;