[MLIR] Krnl Dialect Definition (#357)
* a complete, roud-trippable Krnl dialect operation definition * remove old dialect definition files, edit build files * register dialect * check in src for onnf_opt and dimension handler types * re-trigger jenkins * fix build * clarify operation semantics * add verifier for krnl.iterate * refactor to make things clear * do not hard code types * nit and add comments * fix rebase * update op implementation * fix merge * update kernel dialect definition * more comment on how to use the builder for krnl.iterate operation * ammend the comment * can parse krnl.iterate * can parse and print if bounds are not SSA values * address comments * better error handling * Update CMakeLists.txt * update comment * reflow comments
This commit is contained in:
parent
03be41f7df
commit
780e6f0aa0
|
@ -3,3 +3,5 @@ add_executable(onnf main.cpp)
|
||||||
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||||
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
target_link_libraries(onnf builder compiler ${Boost_LIBRARIES})
|
target_link_libraries(onnf builder compiler ${Boost_LIBRARIES})
|
||||||
|
|
||||||
|
install(TARGETS onnf DESTINATION bin)
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
add_library(
|
add_library(
|
||||||
compiler
|
compiler
|
||||||
ir/knl/knl_ops.cpp
|
dialect/krnl/krnl_ops.cpp
|
||||||
ir/knl/knl_ops.hpp
|
dialect/krnl/krnl_ops.hpp
|
||||||
|
dialect/krnl/krnl_types.cpp
|
||||||
|
dialect/krnl/krnl_types.hpp
|
||||||
dialect/onnx/onnx_ops.cpp
|
dialect/onnx/onnx_ops.cpp
|
||||||
dialect/onnx/onnx_ops.hpp
|
dialect/onnx/onnx_ops.hpp
|
||||||
|
dialect/krnl/parser_helper.cpp
|
||||||
|
dialect/krnl/parser_helper.hpp
|
||||||
pass/shape_inference_pass.cpp
|
pass/shape_inference_pass.cpp
|
||||||
pass/shape_inference_interface.hpp
|
pass/shape_inference_interface.hpp
|
||||||
pass/passes.hpp)
|
pass/passes.hpp)
|
||||||
|
@ -25,7 +29,7 @@ find_package(Boost 1.54.0
|
||||||
# target_link_libraries(compiler isl inja ${Boost_LIBRARIES})
|
# target_link_libraries(compiler isl inja ${Boost_LIBRARIES})
|
||||||
target_link_libraries(compiler
|
target_link_libraries(compiler
|
||||||
${Boost_LIBRARIES}
|
${Boost_LIBRARIES}
|
||||||
)
|
${MLIRLIBS} curses)
|
||||||
|
|
||||||
add_executable(onnf-opt
|
add_executable(onnf-opt
|
||||||
tool/onnf_opt/onnf_opt.cpp)
|
tool/onnf_opt/onnf_opt.cpp)
|
||||||
|
@ -34,12 +38,6 @@ target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler)
|
||||||
target_include_directories(onnf-opt PRIVATE ../..)
|
target_include_directories(onnf-opt PRIVATE ../..)
|
||||||
target_include_directories(onnf-opt PRIVATE ${CMAKE_BINARY_DIR})
|
target_include_directories(onnf-opt PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ir/knl/knl.td)
|
|
||||||
onnf_tablegen(knl.hpp.inc -gen-op-decls)
|
|
||||||
onnf_tablegen(knl.cpp.inc -gen-op-defs)
|
|
||||||
add_public_tablegen_target(gen_kir)
|
|
||||||
add_dependencies(compiler gen_kir)
|
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
||||||
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
||||||
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
||||||
|
@ -51,3 +49,10 @@ onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
add_public_tablegen_target(gen_onnx)
|
add_public_tablegen_target(gen_onnx)
|
||||||
add_dependencies(compiler gen_onnx)
|
add_dependencies(compiler gen_onnx)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
||||||
|
onnf_tablegen(krnl.hpp.inc -gen-op-decls)
|
||||||
|
onnf_tablegen(krnl.cpp.inc -gen-op-defs)
|
||||||
|
add_public_tablegen_target(gen_krnl_ops)
|
||||||
|
add_dependencies(compiler gen_krnl_ops)
|
||||||
|
add_dependencies(onnf-opt gen_krnl_ops)
|
||||||
|
|
|
@ -0,0 +1,398 @@
|
||||||
|
//===--------------------- krnl_ops.cpp - MLIR Operations -----------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/parser_helper.hpp"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Function.h"
|
||||||
|
#include "mlir/IR/IntegerSet.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/Module.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "krnl_ops.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
KrnlOpsDialect::KrnlOpsDialect(MLIRContext* context)
|
||||||
|
: Dialect(getDialectNamespace(), context) {
|
||||||
|
addOperations<
|
||||||
|
#define GET_OP_LIST
|
||||||
|
#include "src/compiler/krnl.cpp.inc"
|
||||||
|
>();
|
||||||
|
addTypes<LoopType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KrnlDefineLoopsOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void KrnlDefineLoopsOp::build(
|
||||||
|
Builder* builder, OperationState& result, int64_t num_loops) {
|
||||||
|
// Create the same number of dimension handlers as the number of
|
||||||
|
// dimensions in the associated integer set.
|
||||||
|
result.types.append(num_loops, LoopType::get(builder->getContext()));
|
||||||
|
result.addAttribute(
|
||||||
|
getNumLoopsAttrName(), builder->getI32IntegerAttr(num_loops));
|
||||||
|
}
|
||||||
|
|
||||||
|
void print(OpAsmPrinter& p, KrnlDefineLoopsOp& op) {
|
||||||
|
auto num_loop_attr = op.getAttrOfType<IntegerAttr>(op.getNumLoopsAttrName());
|
||||||
|
p << "krnl.define_loops " << num_loop_attr.getValue().getSExtValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult parseKrnlDefineLoopsOp(
|
||||||
|
OpAsmParser& parser, OperationState& result) {
|
||||||
|
// Parse the attribute indicating number of loops defined.
|
||||||
|
IntegerAttr num_loops;
|
||||||
|
auto& builder = parser.getBuilder();
|
||||||
|
auto int32_type = builder.getIntegerType(64);
|
||||||
|
if (parser.parseAttribute(num_loops, int32_type,
|
||||||
|
KrnlDefineLoopsOp::getNumLoopsAttrName(), result.attributes))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto loop_types = llvm::SmallVector<Type, 4>(
|
||||||
|
num_loops.getValue().getSExtValue(), LoopType::get(builder.getContext()));
|
||||||
|
if (parser.addTypesToList(loop_types, result.types))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KrnlOptimizeLoopsOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void KrnlOptimizeLoopsOp::build(
|
||||||
|
Builder* builder, OperationState& result, int num_optimized_loops) {
|
||||||
|
result.types.append(
|
||||||
|
num_optimized_loops, LoopType::get(builder->getContext()));
|
||||||
|
// Create a region and a block for the body.
|
||||||
|
// Schedule intrinsics will be placed into this region.
|
||||||
|
Region* region = result.addRegion();
|
||||||
|
auto* body = new Block();
|
||||||
|
region->push_back(body);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print(OpAsmPrinter& p, KrnlOptimizeLoopsOp& op) {
|
||||||
|
p << "krnl.optimize_loops ";
|
||||||
|
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
|
||||||
|
/*printBlockTerminators=*/true);
|
||||||
|
p << " : ";
|
||||||
|
p.printFunctionalType(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult parseKrnlOptimizeLoopsOp(
|
||||||
|
OpAsmParser& parser, OperationState& result) {
|
||||||
|
// Parse the schedule body region.
|
||||||
|
Region* region = result.addRegion();
|
||||||
|
if (parser.parseRegion(*region, llvm::None, llvm::None))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Parse the function type for the schedule operation.
|
||||||
|
// Then following the hint of this parsed function type, parse the
|
||||||
|
// returned timestamp space dimension handlers.
|
||||||
|
FunctionType schedule_func_type;
|
||||||
|
if (parser.parseColonType(schedule_func_type) ||
|
||||||
|
parser.addTypesToList(schedule_func_type.getResults(), result.types)) {
|
||||||
|
failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KrnlIterateOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Build a Krnl Dialect iterate operation.
|
||||||
|
* input_loops: a collection of input krnl.loops being optimized.
|
||||||
|
* optimized_loops: a collection of optimized (scheduled) krnl.loops.
|
||||||
|
* operand_bounds: a collection of SSA value bounds.
|
||||||
|
* const_bounds: a collection of constant bounds.
|
||||||
|
* bound_types: a collection of integer values indicating how bounds are given.
|
||||||
|
* 0 : bound is given as an integer in const_bounds.
|
||||||
|
* 1 : bound is given as an operand in operand_bounds.
|
||||||
|
* 2 : bound is given as an affine map. (TODO).
|
||||||
|
*
|
||||||
|
* The following example illustrates how induction variable bounds are parsed
|
||||||
|
* from builder function inputs:
|
||||||
|
*
|
||||||
|
* - operand_bounds = [N, M]
|
||||||
|
* - const_bounds = [10, 20]
|
||||||
|
* - bound_types = [0, 1, 1, 0]
|
||||||
|
*
|
||||||
|
* Then the bounds will be parsed as:
|
||||||
|
* %i0 = 10 to N : %i1 = M to 20
|
||||||
|
*/
|
||||||
|
void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
||||||
|
ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops,
|
||||||
|
ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds,
|
||||||
|
ArrayRef<int> bound_types) {
|
||||||
|
// Record optimized loops and the number of such loops.
|
||||||
|
result.addOperands(optimized_loops);
|
||||||
|
result.addAttribute(getNumOptimizedLoopsAttrName(),
|
||||||
|
builder->getI64IntegerAttr(optimized_loops.size()));
|
||||||
|
|
||||||
|
// Record input loops and the number of such loops.
|
||||||
|
result.addOperands(input_loops);
|
||||||
|
result.addAttribute(getNumInputLoopsAttrName(),
|
||||||
|
builder->getI64IntegerAttr(input_loops.size()));
|
||||||
|
|
||||||
|
// Record bound either as attribute or from operand list.
|
||||||
|
auto next_operand_bound = operand_bounds.begin();
|
||||||
|
auto next_const_bound = const_bounds.begin();
|
||||||
|
for (size_t i = 0; i < bound_types.size(); i++) {
|
||||||
|
auto bound_type = bound_types[i];
|
||||||
|
if (bound_type == 0) {
|
||||||
|
// Constant bound.
|
||||||
|
result.addAttribute(getBoundAttrName(i / 2, i % 2),
|
||||||
|
builder->getI64IntegerAttr(*next_const_bound));
|
||||||
|
next_const_bound = std::next(next_const_bound);
|
||||||
|
} else {
|
||||||
|
// Operand bound.
|
||||||
|
result.addOperands(*next_operand_bound);
|
||||||
|
next_operand_bound = std::next(next_operand_bound);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record bound types as attribute:
|
||||||
|
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
|
||||||
|
builder->getI32ArrayAttr(bound_types));
|
||||||
|
|
||||||
|
// Create a region and a block for the body. The arguments of the region is
|
||||||
|
// the loop induction variables; there can be multiple induction variables
|
||||||
|
// associated with the same krnl.iterate operation.
|
||||||
|
Region* bodyRegion = result.addRegion();
|
||||||
|
auto* body = new Block();
|
||||||
|
auto body_args = llvm::SmallVector<Type, 4>(
|
||||||
|
input_loops.size(), IndexType::get(builder->getContext()));
|
||||||
|
body->addArguments(body_args);
|
||||||
|
bodyRegion->push_back(body);
|
||||||
|
|
||||||
|
ensureTerminator(*bodyRegion, *builder, result.location);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
||||||
|
p << "krnl.iterate(";
|
||||||
|
// Print optimized loops:
|
||||||
|
auto num_optimized_loops = op.getNumOptimizedLoops();
|
||||||
|
p.printOperands(op.operand_begin(), op.operand_begin() + num_optimized_loops);
|
||||||
|
p << ") with (";
|
||||||
|
|
||||||
|
// Set up iterator to input loops:
|
||||||
|
auto num_input_loops = op.getNumInputLoops();
|
||||||
|
auto input_loop_begin = op.operand_begin() + num_optimized_loops;
|
||||||
|
|
||||||
|
// Set up iterators to operand bounds.
|
||||||
|
auto next_operand_bound = input_loop_begin + num_input_loops;
|
||||||
|
|
||||||
|
// Function to print a lower or upper bound.
|
||||||
|
auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) {
|
||||||
|
IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>();
|
||||||
|
if (type.getValue().getSExtValue() == 0) {
|
||||||
|
// Bound is an operand.
|
||||||
|
p.printOperand(*next_operand_bound);
|
||||||
|
next_operand_bound = std::next(next_operand_bound);
|
||||||
|
} else {
|
||||||
|
// Bound is an integer attribute.
|
||||||
|
auto bound_idx = idx / 2;
|
||||||
|
auto is_ub = idx % 2;
|
||||||
|
IntegerAttr bound = op.getAttrOfType<IntegerAttr>(
|
||||||
|
KrnlIterateOp::getBoundAttrName(bound_idx, is_ub));
|
||||||
|
p << bound.getValue().getSExtValue();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto induction_variables = op.bodyRegion().front().getArguments();
|
||||||
|
ArrayRef<Attribute> bound_types =
|
||||||
|
op.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundTypesAttrName())
|
||||||
|
.getValue();
|
||||||
|
|
||||||
|
// Print input loop operands, induction variables and their ranges.
|
||||||
|
for (size_t i = 0; i < num_input_loops; i++) {
|
||||||
|
if (i != 0)
|
||||||
|
p << ", ";
|
||||||
|
|
||||||
|
p.printOperand(*std::next(input_loop_begin, i));
|
||||||
|
p << " -> ";
|
||||||
|
|
||||||
|
// Print induction variable block argument.
|
||||||
|
p.printOperand(induction_variables[i]);
|
||||||
|
p << " = ";
|
||||||
|
|
||||||
|
print_bound(bound_types, 2 * i); // Print lower bound.
|
||||||
|
p << " to ";
|
||||||
|
print_bound(bound_types, 2 * i + 1); // Print upper bound.
|
||||||
|
}
|
||||||
|
p << ")";
|
||||||
|
|
||||||
|
p.printRegion(op.bodyRegion(), /*printEntryBlockArgs=*/false,
|
||||||
|
/*printBlockTerminators=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult parseKrnlIterateOp(OpAsmParser& parser, OperationState& result) {
|
||||||
|
auto builder = parser.getBuilder();
|
||||||
|
auto context = builder.getContext();
|
||||||
|
onnf::KrnlDialectOperandParser operand_parser(parser);
|
||||||
|
|
||||||
|
// Parse optimized loops:
|
||||||
|
SmallVector<OpAsmParser::OperandType, 4> num_optimized_loops;
|
||||||
|
if (parser.parseOperandList(
|
||||||
|
num_optimized_loops, OpAsmParser::Delimiter::Paren) ||
|
||||||
|
parser.resolveOperands(num_optimized_loops,
|
||||||
|
LoopType::get(result.getContext()), result.operands))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Record how many optimized loops did we parse.
|
||||||
|
result.addAttribute(KrnlIterateOp::getNumOptimizedLoopsAttrName(),
|
||||||
|
builder.getI64IntegerAttr(num_optimized_loops.size()));
|
||||||
|
|
||||||
|
// Parse input loops and their lower and upper bounds.
|
||||||
|
SmallVector<OpAsmParser::OperandType, 4> in_loop_refs, induction_var_refs;
|
||||||
|
SmallVector<Value*, 4> in_loop_operands, operand_bounds;
|
||||||
|
SmallVector<Attribute, 4> bound_types;
|
||||||
|
SmallVector<IntegerAttr, 4> const_bounds;
|
||||||
|
|
||||||
|
if (parser.parseKeyword("with") || parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// A function to parse a lower or upper bound.
|
||||||
|
auto parse_bound = [&result, &builder, &operand_parser, &parser, &bound_types,
|
||||||
|
&operand_bounds, &const_bounds](
|
||||||
|
bool is_ub, size_t bound_pair_count) -> ParseResult {
|
||||||
|
// Try parse an SSA operand.
|
||||||
|
Value* bound;
|
||||||
|
operand_parser.ParseOptionalOperand(builder.getIndexType(), bound);
|
||||||
|
|
||||||
|
if (bound != nullptr) {
|
||||||
|
// Parsed an SSA id as bound.
|
||||||
|
operand_bounds.emplace_back(bound);
|
||||||
|
// Record bound_type as an operand type.
|
||||||
|
bound_types.emplace_back(builder.getI32IntegerAttr(0));
|
||||||
|
} else {
|
||||||
|
// Bound is not an SSA id, then it must be an integer.
|
||||||
|
// Parse an integer constant attribute.
|
||||||
|
IntegerAttr boundAttr;
|
||||||
|
if (parser.parseAttribute(boundAttr, builder.getIndexType(),
|
||||||
|
KrnlIterateOp::getBoundAttrName(bound_pair_count, is_ub),
|
||||||
|
result.attributes))
|
||||||
|
return failure();
|
||||||
|
const_bounds.emplace_back(
|
||||||
|
builder.getIntegerAttr(builder.getIndexType(), boundAttr.getValue()));
|
||||||
|
|
||||||
|
// Record that the bound_type is a constant integer attribute.
|
||||||
|
bound_types.emplace_back(builder.getI32IntegerAttr(1));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bool keep_parsing; // Do we keep parsing loops/bounds?
|
||||||
|
size_t bound_pair_count = 0; // Record the number of bound pairs parsed.
|
||||||
|
do {
|
||||||
|
// Parse an input loop operand;
|
||||||
|
Value* in_loop_operand;
|
||||||
|
operand_parser.ParseOperand(LoopType::get(context), in_loop_operand);
|
||||||
|
in_loop_operands.emplace_back(in_loop_operand);
|
||||||
|
|
||||||
|
parser.parseArrow();
|
||||||
|
|
||||||
|
// Parse induction variable.
|
||||||
|
OpAsmParser::OperandType induction_var;
|
||||||
|
if (parser.parseRegionArgument(induction_var) || parser.parseEqual())
|
||||||
|
return failure();
|
||||||
|
induction_var_refs.emplace_back(induction_var);
|
||||||
|
|
||||||
|
// Parse bound par (min to max).
|
||||||
|
if (parse_bound(false, bound_pair_count) || parser.parseKeyword("to") ||
|
||||||
|
parse_bound(true, bound_pair_count))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bound_pair_count++;
|
||||||
|
// We may fail to parse a comma if an operand bound is followed by
|
||||||
|
// a comma and the next input loop operand, in which case
|
||||||
|
// the entire "{operand bound}, {input_loop_operand}" sequence will
|
||||||
|
// be parsed as an operand list.
|
||||||
|
parser.parseOptionalComma();
|
||||||
|
|
||||||
|
// If we don't see a RParen token, we keep parsing.
|
||||||
|
keep_parsing = failed(parser.parseOptionalRParen());
|
||||||
|
} while (keep_parsing);
|
||||||
|
|
||||||
|
// At this point, there shouldn't be any operands left to parse.
|
||||||
|
if (operand_parser.has_operand_left())
|
||||||
|
return parser.emitError(parser.getCurrentLocation());
|
||||||
|
|
||||||
|
// Record how many input loops did we parse.
|
||||||
|
result.addOperands(in_loop_operands);
|
||||||
|
result.addAttribute(KrnlIterateOp::getNumInputLoopsAttrName(),
|
||||||
|
builder.getI64IntegerAttr(in_loop_operands.size()));
|
||||||
|
|
||||||
|
// Add operand bounds to the list of operands of current operation.
|
||||||
|
result.addOperands(operand_bounds);
|
||||||
|
|
||||||
|
// A list of 2N elements where the (2n) and (2n+1) th element specifies
|
||||||
|
// whether the lower and upper bound of the n'th induction variable is stored
|
||||||
|
// as an operand or as an attribute. N being the number of input loops
|
||||||
|
// specified in this krnl.iterate operation.
|
||||||
|
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
|
||||||
|
builder.getArrayAttr(bound_types));
|
||||||
|
|
||||||
|
// Parse the schedule body region.
|
||||||
|
Region* region = result.addRegion();
|
||||||
|
SmallVector<Type, 4> induction_var_types(
|
||||||
|
induction_var_refs.size(), builder.getIndexType());
|
||||||
|
if (parser.parseRegion(*region, induction_var_refs, induction_var_types))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Ensure iterate region is closed off with krnl.terminate.
|
||||||
|
KrnlIterateOp::ensureTerminator(
|
||||||
|
*region, parser.getBuilder(), result.location);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(KrnlIterateOp op) {
|
||||||
|
// TODO: Verify number of induction variable bounds matches the number of
|
||||||
|
// input loops.
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KrnlReturnLoopsOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void print(OpAsmPrinter& p, KrnlReturnLoopsOp& op) {
|
||||||
|
p << "krnl.return_loops ";
|
||||||
|
p.printOperands(op.operand_begin(), op.operand_end());
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult parseKrnlReturnLoopsOp(
|
||||||
|
OpAsmParser& parser, OperationState& result) {
|
||||||
|
// Parse the loops to return.
|
||||||
|
SmallVector<OpAsmParser::OperandType, 4> timestamp_dim_handlers;
|
||||||
|
if (parser.parseOperandList(timestamp_dim_handlers) ||
|
||||||
|
parser.resolveOperands(timestamp_dim_handlers,
|
||||||
|
LoopType::get(result.getContext()), result.operands))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "src/compiler/krnl.cpp.inc"
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,51 @@
|
||||||
|
//===--------------------- krnl_ops.hpp - MLIR Operations -----------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_types.hpp"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class KrnlOpsDialect : public Dialect {
|
||||||
|
public:
|
||||||
|
KrnlOpsDialect(MLIRContext* context);
|
||||||
|
static StringRef getDialectNamespace() { return "krnl"; }
|
||||||
|
|
||||||
|
/// Parse a type registered to this dialect. Overriding this method is
|
||||||
|
/// required for dialects that have custom types.
|
||||||
|
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||||
|
mlir::Type parseType(
|
||||||
|
llvm::StringRef tyData, mlir::Location loc) const override {
|
||||||
|
MLIRContext* context = getContext();
|
||||||
|
|
||||||
|
if (tyData.consume_front("loop"))
|
||||||
|
return LoopType::get(context);
|
||||||
|
else
|
||||||
|
return (emitError(loc, "Unexpected type: " + tyData), Type());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a type registered to this dialect. Overriding this method is
|
||||||
|
/// only required for dialects that have custom types.
|
||||||
|
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||||
|
void printType(mlir::Type type, llvm::raw_ostream& os) const override {
|
||||||
|
switch (type.getKind()) {
|
||||||
|
case KrnlTypes::Loop:
|
||||||
|
os << "loop";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "src/compiler/krnl.hpp.inc"
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,209 @@
|
||||||
|
//===--------------------- krnl_ops.td - MLIR Operations ------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def Krnl_Dialect : Dialect {
|
||||||
|
let name = "krnl";
|
||||||
|
let cppNamespace = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Require regions to have krnl.terminate terminator operation.
|
||||||
|
def ImplicitKrnlTerminator
|
||||||
|
: SingleBlockImplicitTerminator<"KrnlTerminatorOp">;
|
||||||
|
|
||||||
|
def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
|
||||||
|
let summary = "define_loops operation";
|
||||||
|
let description = [{
|
||||||
|
|
||||||
|
The "krnl.define_loops" operation is used to define input loops,
|
||||||
|
those are the for loops appearing in the input program that we
|
||||||
|
intend to optimize.
|
||||||
|
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins);
|
||||||
|
let results = (outs Variadic<AnyType>);
|
||||||
|
|
||||||
|
let skipDefaultBuilders = 1;
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &result,"
|
||||||
|
"int64_t num_loops">
|
||||||
|
];
|
||||||
|
|
||||||
|
let printer = [{ return ::print(p, *this); }];
|
||||||
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static StringRef getNumLoopsAttrName() { return "num_loops"; }
|
||||||
|
|
||||||
|
// Helper function to extract the number of loops being defined.
|
||||||
|
int64_t getNumLoops() {
|
||||||
|
auto num_loops =
|
||||||
|
getAttrOfType<IntegerAttr>(
|
||||||
|
getNumLoopsAttrName())
|
||||||
|
.getValue()
|
||||||
|
.getSExtValue();
|
||||||
|
return num_loops;
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
def KrnlOptimizeLoopsOp : Op<Krnl_Dialect, "optimize_loops"> {
|
||||||
|
let summary = "optimize_loops operation";
|
||||||
|
let description = [{
|
||||||
|
|
||||||
|
The "krnl.optimize_loops" operation is essentially a cosmetic operation
|
||||||
|
which exists to encapsulate a region where loops are being scheduled/optimized.
|
||||||
|
|
||||||
|
The optimized loops are returned at the end of the
|
||||||
|
region associated with the krnl.optimize_loops operation.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
TBD once we have actual schedule intrinsics.
|
||||||
|
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyType>);
|
||||||
|
let results = (outs Variadic<AnyType>);
|
||||||
|
let regions = (region SizedRegion<1>:$region);
|
||||||
|
|
||||||
|
let skipDefaultBuilders = 1;
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &result, "
|
||||||
|
"int timestamp_space_rank">
|
||||||
|
];
|
||||||
|
|
||||||
|
let printer = [{ return ::print(p, *this); }];
|
||||||
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator]> {
|
||||||
|
let summary = "iterate operation";
|
||||||
|
let description = [{
|
||||||
|
|
||||||
|
The "krnl.iterate" operation is conceptually equivalent to a nested for loops.
|
||||||
|
|
||||||
|
For instance, say we have the following two
|
||||||
|
%l0, %l1 = krnl.define_loops 2
|
||||||
|
%o0, %o1 = krnl.optimize_loops {
|
||||||
|
// Identity schedule.
|
||||||
|
krnl.return_loops %l0, %l1
|
||||||
|
}
|
||||||
|
|
||||||
|
Then, consider the following krnl.iterate operation:
|
||||||
|
krnl.iterate (%o0, %o1) with (%l0 -> %i0 = 0 to 10, %l1 -> %i1 = 0 to 10) {
|
||||||
|
// Some operations.
|
||||||
|
}
|
||||||
|
|
||||||
|
It is equivalent to:
|
||||||
|
for (i0=0; i0<10; i0++)
|
||||||
|
for (i1=0; i1<10; i1++)
|
||||||
|
// Some operations.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyType>);
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$bodyRegion);
|
||||||
|
|
||||||
|
let skipDefaultBuilders = 1;
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &result, "
|
||||||
|
"ArrayRef<Value*> input_loops, ArrayRef<Value*> optimized_loops, "
|
||||||
|
"ArrayRef<Value*> operand_bounds, ArrayRef<int64_t> const_bounds, "
|
||||||
|
"ArrayRef<int> bound_types">
|
||||||
|
];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
|
||||||
|
// In krnl.iterate operation, three types of SSA values are stored:
|
||||||
|
// - Optimized krnl.loops.
|
||||||
|
// - Input krnl.loops.
|
||||||
|
// - SSA value based induction variable bound (parametric bound).
|
||||||
|
// We record the number of optimized and input loops to separate these three
|
||||||
|
// group of operands out.
|
||||||
|
static StringRef getNumOptimizedLoopsAttrName() { return "num_optimized_loops"; }
|
||||||
|
|
||||||
|
int64_t getNumOptimizedLoops() {
|
||||||
|
auto num_optimized_loops =
|
||||||
|
getAttrOfType<IntegerAttr>(
|
||||||
|
getNumOptimizedLoopsAttrName())
|
||||||
|
.getValue()
|
||||||
|
.getSExtValue();
|
||||||
|
return num_optimized_loops;
|
||||||
|
}
|
||||||
|
|
||||||
|
static StringRef getNumInputLoopsAttrName() { return "num_input_loops"; }
|
||||||
|
|
||||||
|
int64_t getNumInputLoops() {
|
||||||
|
auto num_loops =
|
||||||
|
getAttrOfType<IntegerAttr>(
|
||||||
|
getNumInputLoopsAttrName())
|
||||||
|
.getValue()
|
||||||
|
.getSExtValue();
|
||||||
|
return num_loops;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constant bounds are stored here as a list attribute.
|
||||||
|
static StringRef getConstantBoundsAttrName() { return "constant_bounds"; }
|
||||||
|
|
||||||
|
// Store type of each bound as three types:
|
||||||
|
// - 0 = constant attribute.
|
||||||
|
// - 1 = operand type.
|
||||||
|
// - 2 = affine maps (TODO).
|
||||||
|
static StringRef getBoundTypesAttrName() { return "bound_types"; }
|
||||||
|
|
||||||
|
// Get dynamic attribute name for the i-th lower and upper bound.
|
||||||
|
static std::string getBoundAttrName(int64_t i, bool is_ub) {
|
||||||
|
std::string bound_type = is_ub ? "_ub" : "_lb";
|
||||||
|
std::string bound_idx = std::to_string(i);
|
||||||
|
return "__bound_" + bound_idx + bound_type;
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let printer = [{ return ::print(p, *this); }];
|
||||||
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
let verifier = [{ return ::verify(*this); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
def KrnlReturnLoopsOp : Op<Krnl_Dialect, "return_loops", [Terminator]> {
|
||||||
|
let summary = "Krnl return handler operation";
|
||||||
|
let description = [{
|
||||||
|
Krnl return_loops operation is a terminator operation for returning
|
||||||
|
scheduled dimension handlers in the krnl.optimize_loops region.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyType>);
|
||||||
|
|
||||||
|
let printer = [{ return ::print(p, *this); }];
|
||||||
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
|
||||||
|
let summary = "Krnl terminator operation";
|
||||||
|
let description = [{
|
||||||
|
Krnl terminator is a special terminator operation for blocks inside krnl
|
||||||
|
iterate operations. It unconditionally transmits the control flow to the
|
||||||
|
successor of the operation enclosing the region.
|
||||||
|
|
||||||
|
This operation does _not_ have a custom syntax. However, krnl control
|
||||||
|
operations omit the terminator in their custom syntax for brevity.
|
||||||
|
}];
|
||||||
|
|
||||||
|
// No custom parsing/printing form.
|
||||||
|
let parser = ?;
|
||||||
|
let printer = ?;
|
||||||
|
|
||||||
|
// Fully specified by traits.
|
||||||
|
let verifier = ?;
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
//===--------------------- krnl_types.cpp - MLIR Operations ---------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "krnl_types.hpp"
|
|
@ -0,0 +1,36 @@
|
||||||
|
//===--------------------- krnl_types.hpp - MLIR Operations ---------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <mlir/IR/Types.h>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
namespace KrnlTypes {
|
||||||
|
enum Kinds {
|
||||||
|
// A krnl.loop is simply a reference to a for loop and will be used to:
|
||||||
|
// - Indicate the presence of a for loop in krnl.iterate.
|
||||||
|
// - Identify the loop in optimization intrinsics.
|
||||||
|
Loop = mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> {
|
||||||
|
public:
|
||||||
|
using Base::Base;
|
||||||
|
|
||||||
|
// Support type inquiry through isa, cast and dyn_cast.
|
||||||
|
static bool kindof(unsigned kind) { return kind == KrnlTypes::Loop; }
|
||||||
|
|
||||||
|
// Get a unique instance of Loop type.
|
||||||
|
static LoopType get(mlir::MLIRContext* context) {
|
||||||
|
return Base::get(context, KrnlTypes::Loop);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,52 @@
|
||||||
|
//===------------------ parser_helper.cpp - MLIR Operations ---------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "parser_helper.hpp"
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
|
|
||||||
|
namespace onnf {
|
||||||
|
|
||||||
|
mlir::ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
|
mlir::Type operand_type, mlir::Value*& operand) {
|
||||||
|
// If operand queue is empty, parse more operands and cache them.
|
||||||
|
if (_operand_ref_queue.empty()) {
|
||||||
|
// Parse operand types:
|
||||||
|
llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> operand_refs;
|
||||||
|
_parser.parseOperandList(operand_refs);
|
||||||
|
|
||||||
|
// Record operands:
|
||||||
|
for (auto& operand_ref : operand_refs)
|
||||||
|
_operand_ref_queue.emplace(operand_ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we parsed some operand reference(s), resolve the ref to an operand:
|
||||||
|
if (!_operand_ref_queue.empty()) {
|
||||||
|
auto operand_ref = _operand_ref_queue.front();
|
||||||
|
_operand_ref_queue.pop();
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value*, 1> operands;
|
||||||
|
_parser.resolveOperand(operand_ref, operand_type, operands);
|
||||||
|
operand = operands.front();
|
||||||
|
return mlir::success();
|
||||||
|
} else {
|
||||||
|
operand = nullptr;
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||||
|
mlir::Type operand_type, mlir::Value*& operand) {
|
||||||
|
ParseOptionalOperand(operand_type, operand);
|
||||||
|
if (operand == nullptr)
|
||||||
|
return _parser.emitError(
|
||||||
|
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnf
|
|
@ -0,0 +1,46 @@
|
||||||
|
//===------------------ parser_helper.hpp - MLIR Operations ---------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
namespace onnf {
|
||||||
|
|
||||||
|
class KrnlDialectOperandParser {
|
||||||
|
public:
|
||||||
|
KrnlDialectOperandParser(mlir::OpAsmParser& parser)
|
||||||
|
: _parser(parser), _builder(parser.getBuilder()){};
|
||||||
|
|
||||||
|
// Parse an optional operand.
|
||||||
|
mlir::ParseResult ParseOptionalOperand(
|
||||||
|
mlir::Type operand_type, mlir::Value*& operand);
|
||||||
|
|
||||||
|
// Parse a required operand.
|
||||||
|
mlir::ParseResult ParseOperand(
|
||||||
|
mlir::Type operand_type, mlir::Value*& operand);
|
||||||
|
|
||||||
|
// Do we have more operands to parse?
|
||||||
|
bool has_operand_left() { return !_operand_ref_queue.empty(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
mlir::OpAsmParser& _parser;
|
||||||
|
|
||||||
|
mlir::Builder& _builder;
|
||||||
|
|
||||||
|
// A queue storing the parsed SSA id references.
|
||||||
|
std::queue<mlir::OpAsmParser::OperandType> _operand_ref_queue;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace onnf
|
|
@ -1,27 +0,0 @@
|
||||||
include "mlir/IR/OpBase.td"
|
|
||||||
|
|
||||||
def Knl_Dialect : Dialect {
|
|
||||||
let name = "knl";
|
|
||||||
let cppNamespace = "";
|
|
||||||
}
|
|
||||||
|
|
||||||
def KnlIterate : Op<Knl_Dialect, "iterate"> {
|
|
||||||
let summary = "iterate operation";
|
|
||||||
let description = [{
|
|
||||||
|
|
||||||
The "knl.iterate" operation is conceptually equivalent to a nested for loop
|
|
||||||
in that it represents ordered interation of integer coordinates within an
|
|
||||||
affine integer set.
|
|
||||||
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins Variadic<AnyType>);
|
|
||||||
let regions = (region SizedRegion<1>:$region);
|
|
||||||
|
|
||||||
let skipDefaultBuilders = 1;
|
|
||||||
|
|
||||||
let builders = [
|
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
|
||||||
"IntegerSet set, ArrayRef<Value *> args">
|
|
||||||
];
|
|
||||||
}
|
|
|
@ -1,23 +0,0 @@
|
||||||
#include "llvm/ADT/SetVector.h"
|
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
|
||||||
#include "mlir/IR/Block.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/IR/Function.h"
|
|
||||||
#include "mlir/IR/IntegerSet.h"
|
|
||||||
#include "mlir/IR/Matchers.h"
|
|
||||||
#include "mlir/IR/OpImplementation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
#include "knl_ops.hpp"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
KnlOpsDialect::KnlOpsDialect(MLIRContext* context)
|
|
||||||
: Dialect(getDialectNamespace(), context) {
|
|
||||||
addOperations<
|
|
||||||
#define GET_OP_LIST
|
|
||||||
#include "src/compiler/knl.cpp.inc"
|
|
||||||
>();
|
|
||||||
}
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
namespace onnf {}
|
|
|
@ -1,19 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/IR/Dialect.h"
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
class KnlOpsDialect : public Dialect {
|
|
||||||
public:
|
|
||||||
KnlOpsDialect(MLIRContext* context);
|
|
||||||
static StringRef getDialectNamespace() { return "knl"; }
|
|
||||||
};
|
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
|
||||||
#include "src/compiler/knl.hpp.inc"
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
namespace onnf {}
|
|
|
@ -16,6 +16,11 @@
|
||||||
#include <mlir/Support/MlirOptMain.h>
|
#include <mlir/Support/MlirOptMain.h>
|
||||||
#include <mlir/Dialect/StandardOps/Ops.h>
|
#include <mlir/Dialect/StandardOps/Ops.h>
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
|
#include "src/compiler/helper.hpp"
|
||||||
|
|
||||||
|
using namespace onnf;
|
||||||
|
|
||||||
static llvm::cl::opt<std::string> input_filename(
|
static llvm::cl::opt<std::string> input_filename(
|
||||||
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
|
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
|
||||||
|
|
||||||
|
@ -52,6 +57,7 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
auto output = mlir::openOutputFile(output_filename, &error_message);
|
auto output = mlir::openOutputFile(output_filename, &error_message);
|
||||||
|
|
||||||
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||||
|
|
||||||
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
|
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
|
||||||
|
|
10
src/main.cpp
10
src/main.cpp
|
@ -1,3 +1,11 @@
|
||||||
|
//===--------------------------- main.cpp ---------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
@ -21,6 +29,7 @@
|
||||||
#include <boost/program_options.hpp>
|
#include <boost/program_options.hpp>
|
||||||
|
|
||||||
#include "src/builder/frontend_dialect_transformer.hpp"
|
#include "src/builder/frontend_dialect_transformer.hpp"
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
#include "src/compiler/pass/passes.hpp"
|
#include "src/compiler/pass/passes.hpp"
|
||||||
|
|
||||||
|
@ -57,6 +66,7 @@ int main(int ac, char* av[]) {
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
mlir::OwningModuleRef module;
|
mlir::OwningModuleRef module;
|
||||||
|
|
Loading…
Reference in New Issue