[MLIR] Lowering of frontend dialect to KRNL dialect (#382)

* Partial support for lowering operations to KRNL dialect.

* Attempt to lower to KRNL IR.

* Update file.

* Add lowering.

* Address comments. Fix alloc dynamic dimensions. Correctly link StandardOps.

* Temporarily remove deallocation of locally allocated tensors.
This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-11-26 13:55:44 -05:00 committed by Tian Jin
parent d61cf35471
commit b02652dd76
13 changed files with 709 additions and 413 deletions

View File

@ -46,3 +46,5 @@ add_subdirectory(src/builder)
add_subdirectory(src/compiler)
add_subdirectory(src)
add_subdirectory(test)

View File

@ -44,90 +44,96 @@ set(
)
include_directories(${MLIR_INCLUDE_PATHS})
find_library(MLIR_LIB_ANALYSIS
NAMES MLIRAnalysis
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LIB_IR NAMES MLIRIR PATHS ${LLVM_PROJECT_LIB} NO_DEFAULT_PATH)
find_library(MLIR_LIB_PARSER
NAMES MLIRParser
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LIB_PASS
NAMES MLIRPass
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LIB_TRANSFORMS
NAMES MLIRTransforms
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LIB_VECTOR_OPS
NAMES MLIRVectorOps
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LIB_SUPPORT
NAMES MLIRSupport
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LIB_STANDARD_OPS
NAMES MLIRStandardOps
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LIB_OPT_MAIN
NAMES MLIROptMain
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LLVM_IR
NAMES MLIRLLVMIR
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LIB_TRANSFORM_UTILS
NAMES MLIRTransformUtils
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(LLVM_LIB_SUPPORT
NAMES LLVMSupport
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
# Threading libraries required due to parallel pass execution.
find_package(Threads REQUIRED)
set(MLIRLIBS
${MLIR_LIB_ANALYSIS}
${MLIR_LIB_IR}
${MLIR_LIB_PARSER}
${MLIR_LIB_PASS}
${MLIR_LIB_TRANSFORMS}
${MLIR_LIB_VECTOR_OPS}
${MLIR_LIB_STANDARD_OPS}
${MLIR_LIB_OPT_MAIN}
${MLIR_LIB_SUPPORT}
${MLIR_LIB_TRANSFORM_UTILS}
${MLIR_LIB_ANALYSIS}
${MLIR_LIB_IR}
${MLIR_LIB_PARSER}
${MLIR_LIB_PASS}
${MLIR_LIB_TRANSFORMS}
${MLIR_LIB_VECTOR_OPS}
${MLIR_LIB_STANDARD_OPS}
${MLIR_LIB_OPT_MAIN}
${MLIR_LIB_SUPPORT}
${MLIR_LIB_TRANSFORM_UTILS}
${LLVM_LIB_SUPPORT}
function(find_mlir_lib lib)
find_library(${lib}
NAMES ${lib}
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
endfunction(find_mlir_lib)
find_mlir_lib(MLIRAffineOps)
find_mlir_lib(MLIRAffineToStandard)
find_mlir_lib(MLIRAnalysis)
find_mlir_lib(MLIRExecutionEngine)
find_mlir_lib(MLIRIR)
find_mlir_lib(MLIRLLVMIR)
find_mlir_lib(MLIRLoopToStandard)
find_mlir_lib(MLIRParser)
find_mlir_lib(MLIRPass)
find_mlir_lib(MLIRStandardOps)
find_mlir_lib(MLIRStandardToLLVM)
find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransforms)
find_mlir_lib(MLIRTransforms)
find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRSupport)
find_mlir_lib(MLIROptMain)
find_mlir_lib(LLVMCore)
find_mlir_lib(LLVMSupport)
find_mlir_lib(LLVMAsmParser)
find_mlir_lib(LLVMBinaryFormat)
find_mlir_lib(LLVMRemarks)
find_mlir_lib(LLVMIRReader)
find_mlir_lib(LLVMTransformUtils)
find_mlir_lib(LLVMBitstreamReader)
set(MLIRLibsOnce
MLIRAffineOps
MLIRAffineToStandard
MLIRAnalysis
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
MLIRLoopToStandard
MLIRParser
MLIRPass
MLIRStandardOps
MLIRStandardToLLVM
MLIRTargetLLVMIR
MLIRTransforms
MLIRAffineOps
MLIRAffineToStandard
MLIRAnalysis
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
MLIRLoopToStandard
MLIRParser
MLIRPass
MLIRStandardOps
MLIRStandardToLLVM
MLIRTargetLLVMIR
MLIRTransforms
MLIRTransformUtils
MLIRLoopOps
MLIRSupport
MLIROptMain
LLVMCore
LLVMSupport
LLVMAsmParser
LLVMIRReader
LLVMTransformUtils
LLVMBinaryFormat
LLVMRemarks
LLVMBitstreamReader)
set(MLIRLibs
${MLIRLibsOnce}
${MLIRLibsOnce}
Threads::Threads)
set(MLIRWholeArchiveLibs
MLIRAffineToStandard
MLIRAffineOps
MLIRLLVMIR
MLIRStandardOps
MLIRStandardToLLVM
MLIRLoopToStandard)
function(whole_archive_link target lib_dir)
get_property(link_flags TARGET ${target} PROPERTY LINK_FLAGS)
if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
@ -155,6 +161,9 @@ function(whole_archive_link_mlir target)
endfunction(whole_archive_link_mlir)
function(whole_archive_link_onnf target)
foreach(LIB ${ARGN})
add_dependencies(${target} ${LIB})
endforeach(LIB)
whole_archive_link(${target} ${CMAKE_BINARY_DIR}/lib ${ARGN})
endfunction(whole_archive_link_onnf)

View File

@ -1,7 +1,9 @@
add_executable(onnf main.cpp)
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
target_link_libraries(onnf builder compiler ${Boost_LIBRARIES})
install(TARGETS onnf DESTINATION bin)

View File

@ -5,7 +5,8 @@ add_library(builder
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR})
target_link_libraries(builder compiler onnx ${MLIRLIBS} curses)
target_link_libraries(builder compiler onnx ${MLIRLibs} curses)
target_include_directories(builder
PRIVATE
${CMAKE_SOURCE_DIR}/third_party/onnx

View File

@ -10,9 +10,10 @@ add_library(
dialect/krnl/parser_helper.hpp
pass/shape_inference_pass.cpp
pass/shape_inference_interface.hpp
pass/passes.hpp
dialect/onnx/onnxop.inc
pass/onnx_combine.cpp)
pass/onnx_combine.cpp
pass/lower_frontend_to_krnl.cpp
pass/passes.hpp)
# Include root src directory.
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
@ -41,7 +42,7 @@ target_link_libraries(compiler
${Boost_LIBRARIES}
${CMAKE_THREAD_LIBS_INIT}
${CMAKE_DL_LIBS}
${MLIRLIBS}
${MLIRLibs}
curses)
add_subdirectory(tool)

View File

@ -176,7 +176,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
builder->getI32ArrayAttr(bound_types));
// Create a region and a block for the body. The arguments of the region is
// Create a region and a block for the body. The arguments of the region are
// the loop induction variables; there can be multiple induction variables
// associated with the same krnl.iterate operation.
Region* bodyRegion = result.addRegion();
@ -207,16 +207,16 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) {
auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) {
IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>();
if (type.getValue().getSExtValue() == 0) {
// Bound is an operand.
p.printOperand(*next_operand_bound);
next_operand_bound = std::next(next_operand_bound);
} else {
// Bound is an integer attribute.
auto bound_idx = idx / 2;
auto is_ub = idx % 2;
IntegerAttr bound = op.getAttrOfType<IntegerAttr>(
KrnlIterateOp::getBoundAttrName(bound_idx, is_ub));
p << bound.getValue().getSExtValue();
} else {
// Bound is an operand.
p.printOperand(*next_operand_bound);
next_operand_bound = std::next(next_operand_bound);
}
};

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,282 @@
//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===//
//
// Copyright 2019 The DLC Authors.
//
// =============================================================================
//
// This file implements the lowering of frontend operations to a combination of
// Krnl IR and standard operations.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "src/compiler/pass/passes.hpp"
using namespace mlir;
//===----------------------------------------------------------------------===//
// FrontendToAffine RewritePatterns
//===----------------------------------------------------------------------===//
/// Check is all dimensions are known at compile time.
static bool hasAllConstantDimensions(MemRefType type) {
auto memRefShape = type.getShape();
for (int i = 0; i < memRefShape.size(); ++i)
if (memRefShape[i] < 0)
return false;
return true;
}
/// Convert the given TensorType into the corresponding MemRefType.
static MemRefType convertTensorToMemRef(TensorType type) {
assert(type.hasRank() && "expected only ranked shapes");
return MemRefType::get(type.getShape(), type.getElementType());
}
/// Insert an allocation and deallocation for the given MemRefType.
static Value* insertAllocAndDealloc(
MemRefType type, Location loc, PatternRewriter& rewriter,
Value *oldMemRef = nullptr) {
// Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc;
if (oldMemRef) {
SmallVector<Value *, 4> allocOperands;
auto memRefShape = type.getShape();
for (int i = 0; i < memRefShape.size(); ++i)
if (memRefShape[i] < 0)
allocOperands.push_back(rewriter.create<DimOp>(loc, oldMemRef, i));
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
} else {
alloc = rewriter.create<AllocOp>(loc, type);
}
// Make sure to allocate at the beginning of the block if
// all dimensions are known.
auto* parentBlock = alloc.getOperation()->getBlock();
if (hasAllConstantDimensions(type))
alloc.getOperation()->moveBefore(&parentBlock->front());
return alloc;
}
namespace {
//===----------------------------------------------------------------------===//
// AddOp lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
struct ONNXAddOpLowering : public ConversionPattern {
ONNXAddOpLowering(MLIRContext* ctx)
: ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
// TODO: Check that the types are valid.
// Add is an operation that must have all operands and the result of
// the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
// If the output has a dynamic dimension, pass the operands required for
// each dynamic dimension to the AllocOp. The first operand of the Add
// operation is used. The operands of the Add need to match in terms of
// dimensions with the result at this pre-optimization phase.
// TODO: verify that dimensions match.
// TODO: can the dimension of the result differ after optimizations?
Value *alloc;
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, operands[0]);
// Number of loops
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value*> originalLoops;
originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result);
}
// Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value*> optimizedLoops;
optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result);
}
Block& optimizationBlock = optimizedLoopsOp.region().front();
// Iterate over the loop nest.
// TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
// to KrnlIterateOp instead.
SmallVector<Value*, 8> operandBounds;
SmallVector<int64_t, 8> constBounds;
SmallVector<int, 16> boundTypes;
for (int i = 0; i < rank; ++i) {
if (memRefShape[i] < 0) {
// This is a dynamic value, hence use operands.
// Lower bound
constBounds.push_back(0);
boundTypes.push_back(0);
// Upper bound
operandBounds.push_back(
rewriter.create<DimOp>(loc, operands[0], i).getResult());
boundTypes.push_back(1);
} else {
// Lower bound
constBounds.push_back(0);
boundTypes.push_back(0);
// Upper bound
constBounds.push_back(memRefShape[i]);
boundTypes.push_back(0);
}
}
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, originalLoops,
optimizedLoops, operandBounds, constBounds, boundTypes);
Block& iterationBlock = iterateOp.bodyRegion().front();
// Now perform the insertions into the body of the
// just generated instructions:
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
rewriter.setInsertionPointToEnd(&optimizationBlock);
// Return from KrnlOptimizeLoopsOp body.
// When no optimizations are present we just return the loops
// unchaged.
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.setInsertionPoint(optimizedLoopsOp);
// 2. Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle AddOp:
SmallVector<Value*, 4> loopIVs;
for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg);
auto loadedFirstVal =
rewriter.create<LoadOp>(loc, operands[0], loopIVs);
auto loadedSecondVal =
rewriter.create<LoadOp>(loc, operands[1], loopIVs);
// TODO: Choose type of the Add for now use the Float Add.
auto addOpResult = rewriter.create<AddFOp>(
loc, loadedFirstVal, loadedSecondVal);
// Store result in the resulting array.
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===//
struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
LogicalResult convertType(Type t, SmallVectorImpl<Type>& results) override {
if (auto tensor_type = t.dyn_cast<TensorType>()) {
results.push_back(convertTensorToMemRef(tensor_type));
return success();
}
results.push_back(t);
return success();
}
/// Return true if the inputs and outputs of the given function type are
/// legal. [Taken from MLIR and adapted to only check the legality of the
/// inputs. Once unranked results can be handled gracefully this
/// override needs to be removed in favour of the original MLIR one.]
bool isSignatureLegal(FunctionType funcType) {
return llvm::all_of(funcType.getInputs(),
[this](Type type) { return isLegal(type); });
}
};
} // end anonymous namespace.
//===----------------------------------------------------------------------===//
// Frontend to Krnl Dialect lowering pass
//===----------------------------------------------------------------------===//
/// This is a partial lowering to Krnl loops of the ONNX operations.
namespace {
struct FrontendToKrnlLoweringPass
: public ModulePass<FrontendToKrnlLoweringPass> {
void runOnModule() final;
};
} // end anonymous namespace.
void FrontendToKrnlLoweringPass::runOnModule() {
auto module = getModule();
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for
// this lowering.
target
.addLegalDialect<KrnlOpsDialect, AffineOpsDialect, StandardOpsDialect>();
// TODO: enable this once more ops are supported.
// We also define the ONNX dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted.
// target.addIllegalDialect<mlir::ONNXOpsDialect>();
// TODO: add any other ops which are considered legal.
// Some operations can be marked as being still legal.
// Example: target.addLegalOp<mlir::OpName>();
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the frontend operations.
OwningRewritePatternList patterns;
// Convert TensorType to MemRef
TensorTypeConverter tensor_to_memref_converter;
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
// FuncOp is legal only if types have been converted to Std types.
return tensor_to_memref_converter.isSignatureLegal(op.getType());
});
// Type conversion for function signatures.
// Call MLIR FuncOp signature conversion when result type is
// a ranked tensor.
populateFuncOpTypeConversionPattern(
patterns, &getContext(), tensor_to_memref_converter);
// Frontent operation lowering.
patterns.insert<ONNXAddOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(applyPartialConversion(
module, target, patterns)))
signalPassFailure();
}
std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
return std::make_unique<FrontendToKrnlLoweringPass>();
}

