diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7085399..5f24f72 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,3 +3,5 @@ add_executable(onnf main.cpp) target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) target_link_libraries(onnf builder compiler ${Boost_LIBRARIES}) + +install(TARGETS onnf DESTINATION bin) diff --git a/src/compiler/CMakeLists.txt b/src/compiler/CMakeLists.txt index e7e069c..417b60f 100644 --- a/src/compiler/CMakeLists.txt +++ b/src/compiler/CMakeLists.txt @@ -1,9 +1,13 @@ add_library( compiler - ir/knl/knl_ops.cpp - ir/knl/knl_ops.hpp + dialect/krnl/krnl_ops.cpp + dialect/krnl/krnl_ops.hpp + dialect/krnl/krnl_types.cpp + dialect/krnl/krnl_types.hpp dialect/onnx/onnx_ops.cpp dialect/onnx/onnx_ops.hpp + dialect/krnl/parser_helper.cpp + dialect/krnl/parser_helper.hpp pass/shape_inference_pass.cpp pass/shape_inference_interface.hpp pass/passes.hpp) @@ -25,7 +29,7 @@ find_package(Boost 1.54.0 # target_link_libraries(compiler isl inja ${Boost_LIBRARIES}) target_link_libraries(compiler ${Boost_LIBRARIES} - ) + ${MLIRLIBS} curses) add_executable(onnf-opt tool/onnf_opt/onnf_opt.cpp) @@ -34,12 +38,6 @@ target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler) target_include_directories(onnf-opt PRIVATE ../..) target_include_directories(onnf-opt PRIVATE ${CMAKE_BINARY_DIR}) -set(LLVM_TARGET_DEFINITIONS ir/knl/knl.td) -onnf_tablegen(knl.hpp.inc -gen-op-decls) -onnf_tablegen(knl.cpp.inc -gen-op-defs) -add_public_tablegen_target(gen_kir) -add_dependencies(compiler gen_kir) - set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td) onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls) onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs) @@ -51,3 +49,10 @@ onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") add_public_tablegen_target(gen_onnx) add_dependencies(compiler gen_onnx) + +set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td) +onnf_tablegen(krnl.hpp.inc -gen-op-decls) +onnf_tablegen(krnl.cpp.inc -gen-op-defs) +add_public_tablegen_target(gen_krnl_ops) +add_dependencies(compiler gen_krnl_ops) +add_dependencies(onnf-opt gen_krnl_ops) diff --git a/src/compiler/dialect/krnl/krnl_ops.cpp b/src/compiler/dialect/krnl/krnl_ops.cpp new file mode 100644 index 0000000..3ba483d --- /dev/null +++ b/src/compiler/dialect/krnl/krnl_ops.cpp @@ -0,0 +1,398 @@ +//===--------------------- krnl_ops.cpp - MLIR Operations -----------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "src/compiler/dialect/krnl/parser_helper.hpp" + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "krnl_ops.hpp" + +using namespace mlir; + +namespace mlir { +KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "src/compiler/krnl.cpp.inc" + >(); + addTypes(); +} + +//===----------------------------------------------------------------------===// +// KrnlDefineLoopsOp +//===----------------------------------------------------------------------===// + +void KrnlDefineLoopsOp::build( + Builder* builder, OperationState& result, int64_t num_loops) { + // Create the same number of dimension handlers as the number of + // dimensions in the associated integer set. + result.types.append(num_loops, LoopType::get(builder->getContext())); + result.addAttribute( + getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops)); +} + +void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) { + auto num_loop_attr = op.getAttrOfType(op.getNumLoopsAttrName()); + p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue(); +} + +ParseResult parseKrnlDefineLoopsOp( + OpAsmParser& parser, OperationState& result) { + // Parse the attribute indicating number of loops defined. + IntegerAttr num_loops; + auto& builder = parser.getBuilder(); + auto int32_type = builder.getIntegerType(64); + if (parser.parseAttribute(num_loops, int32_type, + KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes)) + return failure(); + + auto loop_types = llvm::SmallVector( + num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext())); + if (parser.addTypesToList(loop_types, result.types)) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// KrnlOptimizeLoopsOp +//===----------------------------------------------------------------------===// + +void KrnlOptimizeLoopsOp::build( + Builder* builder, OperationState& result, int num_optimized_loops) { + result.types.append( + num_optimized_loops, LoopType::get(builder->getContext())); + // Create a region and a block for the body. + // Schedule intrinsics will be placed into this region. + Region* region = result.addRegion(); + auto* body = new Block(); + region->push_back(body); +} + +void print(OpAsmPrinter& p, KrnlOptimizeLoopsOp& op) { + p << "krnl.optimize_loops "; + p.printRegion(op.region(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + p << " : "; + p.printFunctionalType(op); +} + +ParseResult parseKrnlOptimizeLoopsOp( + OpAsmParser& parser, OperationState& result) { + // Parse the schedule body region. + Region* region = result.addRegion(); + if (parser.parseRegion(*region, llvm::None, llvm::None)) + return failure(); + + // Parse the function type for the schedule operation. + // Then following the hint of this parsed function type, parse the + // returned timestamp space dimension handlers. + FunctionType schedule_func_type; + if (parser.parseColonType(schedule_func_type) || + parser.addTypesToList(schedule_func_type.getResults(), result.types)) { + failure(); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// KrnlIterateOp +//===----------------------------------------------------------------------===// + +/*! + * Build a Krnl Dialect iterate operation. + * input_loops: a collection of input krnl.loops being optimized. + * optimized_loops: a collection of optimized (scheduled) krnl.loops. + * operand_bounds: a collection of SSA value bounds. + * const_bounds: a collection of constant bounds. + * bound_types: a collection of integer values indicating how bounds are given. + * 0 : bound is given as an integer in const_bounds. + * 1 : bound is given as an operand in operand_bounds. + * 2 : bound is given as an affine map. (TODO). + * + * The following example illustrates how induction variable bounds are parsed + * from builder function inputs: + * + * - operand_bounds = [N, M] + * - const_bounds = [10, 20] + * - bound_types = [0, 1, 1, 0] + * + * Then the bounds will be parsed as: + * %i0 = 10 to N : %i1 = M to 20 + */ +void KrnlIterateOp::build(Builder* builder, OperationState& result, + ArrayRef input_loops, ArrayRef optimized_loops, + ArrayRef operand_bounds, ArrayRef const_bounds, + ArrayRef bound_types) { + // Record optimized loops and the number of such loops. + result.addOperands(optimized_loops); + result.addAttribute(getNumOptimizedLoopsAttrName(), + builder->getI64IntegerAttr(optimized_loops.size())); + + // Record input loops and the number of such loops. + result.addOperands(input_loops); + result.addAttribute(getNumInputLoopsAttrName(), + builder->getI64IntegerAttr(input_loops.size())); + + // Record bound either as attribute or from operand list. + auto next_operand_bound = operand_bounds.begin(); + auto next_const_bound = const_bounds.begin(); + for (size_t i = 0; i < bound_types.size(); i++) { + auto bound_type = bound_types[i]; + if (bound_type == 0) { + // Constant bound. + result.addAttribute(getBoundAttrName(i / 2, i % 2), + builder->getI64IntegerAttr(*next_const_bound)); + next_const_bound = std::next(next_const_bound); + } else { + // Operand bound. + result.addOperands(*next_operand_bound); + next_operand_bound = std::next(next_operand_bound); + } + } + + // Record bound types as attribute: + result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(), + builder->getI32ArrayAttr(bound_types)); + + // Create a region and a block for the body. The arguments of the region is + // the loop induction variables; there can be multiple induction variables + // associated with the same krnl.iterate operation. + Region* bodyRegion = result.addRegion(); + auto* body = new Block(); + auto body_args = llvm::SmallVector( + input_loops.size(), IndexType::get(builder->getContext())); + body->addArguments(body_args); + bodyRegion->push_back(body); + + ensureTerminator(*bodyRegion, *builder, result.location); +} + +void print(OpAsmPrinter& p, KrnlIterateOp& op) { + p << "krnl.iterate("; + // Print optimized loops: + auto num_optimized_loops = op.getNumOptimizedLoops(); + p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops); + p << ") with ("; + + // Set up iterator to input loops: + auto num_input_loops = op.getNumInputLoops(); + auto input_loop_begin = op.operand_begin() + num_optimized_loops; + + // Set up iterators to operand bounds. + auto next_operand_bound = input_loop_begin + num_input_loops; + + // Function to print a lower or upper bound. + auto print_bound = [&](ArrayRef bound_types, size_t idx) { + IntegerAttr type = bound_types[idx].dyn_cast(); + if (type.getValue().getSExtValue() == 0) { + // Bound is an operand. + p.printOperand(*next_operand_bound); + next_operand_bound = std::next(next_operand_bound); + } else { + // Bound is an integer attribute. + auto bound_idx = idx / 2; + auto is_ub = idx % 2; + IntegerAttr bound = op.getAttrOfType( + KrnlIterateOp::getBoundAttrName(bound_idx, is_ub)); + p << bound.getValue().getSExtValue(); + } + }; + + auto induction_variables = op.bodyRegion().front().getArguments(); + ArrayRef bound_types = + op.getAttrOfType(KrnlIterateOp::getBoundTypesAttrName()) + .getValue(); + + // Print input loop operands, induction variables and their ranges. + for (size_t i = 0; i < num_input_loops; i++) { + if (i != 0) + p << ", "; + + p.printOperand(*std::next(input_loop_begin, i)); + p << " -> "; + + // Print induction variable block argument. + p.printOperand(induction_variables[i]); + p << " = "; + + print_bound(bound_types, 2 * i); // Print lower bound. + p << " to "; + print_bound(bound_types, 2 * i + 1); // Print upper bound. + } + p << ")"; + + p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); +} + +ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) { + auto builder = parser.getBuilder(); + auto context = builder.getContext(); + onnf::KrnlDialectOperandParser operand_parser(parser); + + // Parse optimized loops: + SmallVector num_optimized_loops; + if (parser.parseOperandList( + num_optimized_loops, OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(num_optimized_loops, + LoopType::get(result.getContext()), result.operands)) + return failure(); + + // Record how many optimized loops did we parse. + result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(), + builder.getI64IntegerAttr(num_optimized_loops.size())); + + // Parse input loops and their lower and upper bounds. + SmallVector in_loop_refs, induction_var_refs; + SmallVector in_loop_operands, operand_bounds; + SmallVector bound_types; + SmallVector const_bounds; + + if (parser.parseKeyword("with") || parser.parseLParen()) + return failure(); + + // A function to parse a lower or upper bound. + auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types, + &operand_bounds, &const_bounds]( + bool is_ub, size_t bound_pair_count) -> ParseResult { + // Try parse an SSA operand. + Value* bound; + operand_parser.ParseOptionalOperand(builder.getIndexType(), bound); + + if (bound != nullptr) { + // Parsed an SSA id as bound. + operand_bounds.emplace_back(bound); + // Record bound_type as an operand type. + bound_types.emplace_back(builder.getI32IntegerAttr(0)); + } else { + // Bound is not an SSA id, then it must be an integer. + // Parse an integer constant attribute. + IntegerAttr boundAttr; + if (parser.parseAttribute(boundAttr, builder.getIndexType(), + KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub), + result.attributes)) + return failure(); + const_bounds.emplace_back( + builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue())); + + // Record that the bound_type is a constant integer attribute. + bound_types.emplace_back(builder.getI32IntegerAttr(1)); + } + }; + + bool keep_parsing; // Do we keep parsing loops/bounds? + size_t bound_pair_count = 0; // Record the number of bound pairs parsed. + do { + // Parse an input loop operand; + Value* in_loop_operand; + operand_parser.ParseOperand(LoopType::get(context), in_loop_operand); + in_loop_operands.emplace_back(in_loop_operand); + + parser.parseArrow(); + + // Parse induction variable. + OpAsmParser::OperandType induction_var; + if (parser.parseRegionArgument(induction_var) || parser.parseEqual()) + return failure(); + induction_var_refs.emplace_back(induction_var); + + // Parse bound par (min to max). + if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") || + parse_bound(true, bound_pair_count)) + return failure(); + + bound_pair_count++; + // We may fail to parse a comma if an operand bound is followed by + // a comma and the next input loop operand, in which case + // the entire "{operand bound}, {input_loop_operand}" sequence will + // be parsed as an operand list. + parser.parseOptionalComma(); + + // If we don't see a RParen token, we keep parsing. + keep_parsing = failed(parser.parseOptionalRParen()); + } while (keep_parsing); + + // At this point, there shouldn't be any operands left to parse. + if (operand_parser.has_operand_left()) + return parser.emitError(parser.getCurrentLocation()); + + // Record how many input loops did we parse. + result.addOperands(in_loop_operands); + result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(), + builder.getI64IntegerAttr(in_loop_operands.size())); + + // Add operand bounds to the list of operands of current operation. + result.addOperands(operand_bounds); + + // A list of 2N elements where the (2n) and (2n+1) th element specifies + // whether the lower and upper bound of the n'th induction variable is stored + // as an operand or as an attribute. N being the number of input loops + // specified in this krnl.iterate operation. + result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(), + builder.getArrayAttr(bound_types)); + + // Parse the schedule body region. + Region* region = result.addRegion(); + SmallVector induction_var_types( + induction_var_refs.size(), builder.getIndexType()); + if (parser.parseRegion(*region, induction_var_refs, induction_var_types)) + return failure(); + + // Ensure iterate region is closed off with krnl.terminate. + KrnlIterateOp::ensureTerminator( + *region, parser.getBuilder(), result.location); + + return success(); +} + +static LogicalResult verify(KrnlIterateOp op) { + // TODO: Verify number of induction variable bounds matches the number of + // input loops. +} + +//===----------------------------------------------------------------------===// +// KrnlReturnLoopsOp +//===----------------------------------------------------------------------===// + +void print(OpAsmPrinter& p, KrnlReturnLoopsOp& op) { + p << "krnl.return_loops "; + p.printOperands(op.operand_begin(), op.operand_end()); +} + +ParseResult parseKrnlReturnLoopsOp( + OpAsmParser& parser, OperationState& result) { + // Parse the loops to return. + SmallVector timestamp_dim_handlers; + if (parser.parseOperandList(timestamp_dim_handlers) || + parser.resolveOperands(timestamp_dim_handlers, + LoopType::get(result.getContext()), result.operands)) + return failure(); + + return success(); +} + +#define GET_OP_CLASSES +#include "src/compiler/krnl.cpp.inc" +} // namespace mlir diff --git a/src/compiler/dialect/krnl/krnl_ops.hpp b/src/compiler/dialect/krnl/krnl_ops.hpp new file mode 100644 index 0000000..1d41dd5 --- /dev/null +++ b/src/compiler/dialect/krnl/krnl_ops.hpp @@ -0,0 +1,51 @@ +//===--------------------- krnl_ops.hpp - MLIR Operations -----------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +#include "src/compiler/dialect/krnl/krnl_types.hpp" + +namespace mlir { +class KrnlOpsDialect : public Dialect { + public: + KrnlOpsDialect(MLIRContext* context); + static StringRef getDialectNamespace() { return "krnl"; } + + /// Parse a type registered to this dialect. Overriding this method is + /// required for dialects that have custom types. + /// Technically this is only needed to be able to round-trip to textual IR. + mlir::Type parseType( + llvm::StringRef tyData, mlir::Location loc) const override { + MLIRContext* context = getContext(); + + if (tyData.consume_front("loop")) + return LoopType::get(context); + else + return (emitError(loc, "Unexpected type: " + tyData), Type()); + } + + /// Print a type registered to this dialect. Overriding this method is + /// only required for dialects that have custom types. + /// Technically this is only needed to be able to round-trip to textual IR. + void printType(mlir::Type type, llvm::raw_ostream& os) const override { + switch (type.getKind()) { + case KrnlTypes::Loop: + os << "loop"; + return; + } + } +}; + +#define GET_OP_CLASSES +#include "src/compiler/krnl.hpp.inc" +} // namespace mlir diff --git a/src/compiler/dialect/krnl/krnl_ops.td b/src/compiler/dialect/krnl/krnl_ops.td new file mode 100644 index 0000000..e6a4830 --- /dev/null +++ b/src/compiler/dialect/krnl/krnl_ops.td @@ -0,0 +1,209 @@ +//===--------------------- krnl_ops.td - MLIR Operations ------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +//===----------------------------------------------------------------------===// + +include "mlir/IR/OpBase.td" + +def Krnl_Dialect : Dialect { + let name = "krnl"; + let cppNamespace = ""; +} + +// Require regions to have krnl.terminate terminator operation. +def ImplicitKrnlTerminator + : SingleBlockImplicitTerminator<"KrnlTerminatorOp">; + +def KrnlDefineLoopsOp : Op { + let summary = "define_loops operation"; + let description = [{ + + The "krnl.define_loops" operation is used to define input loops, + those are the for loops appearing in the input program that we + intend to optimize. + + }]; + + let arguments = (ins); + let results = (outs Variadic); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result," + "int64_t num_loops"> + ]; + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; + + let extraClassDeclaration = [{ + static StringRef getNumLoopsAttrName() { return "num_loops"; } + + // Helper function to extract the number of loops being defined. + int64_t getNumLoops() { + auto num_loops = + getAttrOfType( + getNumLoopsAttrName()) + .getValue() + .getSExtValue(); + return num_loops; + } + }]; + + +} + +def KrnlOptimizeLoopsOp : Op { + let summary = "optimize_loops operation"; + let description = [{ + + The "krnl.optimize_loops" operation is essentially a cosmetic operation + which exists to encapsulate a region where loops are being scheduled/optimized. + + The optimized loops are returned at the end of the + region associated with the krnl.optimize_loops operation. + + For example: + TBD once we have actual schedule intrinsics. + + }]; + + let arguments = (ins Variadic); + let results = (outs Variadic); + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "int timestamp_space_rank"> + ]; + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +def KrnlIterateOp : Op { + let summary = "iterate operation"; + let description = [{ + + The "krnl.iterate" operation is conceptually equivalent to a nested for loops. + + For instance, say we have the following two + %l0, %l1 = krnl.define_loops 2 + %o0, %o1 = krnl.optimize_loops { + // Identity schedule. + krnl.return_loops %l0, %l1 + } + + Then, consider the following krnl.iterate operation: + krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) { + // Some operations. + } + + It is equivalent to: + for (i0=0; i0<10; i0++) + for (i1=0; i1<10; i1++) + // Some operations. + }]; + + let arguments = (ins Variadic); + + let regions = (region SizedRegion<1>:$bodyRegion); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "ArrayRef input_loops, ArrayRef optimized_loops, " + "ArrayRef operand_bounds, ArrayRef const_bounds, " + "ArrayRef bound_types"> + ]; + + let extraClassDeclaration = [{ + + // In krnl.iterate operation, three types of SSA values are stored: + // - Optimized krnl.loops. + // - Input krnl.loops. + // - SSA value based induction variable bound (parametric bound). + // We record the number of optimized and input loops to separate these three + // group of operands out. + static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; } + + int64_t getNumOptimizedLoops() { + auto num_optimized_loops = + getAttrOfType( + getNumOptimizedLoopsAttrName()) + .getValue() + .getSExtValue(); + return num_optimized_loops; + } + + static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; } + + int64_t getNumInputLoops() { + auto num_loops = + getAttrOfType( + getNumInputLoopsAttrName()) + .getValue() + .getSExtValue(); + return num_loops; + } + + // Constant bounds are stored here as a list attribute. + static StringRef getConstantBoundsAttrName() { return "constant_bounds"; } + + // Store type of each bound as three types: + // - 0 = constant attribute. + // - 1 = operand type. + // - 2 = affine maps (TODO). + static StringRef getBoundTypesAttrName() { return "bound_types"; } + + // Get dynamic attribute name for the i-th lower and upper bound. + static std::string getBoundAttrName(int64_t i, bool is_ub) { + std::string bound_type = is_ub ? "_ub" : "_lb"; + std::string bound_idx = std::to_string(i); + return "__bound_" + bound_idx + bound_type; + } + }]; + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; + let verifier = [{ return ::verify(*this); }]; +} + +def KrnlReturnLoopsOp : Op { + let summary = "Krnl return handler operation"; + let description = [{ + Krnl return_loops operation is a terminator operation for returning + scheduled dimension handlers in the krnl.optimize_loops region. + }]; + + let arguments = (ins Variadic); + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +def KrnlTerminatorOp : Op { + let summary = "Krnl terminator operation"; + let description = [{ + Krnl terminator is a special terminator operation for blocks inside krnl + iterate operations. It unconditionally transmits the control flow to the + successor of the operation enclosing the region. + + This operation does _not_ have a custom syntax. However, krnl control + operations omit the terminator in their custom syntax for brevity. + }]; + + // No custom parsing/printing form. + let parser = ?; + let printer = ?; + + // Fully specified by traits. + let verifier = ?; +} diff --git a/src/compiler/dialect/krnl/krnl_types.cpp b/src/compiler/dialect/krnl/krnl_types.cpp new file mode 100644 index 0000000..48ac166 --- /dev/null +++ b/src/compiler/dialect/krnl/krnl_types.cpp @@ -0,0 +1,9 @@ +//===--------------------- krnl_types.cpp - MLIR Operations ---------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +//===----------------------------------------------------------------------===// + +#include "krnl_types.hpp" diff --git a/src/compiler/dialect/krnl/krnl_types.hpp b/src/compiler/dialect/krnl/krnl_types.hpp new file mode 100644 index 0000000..2e26d95 --- /dev/null +++ b/src/compiler/dialect/krnl/krnl_types.hpp @@ -0,0 +1,36 @@ +//===--------------------- krnl_types.hpp - MLIR Operations ---------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace mlir { + +namespace KrnlTypes { +enum Kinds { + // A krnl.loop is simply a reference to a for loop and will be used to: + // - Indicate the presence of a for loop in krnl.iterate. + // - Identify the loop in optimization intrinsics. + Loop = mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE, +}; +} + +class LoopType : public mlir::Type::TypeBase { + public: + using Base::Base; + + // Support type inquiry through isa, cast and dyn_cast. + static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; } + + // Get a unique instance of Loop type. + static LoopType get(mlir::MLIRContext* context) { + return Base::get(context, KrnlTypes::Loop); + } +}; +} // namespace mlir diff --git a/src/compiler/dialect/krnl/parser_helper.cpp b/src/compiler/dialect/krnl/parser_helper.cpp new file mode 100644 index 0000000..814b0fa --- /dev/null +++ b/src/compiler/dialect/krnl/parser_helper.cpp @@ -0,0 +1,52 @@ +//===------------------ parser_helper.cpp - MLIR Operations ---------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +//===----------------------------------------------------------------------===// + +#include "parser_helper.hpp" + +#include "src/compiler/dialect/krnl/krnl_ops.hpp" + +namespace onnf { + +mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand( + mlir::Type operand_type, mlir::Value*& operand) { + // If operand queue is empty, parse more operands and cache them. + if (_operand_ref_queue.empty()) { + // Parse operand types: + llvm::SmallVector operand_refs; + _parser.parseOperandList(operand_refs); + + // Record operands: + for (auto& operand_ref : operand_refs) + _operand_ref_queue.emplace(operand_ref); + } + + // If we parsed some operand reference(s), resolve the ref to an operand: + if (!_operand_ref_queue.empty()) { + auto operand_ref = _operand_ref_queue.front(); + _operand_ref_queue.pop(); + + llvm::SmallVector operands; + _parser.resolveOperand(operand_ref, operand_type, operands); + operand = operands.front(); + return mlir::success(); + } else { + operand = nullptr; + return mlir::failure(); + } +} + +mlir::ParseResult KrnlDialectOperandParser::ParseOperand( + mlir::Type operand_type, mlir::Value*& operand) { + ParseOptionalOperand(operand_type, operand); + if (operand == nullptr) + return _parser.emitError( + _parser.getCurrentLocation(), "Expecting an operand."); + return mlir::success(); +} + +} // namespace onnf diff --git a/src/compiler/dialect/krnl/parser_helper.hpp b/src/compiler/dialect/krnl/parser_helper.hpp new file mode 100644 index 0000000..e1928fd --- /dev/null +++ b/src/compiler/dialect/krnl/parser_helper.hpp @@ -0,0 +1,46 @@ +//===------------------ parser_helper.hpp - MLIR Operations ---------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +namespace onnf { + +class KrnlDialectOperandParser { + public: + KrnlDialectOperandParser(mlir::OpAsmParser& parser) + : _parser(parser), _builder(parser.getBuilder()){}; + + // Parse an optional operand. + mlir::ParseResult ParseOptionalOperand( + mlir::Type operand_type, mlir::Value*& operand); + + // Parse a required operand. + mlir::ParseResult ParseOperand( + mlir::Type operand_type, mlir::Value*& operand); + + // Do we have more operands to parse? + bool has_operand_left() { return !_operand_ref_queue.empty(); } + + private: + mlir::OpAsmParser& _parser; + + mlir::Builder& _builder; + + // A queue storing the parsed SSA id references. + std::queue _operand_ref_queue; +}; + +} // namespace onnf diff --git a/src/compiler/ir/knl/knl.td b/src/compiler/ir/knl/knl.td deleted file mode 100644 index 9ff4758..0000000 --- a/src/compiler/ir/knl/knl.td +++ /dev/null @@ -1,27 +0,0 @@ -include "mlir/IR/OpBase.td" - -def Knl_Dialect : Dialect { - let name = "knl"; - let cppNamespace = ""; -} - -def KnlIterate : Op { - let summary = "iterate operation"; - let description = [{ - - The "knl.iterate" operation is conceptually equivalent to a nested for loop - in that it represents ordered interation of integer coordinates within an - affine integer set. - - }]; - - let arguments = (ins Variadic); - let regions = (region SizedRegion<1>:$region); - - let skipDefaultBuilders = 1; - - let builders = [ - OpBuilder<"Builder *builder, OperationState &result, " - "IntegerSet set, ArrayRef args"> - ]; -} \ No newline at end of file diff --git a/src/compiler/ir/knl/knl_ops.cpp b/src/compiler/ir/knl/knl_ops.cpp deleted file mode 100644 index 10ad6eb..0000000 --- a/src/compiler/ir/knl/knl_ops.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallBitVector.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" - -#include "knl_ops.hpp" - -namespace mlir { -KnlOpsDialect::KnlOpsDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { - addOperations< -#define GET_OP_LIST -#include "src/compiler/knl.cpp.inc" - >(); -} -} // namespace mlir - -namespace onnf {} diff --git a/src/compiler/ir/knl/knl_ops.hpp b/src/compiler/ir/knl/knl_ops.hpp deleted file mode 100644 index aaa55be..0000000 --- a/src/compiler/ir/knl/knl_ops.hpp +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -class KnlOpsDialect : public Dialect { - public: - KnlOpsDialect(MLIRContext* context); - static StringRef getDialectNamespace() { return "knl"; } -}; - -#define GET_OP_CLASSES -#include "src/compiler/knl.hpp.inc" -} // namespace mlir - -namespace onnf {} diff --git a/src/compiler/tool/onnf_opt/onnf_opt.cpp b/src/compiler/tool/onnf_opt/onnf_opt.cpp index b91d410..081faa9 100644 --- a/src/compiler/tool/onnf_opt/onnf_opt.cpp +++ b/src/compiler/tool/onnf_opt/onnf_opt.cpp @@ -16,6 +16,11 @@ #include #include +#include "src/compiler/dialect/krnl/krnl_ops.hpp" +#include "src/compiler/helper.hpp" + +using namespace onnf; + static llvm::cl::opt input_filename( llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-")); @@ -52,6 +57,7 @@ int main(int argc, char** argv) { auto output = mlir::openOutputFile(output_filename, &error_message); + mlir::registerDialect(); mlir::registerDialect(); return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline, diff --git a/src/main.cpp b/src/main.cpp index bcbcd4b..a7c05f8 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,3 +1,11 @@ +//===--------------------------- main.cpp ---------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +//===----------------------------------------------------------------------===// + #include #include #include @@ -21,6 +29,7 @@ #include #include "src/builder/frontend_dialect_transformer.hpp" +#include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/pass/passes.hpp" @@ -57,6 +66,7 @@ int main(int ac, char* av[]) { } mlir::registerDialect(); + mlir::registerDialect(); mlir::MLIRContext context; mlir::OwningModuleRef module;