[MLIR] Krnl Dialect Definition (#357)
* a complete, roud-trippable Krnl dialect operation definition * remove old dialect definition files, edit build files * register dialect * check in src for onnf_opt and dimension handler types * re-trigger jenkins * fix build * clarify operation semantics * add verifier for krnl.iterate * refactor to make things clear * do not hard code types * nit and add comments * fix rebase * update op implementation * fix merge * update kernel dialect definition * more comment on how to use the builder for krnl.iterate operation * ammend the comment * can parse krnl.iterate * can parse and print if bounds are not SSA values * address comments * better error handling * Update CMakeLists.txt * update comment * reflow comments
This commit is contained in:
parent
03be41f7df
commit
780e6f0aa0
|
@ -3,3 +3,5 @@ add_executable(onnf main.cpp)
|
|||
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||
target_link_libraries(onnf builder compiler ${Boost_LIBRARIES})
|
||||
|
||||
install(TARGETS onnf DESTINATION bin)
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
add_library(
|
||||
compiler
|
||||
ir/knl/knl_ops.cpp
|
||||
ir/knl/knl_ops.hpp
|
||||
dialect/krnl/krnl_ops.cpp
|
||||
dialect/krnl/krnl_ops.hpp
|
||||
dialect/krnl/krnl_types.cpp
|
||||
dialect/krnl/krnl_types.hpp
|
||||
dialect/onnx/onnx_ops.cpp
|
||||
dialect/onnx/onnx_ops.hpp
|
||||
dialect/krnl/parser_helper.cpp
|
||||
dialect/krnl/parser_helper.hpp
|
||||
pass/shape_inference_pass.cpp
|
||||
pass/shape_inference_interface.hpp
|
||||
pass/passes.hpp)
|
||||
|
@ -25,7 +29,7 @@ find_package(Boost 1.54.0
|
|||
# target_link_libraries(compiler isl inja ${Boost_LIBRARIES})
|
||||
target_link_libraries(compiler
|
||||
${Boost_LIBRARIES}
|
||||
)
|
||||
${MLIRLIBS} curses)
|
||||
|
||||
add_executable(onnf-opt
|
||||
tool/onnf_opt/onnf_opt.cpp)
|
||||
|
@ -34,12 +38,6 @@ target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler)
|
|||
target_include_directories(onnf-opt PRIVATE ../..)
|
||||
target_include_directories(onnf-opt PRIVATE ${CMAKE_BINARY_DIR})
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ir/knl/knl.td)
|
||||
onnf_tablegen(knl.hpp.inc -gen-op-decls)
|
||||
onnf_tablegen(knl.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(gen_kir)
|
||||
add_dependencies(compiler gen_kir)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
||||
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
||||
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
||||
|
@ -51,3 +49,10 @@ onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
|||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
add_public_tablegen_target(gen_onnx)
|
||||
add_dependencies(compiler gen_onnx)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
||||
onnf_tablegen(krnl.hpp.inc -gen-op-decls)
|
||||
onnf_tablegen(krnl.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(gen_krnl_ops)
|
||||
add_dependencies(compiler gen_krnl_ops)
|
||||
add_dependencies(onnf-opt gen_krnl_ops)
|
||||
|
|
|
@ -0,0 +1,398 @@
|
|||
//===--------------------- krnl_ops.cpp - MLIR Operations -----------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <iostream>
|
||||
#include <queue>
|
||||
|
||||
#include "src/compiler/dialect/krnl/parser_helper.hpp"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "krnl_ops.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace mlir {
|
||||
KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "src/compiler/krnl.cpp.inc"
|
||||
>();
|
||||
addTypes<LoopType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KrnlDefineLoopsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void KrnlDefineLoopsOp::build(
|
||||
Builder* builder, OperationState& result, int64_t num_loops) {
|
||||
// Create the same number of dimension handlers as the number of
|
||||
// dimensions in the associated integer set.
|
||||
result.types.append(num_loops, LoopType::get(builder->getContext()));
|
||||
result.addAttribute(
|
||||
getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops));
|
||||
}
|
||||
|
||||
void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) {
|
||||
auto num_loop_attr = op.getAttrOfType<IntegerAttr>(op.getNumLoopsAttrName());
|
||||
p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue();
|
||||
}
|
||||
|
||||
ParseResult parseKrnlDefineLoopsOp(
|
||||
OpAsmParser& parser, OperationState& result) {
|
||||
// Parse the attribute indicating number of loops defined.
|
||||
IntegerAttr num_loops;
|
||||
auto& builder = parser.getBuilder();
|
||||
auto int32_type = builder.getIntegerType(64);
|
||||
if (parser.parseAttribute(num_loops, int32_type,
|
||||
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
|
||||
return failure();
|
||||
|
||||
auto loop_types = llvm::SmallVector<Type, 4>(
|
||||
num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext()));
|
||||
if (parser.addTypesToList(loop_types, result.types))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KrnlOptimizeLoopsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void KrnlOptimizeLoopsOp::build(
|
||||
Builder* builder, OperationState& result, int num_optimized_loops) {
|
||||
result.types.append(
|
||||
num_optimized_loops, LoopType::get(builder->getContext()));
|
||||
// Create a region and a block for the body.
|
||||
// Schedule intrinsics will be placed into this region.
|
||||
Region* region = result.addRegion();
|
||||
auto* body = new Block();
|
||||
region->push_back(body);
|
||||
}
|
||||
|
||||
void print(OpAsmPrinter& p, KrnlOptimizeLoopsOp& op) {
|
||||
p << "krnl.optimize_loops ";
|
||||
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
p << " : ";
|
||||
p.printFunctionalType(op);
|
||||
}
|
||||
|
||||
ParseResult parseKrnlOptimizeLoopsOp(
|
||||
OpAsmParser& parser, OperationState& result) {
|
||||
// Parse the schedule body region.
|
||||
Region* region = result.addRegion();
|
||||
if (parser.parseRegion(*region, llvm::None, llvm::None))
|
||||
return failure();
|
||||
|
||||
// Parse the function type for the schedule operation.
|
||||
// Then following the hint of this parsed function type, parse the
|
||||
// returned timestamp space dimension handlers.
|
||||
FunctionType schedule_func_type;
|
||||
if (parser.parseColonType(schedule_func_type) ||
|
||||
parser.addTypesToList(schedule_func_type.getResults(), result.types)) {
|
||||
failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KrnlIterateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/*!
|
||||
* Build a Krnl Dialect iterate operation.
|
||||
* input_loops: a collection of input krnl.loops being optimized.
|
||||
* optimized_loops: a collection of optimized (scheduled) krnl.loops.
|
||||
* operand_bounds: a collection of SSA value bounds.
|
||||
* const_bounds: a collection of constant bounds.
|
||||
* bound_types: a collection of integer values indicating how bounds are given.
|
||||
* 0 : bound is given as an integer in const_bounds.
|
||||
* 1 : bound is given as an operand in operand_bounds.
|
||||
* 2 : bound is given as an affine map. (TODO).
|
||||
*
|
||||
* The following example illustrates how induction variable bounds are parsed
|
||||
* from builder function inputs:
|
||||
*
|
||||
* - operand_bounds = [N, M]
|
||||
* - const_bounds = [10, 20]
|
||||
* - bound_types = [0, 1, 1, 0]
|
||||
*
|
||||
* Then the bounds will be parsed as:
|
||||
* %i0 = 10 to N : %i1 = M to 20
|
||||
*/
|
||||
void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
||||
ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops,
|
||||
ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds,
|
||||
ArrayRef<int> bound_types) {
|
||||
// Record optimized loops and the number of such loops.
|
||||
result.addOperands(optimized_loops);
|
||||
result.addAttribute(getNumOptimizedLoopsAttrName(),
|
||||
builder->getI64IntegerAttr(optimized_loops.size()));
|
||||
|
||||
// Record input loops and the number of such loops.
|
||||
result.addOperands(input_loops);
|
||||
result.addAttribute(getNumInputLoopsAttrName(),
|
||||
builder->getI64IntegerAttr(input_loops.size()));
|
||||
|
||||
// Record bound either as attribute or from operand list.
|
||||
auto next_operand_bound = operand_bounds.begin();
|
||||
auto next_const_bound = const_bounds.begin();
|
||||
for (size_t i = 0; i < bound_types.size(); i++) {
|
||||
auto bound_type = bound_types[i];
|
||||
if (bound_type == 0) {
|
||||
// Constant bound.
|
||||
result.addAttribute(getBoundAttrName(i / 2, i % 2),
|
||||
builder->getI64IntegerAttr(*next_const_bound));
|
||||
next_const_bound = std::next(next_const_bound);
|
||||
} else {
|
||||
// Operand bound.
|
||||
result.addOperands(*next_operand_bound);
|
||||
next_operand_bound = std::next(next_operand_bound);
|
||||
}
|
||||
}
|
||||
|
||||
// Record bound types as attribute:
|
||||
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
|
||||
builder->getI32ArrayAttr(bound_types));
|
||||
|
||||
// Create a region and a block for the body. The arguments of the region is
|
||||
// the loop induction variables; there can be multiple induction variables
|
||||
// associated with the same krnl.iterate operation.
|
||||
Region* bodyRegion = result.addRegion();
|
||||
auto* body = new Block();
|
||||
auto body_args = llvm::SmallVector<Type, 4>(
|
||||
input_loops.size(), IndexType::get(builder->getContext()));
|
||||
body->addArguments(body_args);
|
||||
bodyRegion->push_back(body);
|
||||
|
||||
ensureTerminator(*bodyRegion, *builder, result.location);
|
||||
}
|
||||
|
||||
void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
||||
p << "krnl.iterate(";
|
||||
// Print optimized loops:
|
||||
auto num_optimized_loops = op.getNumOptimizedLoops();
|
||||
p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops);
|
||||
p << ") with (";
|
||||
|
||||
// Set up iterator to input loops:
|
||||
auto num_input_loops = op.getNumInputLoops();
|
||||
auto input_loop_begin = op.operand_begin() + num_optimized_loops;
|
||||
|
||||
// Set up iterators to operand bounds.
|
||||
auto next_operand_bound = input_loop_begin + num_input_loops;
|
||||
|
||||
// Function to print a lower or upper bound.
|
||||
auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) {
|
||||
IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>();
|
||||
if (type.getValue().getSExtValue() == 0) {
|
||||
// Bound is an operand.
|
||||
p.printOperand(*next_operand_bound);
|
||||
next_operand_bound = std::next(next_operand_bound);
|
||||
} else {
|
||||
// Bound is an integer attribute.
|
||||
auto bound_idx = idx / 2;
|
||||
auto is_ub = idx % 2;
|
||||
IntegerAttr bound = op.getAttrOfType<IntegerAttr>(
|
||||
KrnlIterateOp::getBoundAttrName(bound_idx, is_ub));
|
||||
p << bound.getValue().getSExtValue();
|
||||
}
|
||||
};
|
||||
|
||||
auto induction_variables = op.bodyRegion().front().getArguments();
|
||||
ArrayRef<Attribute> bound_types =
|
||||
op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundTypesAttrName())
|
||||
.getValue();
|
||||
|
||||
// Print input loop operands, induction variables and their ranges.
|
||||
for (size_t i = 0; i < num_input_loops; i++) {
|
||||
if (i != 0)
|
||||
p << ", ";
|
||||
|
||||
p.printOperand(*std::next(input_loop_begin, i));
|
||||
p << " -> ";
|
||||
|
||||
// Print induction variable block argument.
|
||||
p.printOperand(induction_variables[i]);
|
||||
p << " = ";
|
||||
|
||||
print_bound(bound_types, 2 * i); // Print lower bound.
|
||||
p << " to ";
|
||||
print_bound(bound_types, 2 * i + 1); // Print upper bound.
|
||||
}
|
||||
p << ")";
|
||||
|
||||
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
}
|
||||
|
||||
ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||
auto builder = parser.getBuilder();
|
||||
auto context = builder.getContext();
|
||||
onnf::KrnlDialectOperandParser operand_parser(parser);
|
||||
|
||||
// Parse optimized loops:
|
||||
SmallVector<OpAsmParser::OperandType, 4> num_optimized_loops;
|
||||
if (parser.parseOperandList(
|
||||
num_optimized_loops, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(num_optimized_loops,
|
||||
LoopType::get(result.getContext()), result.operands))
|
||||
return failure();
|
||||
|
||||
// Record how many optimized loops did we parse.
|
||||
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
|
||||
builder.getI64IntegerAttr(num_optimized_loops.size()));
|
||||
|
||||
// Parse input loops and their lower and upper bounds.
|
||||
SmallVector<OpAsmParser::OperandType, 4> in_loop_refs, induction_var_refs;
|
||||
SmallVector<Value*, 4> in_loop_operands, operand_bounds;
|
||||
SmallVector<Attribute, 4> bound_types;
|
||||
SmallVector<IntegerAttr, 4> const_bounds;
|
||||
|
||||
if (parser.parseKeyword("with") || parser.parseLParen())
|
||||
return failure();
|
||||
|
||||
// A function to parse a lower or upper bound.
|
||||
auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types,
|
||||
&operand_bounds, &const_bounds](
|
||||
bool is_ub, size_t bound_pair_count) -> ParseResult {
|
||||
// Try parse an SSA operand.
|
||||
Value* bound;
|
||||
operand_parser.ParseOptionalOperand(builder.getIndexType(), bound);
|
||||
|
||||
if (bound != nullptr) {
|
||||
// Parsed an SSA id as bound.
|
||||
operand_bounds.emplace_back(bound);
|
||||
// Record bound_type as an operand type.
|
||||
bound_types.emplace_back(builder.getI32IntegerAttr(0));
|
||||
} else {
|
||||
// Bound is not an SSA id, then it must be an integer.
|
||||
// Parse an integer constant attribute.
|
||||
IntegerAttr boundAttr;
|
||||
if (parser.parseAttribute(boundAttr, builder.getIndexType(),
|
||||
KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub),
|
||||
result.attributes))
|
||||
return failure();
|
||||
const_bounds.emplace_back(
|
||||
builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue()));
|
||||
|
||||
// Record that the bound_type is a constant integer attribute.
|
||||
bound_types.emplace_back(builder.getI32IntegerAttr(1));
|
||||
}
|
||||
};
|
||||
|
||||
bool keep_parsing; // Do we keep parsing loops/bounds?
|
||||
size_t bound_pair_count = 0; // Record the number of bound pairs parsed.
|
||||
do {
|
||||
// Parse an input loop operand;
|
||||
Value* in_loop_operand;
|
||||
operand_parser.ParseOperand(LoopType::get(context), in_loop_operand);
|
||||
in_loop_operands.emplace_back(in_loop_operand);
|
||||
|
||||
parser.parseArrow();
|
||||
|
||||
// Parse induction variable.
|
||||
OpAsmParser::OperandType induction_var;
|
||||
if (parser.parseRegionArgument(induction_var) || parser.parseEqual())
|
||||
return failure();
|
||||
induction_var_refs.emplace_back(induction_var);
|
||||
|
||||
// Parse bound par (min to max).
|
||||
if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") ||
|
||||
parse_bound(true, bound_pair_count))
|
||||
return failure();
|
||||
|
||||
bound_pair_count++;
|
||||
// We may fail to parse a comma if an operand bound is followed by
|
||||
// a comma and the next input loop operand, in which case
|
||||
// the entire "{operand bound}, {input_loop_operand}" sequence will
|
||||
// be parsed as an operand list.
|
||||
parser.parseOptionalComma();
|
||||
|
||||
// If we don't see a RParen token, we keep parsing.
|
||||
keep_parsing = failed(parser.parseOptionalRParen());
|
||||
} while (keep_parsing);
|
||||
|
||||
// At this point, there shouldn't be any operands left to parse.
|
||||
if (operand_parser.has_operand_left())
|
||||
return parser.emitError(parser.getCurrentLocation());
|
||||
|
||||
// Record how many input loops did we parse.
|
||||
result.addOperands(in_loop_operands);
|
||||
result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(),
|
||||
builder.getI64IntegerAttr(in_loop_operands.size()));
|
||||
|
||||
// Add operand bounds to the list of operands of current operation.
|
||||
result.addOperands(operand_bounds);
|
||||
|
||||
// A list of 2N elements where the (2n) and (2n+1) th element specifies
|
||||
// whether the lower and upper bound of the n'th induction variable is stored
|
||||
// as an operand or as an attribute. N being the number of input loops
|
||||
// specified in this krnl.iterate operation.
|
||||
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
|
||||
builder.getArrayAttr(bound_types));
|
||||
|
||||
// Parse the schedule body region.
|
||||
Region* region = result.addRegion();
|
||||
SmallVector<Type, 4> induction_var_types(
|
||||
induction_var_refs.size(), builder.getIndexType());
|
||||
if (parser.parseRegion(*region, induction_var_refs, induction_var_types))
|
||||
return failure();
|
||||
|
||||
// Ensure iterate region is closed off with krnl.terminate.
|
||||
KrnlIterateOp::ensureTerminator(
|
||||
*region, parser.getBuilder(), result.location);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(KrnlIterateOp op) {
|
||||
// TODO: Verify number of induction variable bounds matches the number of
|
||||
// input loops.
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KrnlReturnLoopsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void print(OpAsmPrinter& p, KrnlReturnLoopsOp& op) {
|
||||
p << "krnl.return_loops ";
|
||||
p.printOperands(op.operand_begin(), op.operand_end());
|
||||
}
|
||||
|
||||
ParseResult parseKrnlReturnLoopsOp(
|
||||
OpAsmParser& parser, OperationState& result) {
|
||||
// Parse the loops to return.
|
||||
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
|
||||
if (parser.parseOperandList(timestamp_dim_handlers) ||
|
||||
parser.resolveOperands(timestamp_dim_handlers,
|
||||
LoopType::get(result.getContext()), result.operands))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/compiler/krnl.cpp.inc"
|
||||
} // namespace mlir
|
|
@ -0,0 +1,51 @@
|
|||
//===--------------------- krnl_ops.hpp - MLIR Operations -----------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_types.hpp"
|
||||
|
||||
namespace mlir {
|
||||
class KrnlOpsDialect : public Dialect {
|
||||
public:
|
||||
KrnlOpsDialect(MLIRContext* context);
|
||||
static StringRef getDialectNamespace() { return "krnl"; }
|
||||
|
||||
/// Parse a type registered to this dialect. Overriding this method is
|
||||
/// required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
mlir::Type parseType(
|
||||
llvm::StringRef tyData, mlir::Location loc) const override {
|
||||
MLIRContext* context = getContext();
|
||||
|
||||
if (tyData.consume_front("loop"))
|
||||
return LoopType::get(context);
|
||||
else
|
||||
return (emitError(loc, "Unexpected type: " + tyData), Type());
|
||||
}
|
||||
|
||||
/// Print a type registered to this dialect. Overriding this method is
|
||||
/// only required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
void printType(mlir::Type type, llvm::raw_ostream& os) const override {
|
||||
switch (type.getKind()) {
|
||||
case KrnlTypes::Loop:
|
||||
os << "loop";
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/compiler/krnl.hpp.inc"
|
||||
} // namespace mlir
|
|
@ -0,0 +1,209 @@
|
|||
//===--------------------- krnl_ops.td - MLIR Operations ------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Krnl_Dialect : Dialect {
|
||||
let name = "krnl";
|
||||
let cppNamespace = "";
|
||||
}
|
||||
|
||||
// Require regions to have krnl.terminate terminator operation.
|
||||
def ImplicitKrnlTerminator
|
||||
: SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
|
||||
|
||||
def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
|
||||
let summary = "define_loops operation";
|
||||
let description = [{
|
||||
|
||||
The "krnl.define_loops" operation is used to define input loops,
|
||||
those are the for loops appearing in the input program that we
|
||||
intend to optimize.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result,"
|
||||
"int64_t num_loops">
|
||||
];
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getNumLoopsAttrName() { return "num_loops"; }
|
||||
|
||||
// Helper function to extract the number of loops being defined.
|
||||
int64_t getNumLoops() {
|
||||
auto num_loops =
|
||||
getAttrOfType<IntegerAttr>(
|
||||
getNumLoopsAttrName())
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
return num_loops;
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
}
|
||||
|
||||
def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
|
||||
let summary = "optimize_loops operation";
|
||||
let description = [{
|
||||
|
||||
The "krnl.optimize_loops" operation is essentially a cosmetic operation
|
||||
which exists to encapsulate a region where loops are being scheduled/optimized.
|
||||
|
||||
The optimized loops are returned at the end of the
|
||||
region associated with the krnl.optimize_loops operation.
|
||||
|
||||
For example:
|
||||
TBD once we have actual schedule intrinsics.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"int timestamp_space_rank">
|
||||
];
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
||||
let summary = "iterate operation";
|
||||
let description = [{
|
||||
|
||||
The "krnl.iterate" operation is conceptually equivalent to a nested for loops.
|
||||
|
||||
For instance, say we have the following two
|
||||
%l0, %l1 = krnl.define_loops 2
|
||||
%o0, %o1 = krnl.optimize_loops {
|
||||
// Identity schedule.
|
||||
krnl.return_loops %l0, %l1
|
||||
}
|
||||
|
||||
Then, consider the following krnl.iterate operation:
|
||||
krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) {
|
||||
// Some operations.
|
||||
}
|
||||
|
||||
It is equivalent to:
|
||||
for (i0=0; i0<10; i0++)
|
||||
for (i1=0; i1<10; i1++)
|
||||
// Some operations.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
|
||||
let regions = (region SizedRegion<1>:$bodyRegion);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, "
|
||||
"ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, "
|
||||
"ArrayRef<int> bound_types">
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
||||
// In krnl.iterate operation, three types of SSA values are stored:
|
||||
// - Optimized krnl.loops.
|
||||
// - Input krnl.loops.
|
||||
// - SSA value based induction variable bound (parametric bound).
|
||||
// We record the number of optimized and input loops to separate these three
|
||||
// group of operands out.
|
||||
static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; }
|
||||
|
||||
int64_t getNumOptimizedLoops() {
|
||||
auto num_optimized_loops =
|
||||
getAttrOfType<IntegerAttr>(
|
||||
getNumOptimizedLoopsAttrName())
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
return num_optimized_loops;
|
||||
}
|
||||
|
||||
static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; }
|
||||
|
||||
int64_t getNumInputLoops() {
|
||||
auto num_loops =
|
||||
getAttrOfType<IntegerAttr>(
|
||||
getNumInputLoopsAttrName())
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
return num_loops;
|
||||
}
|
||||
|
||||
// Constant bounds are stored here as a list attribute.
|
||||
static StringRef getConstantBoundsAttrName() { return "constant_bounds"; }
|
||||
|
||||
// Store type of each bound as three types:
|
||||
// - 0 = constant attribute.
|
||||
// - 1 = operand type.
|
||||
// - 2 = affine maps (TODO).
|
||||
static StringRef getBoundTypesAttrName() { return "bound_types"; }
|
||||
|
||||
// Get dynamic attribute name for the i-th lower and upper bound.
|
||||
static std::string getBoundAttrName(int64_t i, bool is_ub) {
|
||||
std::string bound_type = is_ub ? "_ub" : "_lb";
|
||||
std::string bound_idx = std::to_string(i);
|
||||
return "__bound_" + bound_idx + bound_type;
|
||||
}
|
||||
}];
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def KrnlReturnLoopsOp : Op<Krnl_Dialect, "return_loops", [Terminator]> {
|
||||
let summary = "Krnl return handler operation";
|
||||
let description = [{
|
||||
Krnl return_loops operation is a terminator operation for returning
|
||||
scheduled dimension handlers in the krnl.optimize_loops region.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
|
||||
let summary = "Krnl terminator operation";
|
||||
let description = [{
|
||||
Krnl terminator is a special terminator operation for blocks inside krnl
|
||||
iterate operations. It unconditionally transmits the control flow to the
|
||||
successor of the operation enclosing the region.
|
||||
|
||||
This operation does _not_ have a custom syntax. However, krnl control
|
||||
operations omit the terminator in their custom syntax for brevity.
|
||||
}];
|
||||
|
||||
// No custom parsing/printing form.
|
||||
let parser = ?;
|
||||
let printer = ?;
|
||||
|
||||
// Fully specified by traits.
|
||||
let verifier = ?;
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
//===--------------------- krnl_types.cpp - MLIR Operations ---------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "krnl_types.hpp"
|
|
@ -0,0 +1,36 @@
|
|||
//===--------------------- krnl_types.hpp - MLIR Operations ---------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mlir/IR/Types.h>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace KrnlTypes {
|
||||
enum Kinds {
|
||||
// A krnl.loop is simply a reference to a for loop and will be used to:
|
||||
// - Indicate the presence of a for loop in krnl.iterate.
|
||||
// - Identify the loop in optimization intrinsics.
|
||||
Loop = mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
|
||||
};
|
||||
}
|
||||
|
||||
class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
// Support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; }
|
||||
|
||||
// Get a unique instance of Loop type.
|
||||
static LoopType get(mlir::MLIRContext* context) {
|
||||
return Base::get(context, KrnlTypes::Loop);
|
||||
}
|
||||
};
|
||||
} // namespace mlir
|
|
@ -0,0 +1,52 @@
|
|||
//===------------------ parser_helper.cpp - MLIR Operations ---------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "parser_helper.hpp"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
|
||||
namespace onnf {
|
||||
|
||||
mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||
mlir::Type operand_type, mlir::Value*& operand) {
|
||||
// If operand queue is empty, parse more operands and cache them.
|
||||
if (_operand_ref_queue.empty()) {
|
||||
// Parse operand types:
|
||||
llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> operand_refs;
|
||||
_parser.parseOperandList(operand_refs);
|
||||
|
||||
// Record operands:
|
||||
for (auto& operand_ref : operand_refs)
|
||||
_operand_ref_queue.emplace(operand_ref);
|
||||
}
|
||||
|
||||
// If we parsed some operand reference(s), resolve the ref to an operand:
|
||||
if (!_operand_ref_queue.empty()) {
|
||||
auto operand_ref = _operand_ref_queue.front();
|
||||
_operand_ref_queue.pop();
|
||||
|
||||
llvm::SmallVector<mlir::Value*, 1> operands;
|
||||
_parser.resolveOperand(operand_ref, operand_type, operands);
|
||||
operand = operands.front();
|
||||
return mlir::success();
|
||||
} else {
|
||||
operand = nullptr;
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
mlir::ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||
mlir::Type operand_type, mlir::Value*& operand) {
|
||||
ParseOptionalOperand(operand_type, operand);
|
||||
if (operand == nullptr)
|
||||
return _parser.emitError(
|
||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace onnf
|
|
@ -0,0 +1,46 @@
|
|||
//===------------------ parser_helper.hpp - MLIR Operations ---------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <queue>
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace onnf {
|
||||
|
||||
class KrnlDialectOperandParser {
|
||||
public:
|
||||
KrnlDialectOperandParser(mlir::OpAsmParser& parser)
|
||||
: _parser(parser), _builder(parser.getBuilder()){};
|
||||
|
||||
// Parse an optional operand.
|
||||
mlir::ParseResult ParseOptionalOperand(
|
||||
mlir::Type operand_type, mlir::Value*& operand);
|
||||
|
||||
// Parse a required operand.
|
||||
mlir::ParseResult ParseOperand(
|
||||
mlir::Type operand_type, mlir::Value*& operand);
|
||||
|
||||
// Do we have more operands to parse?
|
||||
bool has_operand_left() { return !_operand_ref_queue.empty(); }
|
||||
|
||||
private:
|
||||
mlir::OpAsmParser& _parser;
|
||||
|
||||
mlir::Builder& _builder;
|
||||
|
||||
// A queue storing the parsed SSA id references.
|
||||
std::queue<mlir::OpAsmParser::OperandType> _operand_ref_queue;
|
||||
};
|
||||
|
||||
} // namespace onnf
|
|
@ -1,27 +0,0 @@
|
|||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Knl_Dialect : Dialect {
|
||||
let name = "knl";
|
||||
let cppNamespace = "";
|
||||
}
|
||||
|
||||
def KnlIterate : Op<Knl_Dialect, "iterate"> {
|
||||
let summary = "iterate operation";
|
||||
let description = [{
|
||||
|
||||
The "knl.iterate" operation is conceptually equivalent to a nested for loop
|
||||
in that it represents ordered interation of integer coordinates within an
|
||||
affine integer set.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"IntegerSet set, ArrayRef<Value *> args">
|
||||
];
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "knl_ops.hpp"
|
||||
|
||||
namespace mlir {
|
||||
KnlOpsDialect::KnlOpsDialect(MLIRContext* context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "src/compiler/knl.cpp.inc"
|
||||
>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
namespace onnf {}
|
|
@ -1,19 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
class KnlOpsDialect : public Dialect {
|
||||
public:
|
||||
KnlOpsDialect(MLIRContext* context);
|
||||
static StringRef getDialectNamespace() { return "knl"; }
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/compiler/knl.hpp.inc"
|
||||
} // namespace mlir
|
||||
|
||||
namespace onnf {}
|
|
@ -16,6 +16,11 @@
|
|||
#include <mlir/Support/MlirOptMain.h>
|
||||
#include <mlir/Dialect/StandardOps/Ops.h>
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/compiler/helper.hpp"
|
||||
|
||||
using namespace onnf;
|
||||
|
||||
static llvm::cl::opt<std::string> input_filename(
|
||||
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
|
||||
|
||||
|
@ -52,6 +57,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
auto output = mlir::openOutputFile(output_filename, &error_message);
|
||||
|
||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
|
||||
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
|
||||
|
|
10
src/main.cpp
10
src/main.cpp
|
@ -1,3 +1,11 @@
|
|||
//===--------------------------- main.cpp ---------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
|
@ -21,6 +29,7 @@
|
|||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "src/builder/frontend_dialect_transformer.hpp"
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
#include "src/compiler/pass/passes.hpp"
|
||||
|
||||
|
@ -57,6 +66,7 @@ int main(int ac, char* av[]) {
|
|||
}
|
||||
|
||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
|
|
Loading…
Reference in New Issue