View File

@ -17,7 +17,8 @@ class Pass;
std::unique_ptr<Pass> createShapeInferencePass();
// TODO: Add pass for lowering to kernel IR.
/// Add pass for lowering to Krnl IR.
std::unique_ptr<mlir::Pass> createLowerToKrnlPass();
// TODO: Add pass for lowering to LLVM IR.

View File

@ -71,6 +71,12 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
<< op_worklist.size() << " operations couldn't be inferred\n";
signalPassFailure();
}
if (auto terminator_op = f.getBody().back().getTerminator()) {
auto results = terminator_op->getOperandTypes();
f.setType(FunctionType::get(f.getType().getInputs(),
std::vector<Type>(results.begin(), results.end()), f.getContext()));
}
}
/*!

View File

@ -3,14 +3,8 @@ add_executable(onnf-opt onnf_opt.cpp)
target_include_directories(onnf-opt PRIVATE ${ONNF_SRC_ROOT})
target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT})
set(LIB_LIST
MLIRStandardOps
MLIRAffineOps
MLIRLoopOps
MLIRTransformUtils
MLIREDSC
MLIRTransforms)
whole_archive_link_mlir(onnf-opt ${LIB_LIST})
target_link_libraries(onnf-opt compiler ${MLIRLibs})
whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs})
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
target_link_libraries(onnf-opt compiler)

View File

@ -124,6 +124,7 @@ int main(int ac, char* av[]) {
mlir::PassManager pm(&context);
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createLowerToKrnlPass());
pm.run(*module);
return 0;

View File

@ -1,4 +1 @@
add_subdirectory(models)
add_subdirectory(nodes)
add_subdirectory(mlir)