[MLIR] Lowering of frontend dialect to KRNL dialect (#382)
* Partial support for lowering operations to KRNL dialect. * Attempt to lower to KRNL IR. * Update file. * Add lowering. * Address comments. Fix alloc dynamic dimensions. Correctly link StandardOps. * Temporarily remove deallocation of locally allocated tensors.
This commit is contained in:
parent
d61cf35471
commit
b02652dd76
|
@ -46,3 +46,5 @@ add_subdirectory(src/builder)
|
|||
add_subdirectory(src/compiler)
|
||||
add_subdirectory(src)
|
||||
|
||||
add_subdirectory(test)
|
||||
|
||||
|
|
169
MLIR.cmake
169
MLIR.cmake
|
@ -44,89 +44,95 @@ set(
|
|||
)
|
||||
include_directories(${MLIR_INCLUDE_PATHS})
|
||||
|
||||
find_library(MLIR_LIB_ANALYSIS
|
||||
NAMES MLIRAnalysis
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_IR NAMES MLIRIR PATHS ${LLVM_PROJECT_LIB} NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_PARSER
|
||||
NAMES MLIRParser
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_PASS
|
||||
NAMES MLIRPass
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_TRANSFORMS
|
||||
NAMES MLIRTransforms
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_VECTOR_OPS
|
||||
NAMES MLIRVectorOps
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_SUPPORT
|
||||
NAMES MLIRSupport
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_STANDARD_OPS
|
||||
NAMES MLIRStandardOps
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_OPT_MAIN
|
||||
NAMES MLIROptMain
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LLVM_IR
|
||||
NAMES MLIRLLVMIR
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_TRANSFORM_UTILS
|
||||
NAMES MLIRTransformUtils
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(LLVM_LIB_SUPPORT
|
||||
NAMES LLVMSupport
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
# Threading libraries required due to parallel pass execution.
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
set(MLIRLIBS
|
||||
${MLIR_LIB_ANALYSIS}
|
||||
${MLIR_LIB_IR}
|
||||
${MLIR_LIB_PARSER}
|
||||
${MLIR_LIB_PASS}
|
||||
${MLIR_LIB_TRANSFORMS}
|
||||
${MLIR_LIB_VECTOR_OPS}
|
||||
${MLIR_LIB_STANDARD_OPS}
|
||||
${MLIR_LIB_OPT_MAIN}
|
||||
${MLIR_LIB_SUPPORT}
|
||||
${MLIR_LIB_TRANSFORM_UTILS}
|
||||
${MLIR_LIB_ANALYSIS}
|
||||
${MLIR_LIB_IR}
|
||||
${MLIR_LIB_PARSER}
|
||||
${MLIR_LIB_PASS}
|
||||
${MLIR_LIB_TRANSFORMS}
|
||||
${MLIR_LIB_VECTOR_OPS}
|
||||
${MLIR_LIB_STANDARD_OPS}
|
||||
${MLIR_LIB_OPT_MAIN}
|
||||
${MLIR_LIB_SUPPORT}
|
||||
${MLIR_LIB_TRANSFORM_UTILS}
|
||||
${LLVM_LIB_SUPPORT}
|
||||
Threads::Threads)
|
||||
function(find_mlir_lib lib)
|
||||
find_library(${lib}
|
||||
NAMES ${lib}
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
endfunction(find_mlir_lib)
|
||||
|
||||
find_mlir_lib(MLIRAffineOps)
|
||||
find_mlir_lib(MLIRAffineToStandard)
|
||||
find_mlir_lib(MLIRAnalysis)
|
||||
find_mlir_lib(MLIRExecutionEngine)
|
||||
find_mlir_lib(MLIRIR)
|
||||
find_mlir_lib(MLIRLLVMIR)
|
||||
find_mlir_lib(MLIRLoopToStandard)
|
||||
find_mlir_lib(MLIRParser)
|
||||
find_mlir_lib(MLIRPass)
|
||||
find_mlir_lib(MLIRStandardOps)
|
||||
find_mlir_lib(MLIRStandardToLLVM)
|
||||
find_mlir_lib(MLIRTargetLLVMIR)
|
||||
find_mlir_lib(MLIRTransforms)
|
||||
find_mlir_lib(MLIRTransforms)
|
||||
find_mlir_lib(MLIRTransformUtils)
|
||||
find_mlir_lib(MLIRSupport)
|
||||
find_mlir_lib(MLIROptMain)
|
||||
|
||||
find_mlir_lib(LLVMCore)
|
||||
find_mlir_lib(LLVMSupport)
|
||||
find_mlir_lib(LLVMAsmParser)
|
||||
find_mlir_lib(LLVMBinaryFormat)
|
||||
find_mlir_lib(LLVMRemarks)
|
||||
find_mlir_lib(LLVMIRReader)
|
||||
find_mlir_lib(LLVMTransformUtils)
|
||||
find_mlir_lib(LLVMBitstreamReader)
|
||||
|
||||
set(MLIRLibsOnce
|
||||
MLIRAffineOps
|
||||
MLIRAffineToStandard
|
||||
MLIRAnalysis
|
||||
MLIRExecutionEngine
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRLoopToStandard
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRStandardOps
|
||||
MLIRStandardToLLVM
|
||||
MLIRTargetLLVMIR
|
||||
MLIRTransforms
|
||||
MLIRAffineOps
|
||||
MLIRAffineToStandard
|
||||
MLIRAnalysis
|
||||
MLIRExecutionEngine
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRLoopToStandard
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRStandardOps
|
||||
MLIRStandardToLLVM
|
||||
MLIRTargetLLVMIR
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
MLIRLoopOps
|
||||
MLIRSupport
|
||||
MLIROptMain
|
||||
LLVMCore
|
||||
LLVMSupport
|
||||
LLVMAsmParser
|
||||
LLVMIRReader
|
||||
LLVMTransformUtils
|
||||
LLVMBinaryFormat
|
||||
LLVMRemarks
|
||||
LLVMBitstreamReader)
|
||||
|
||||
set(MLIRLibs
|
||||
${MLIRLibsOnce}
|
||||
${MLIRLibsOnce}
|
||||
Threads::Threads)
|
||||
|
||||
set(MLIRWholeArchiveLibs
|
||||
MLIRAffineToStandard
|
||||
MLIRAffineOps
|
||||
MLIRLLVMIR
|
||||
MLIRStandardOps
|
||||
MLIRStandardToLLVM
|
||||
MLIRLoopToStandard)
|
||||
|
||||
function(whole_archive_link target lib_dir)
|
||||
get_property(link_flags TARGET ${target} PROPERTY LINK_FLAGS)
|
||||
|
@ -155,6 +161,9 @@ function(whole_archive_link_mlir target)
|
|||
endfunction(whole_archive_link_mlir)
|
||||
|
||||
function(whole_archive_link_onnf target)
|
||||
foreach(LIB ${ARGN})
|
||||
add_dependencies(${target} ${LIB})
|
||||
endforeach(LIB)
|
||||
whole_archive_link(${target} ${CMAKE_BINARY_DIR}/lib ${ARGN})
|
||||
endfunction(whole_archive_link_onnf)
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
|
||||
add_executable(onnf main.cpp)
|
||||
|
||||
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
|
||||
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
||||
|
||||
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||
target_link_libraries(onnf builder compiler ${Boost_LIBRARIES})
|
||||
|
||||
install(TARGETS onnf DESTINATION bin)
|
|
@ -5,7 +5,8 @@ add_library(builder
|
|||
|
||||
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR})
|
||||
target_link_libraries(builder compiler onnx ${MLIRLIBS} curses)
|
||||
|
||||
target_link_libraries(builder compiler onnx ${MLIRLibs} curses)
|
||||
target_include_directories(builder
|
||||
PRIVATE
|
||||
${CMAKE_SOURCE_DIR}/third_party/onnx
|
||||
|
|
|
@ -10,9 +10,10 @@ add_library(
|
|||
dialect/krnl/parser_helper.hpp
|
||||
pass/shape_inference_pass.cpp
|
||||
pass/shape_inference_interface.hpp
|
||||
pass/passes.hpp
|
||||
dialect/onnx/onnxop.inc
|
||||
pass/onnx_combine.cpp)
|
||||
pass/onnx_combine.cpp
|
||||
pass/lower_frontend_to_krnl.cpp
|
||||
pass/passes.hpp)
|
||||
|
||||
# Include root src directory.
|
||||
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
|
||||
|
@ -41,7 +42,7 @@ target_link_libraries(compiler
|
|||
${Boost_LIBRARIES}
|
||||
${CMAKE_THREAD_LIBS_INIT}
|
||||
${CMAKE_DL_LIBS}
|
||||
${MLIRLIBS}
|
||||
${MLIRLibs}
|
||||
curses)
|
||||
|
||||
add_subdirectory(tool)
|
||||
|
|
|
@ -176,7 +176,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result,
|
|||
result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(),
|
||||
builder->getI32ArrayAttr(bound_types));
|
||||
|
||||
// Create a region and a block for the body. The arguments of the region is
|
||||
// Create a region and a block for the body. The arguments of the region are
|
||||
// the loop induction variables; there can be multiple induction variables
|
||||
// associated with the same krnl.iterate operation.
|
||||
Region* bodyRegion = result.addRegion();
|
||||
|
@ -207,16 +207,16 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) {
|
|||
auto print_bound = [&](ArrayRef<Attribute> bound_types, size_t idx) {
|
||||
IntegerAttr type = bound_types[idx].dyn_cast<IntegerAttr>();
|
||||
if (type.getValue().getSExtValue() == 0) {
|
||||
// Bound is an operand.
|
||||
p.printOperand(*next_operand_bound);
|
||||
next_operand_bound = std::next(next_operand_bound);
|
||||
} else {
|
||||
// Bound is an integer attribute.
|
||||
auto bound_idx = idx / 2;
|
||||
auto is_ub = idx % 2;
|
||||
IntegerAttr bound = op.getAttrOfType<IntegerAttr>(
|
||||
KrnlIterateOp::getBoundAttrName(bound_idx, is_ub));
|
||||
p << bound.getValue().getSExtValue();
|
||||
} else {
|
||||
// Bound is an operand.
|
||||
p.printOperand(*next_operand_bound);
|
||||
next_operand_bound = std::next(next_operand_bound);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,282 @@
|
|||
//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===//
|
||||
//
|
||||
// Copyright 2019 The DLC Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the lowering of frontend operations to a combination of
|
||||
// Krnl IR and standard operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
|
||||
#include "src/compiler/pass/passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FrontendToAffine RewritePatterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Check is all dimensions are known at compile time.
|
||||
static bool hasAllConstantDimensions(MemRefType type) {
|
||||
auto memRefShape = type.getShape();
|
||||
for (int i = 0; i < memRefShape.size(); ++i)
|
||||
if (memRefShape[i] < 0)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Convert the given TensorType into the corresponding MemRefType.
|
||||
static MemRefType convertTensorToMemRef(TensorType type) {
|
||||
assert(type.hasRank() && "expected only ranked shapes");
|
||||
return MemRefType::get(type.getShape(), type.getElementType());
|
||||
}
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
static Value* insertAllocAndDealloc(
|
||||
MemRefType type, Location loc, PatternRewriter& rewriter,
|
||||
Value *oldMemRef = nullptr) {
|
||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||
AllocOp alloc;
|
||||
if (oldMemRef) {
|
||||
SmallVector<Value *, 4> allocOperands;
|
||||
auto memRefShape = type.getShape();
|
||||
for (int i = 0; i < memRefShape.size(); ++i)
|
||||
if (memRefShape[i] < 0)
|
||||
allocOperands.push_back(rewriter.create<DimOp>(loc, oldMemRef, i));
|
||||
|
||||
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
||||
} else {
|
||||
alloc = rewriter.create<AllocOp>(loc, type);
|
||||
}
|
||||
|
||||
// Make sure to allocate at the beginning of the block if
|
||||
// all dimensions are known.
|
||||
auto* parentBlock = alloc.getOperation()->getBlock();
|
||||
if (hasAllConstantDimensions(type))
|
||||
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||
|
||||
return alloc;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddOp lowering to Krnl dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct ONNXAddOpLowering : public ConversionPattern {
|
||||
ONNXAddOpLowering(MLIRContext* ctx)
|
||||
: ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// TODO: Check that the types are valid.
|
||||
// Add is an operation that must have all operands and the result of
|
||||
// the same type. This should have been verified by the verifier.
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
|
||||
// If the output has a dynamic dimension, pass the operands required for
|
||||
// each dynamic dimension to the AllocOp. The first operand of the Add
|
||||
// operation is used. The operands of the Add need to match in terms of
|
||||
// dimensions with the result at this pre-optimization phase.
|
||||
// TODO: verify that dimensions match.
|
||||
// TODO: can the dimension of the result differ after optimizations?
|
||||
Value *alloc;
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, operands[0]);
|
||||
|
||||
// Number of loops
|
||||
auto memRefShape = memRefType.getShape();
|
||||
int64_t rank = memRefShape.size();
|
||||
|
||||
// Define loops.
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||
std::vector<Value*> originalLoops;
|
||||
originalLoops.reserve(rank);
|
||||
for (auto result : loopsOp.getResults()) {
|
||||
originalLoops.push_back(result);
|
||||
}
|
||||
|
||||
// Define loop optimization.
|
||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||
std::vector<Value*> optimizedLoops;
|
||||
optimizedLoops.reserve(rank);
|
||||
for (auto result : optimizedLoopsOp.getResults()) {
|
||||
optimizedLoops.push_back(result);
|
||||
}
|
||||
Block& optimizationBlock = optimizedLoopsOp.region().front();
|
||||
|
||||
// Iterate over the loop nest.
|
||||
// TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
|
||||
// to KrnlIterateOp instead.
|
||||
SmallVector<Value*, 8> operandBounds;
|
||||
SmallVector<int64_t, 8> constBounds;
|
||||
SmallVector<int, 16> boundTypes;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (memRefShape[i] < 0) {
|
||||
// This is a dynamic value, hence use operands.
|
||||
// Lower bound
|
||||
constBounds.push_back(0);
|
||||
boundTypes.push_back(0);
|
||||
// Upper bound
|
||||
operandBounds.push_back(
|
||||
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
||||
boundTypes.push_back(1);
|
||||
} else {
|
||||
// Lower bound
|
||||
constBounds.push_back(0);
|
||||
boundTypes.push_back(0);
|
||||
// Upper bound
|
||||
constBounds.push_back(memRefShape[i]);
|
||||
boundTypes.push_back(0);
|
||||
}
|
||||
}
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, originalLoops,
|
||||
optimizedLoops, operandBounds, constBounds, boundTypes);
|
||||
Block& iterationBlock = iterateOp.bodyRegion().front();
|
||||
|
||||
// Now perform the insertions into the body of the
|
||||
// just generated instructions:
|
||||
|
||||
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
|
||||
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||
// Return from KrnlOptimizeLoopsOp body.
|
||||
// When no optimizations are present we just return the loops
|
||||
// unchaged.
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
rewriter.setInsertionPoint(optimizedLoopsOp);
|
||||
|
||||
// 2. Insert instructions inside the KernelIterateOp body.
|
||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||
|
||||
// Handle AddOp:
|
||||
SmallVector<Value*, 4> loopIVs;
|
||||
for (auto arg : iterationBlock.getArguments())
|
||||
loopIVs.push_back(arg);
|
||||
auto loadedFirstVal =
|
||||
rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
||||
auto loadedSecondVal =
|
||||
rewriter.create<LoadOp>(loc, operands[1], loopIVs);
|
||||
|
||||
// TODO: Choose type of the Add for now use the Float Add.
|
||||
auto addOpResult = rewriter.create<AddFOp>(
|
||||
loc, loadedFirstVal, loadedSecondVal);
|
||||
|
||||
// Store result in the resulting array.
|
||||
rewriter.create<StoreOp>(loc, addOpResult, alloc, loopIVs);
|
||||
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct TensorTypeConverter : public TypeConverter {
|
||||
using TypeConverter::TypeConverter;
|
||||
|
||||
LogicalResult convertType(Type t, SmallVectorImpl<Type>& results) override {
|
||||
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
||||
results.push_back(convertTensorToMemRef(tensor_type));
|
||||
return success();
|
||||
}
|
||||
|
||||
results.push_back(t);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Return true if the inputs and outputs of the given function type are
|
||||
/// legal. [Taken from MLIR and adapted to only check the legality of the
|
||||
/// inputs. Once unranked results can be handled gracefully this
|
||||
/// override needs to be removed in favour of the original MLIR one.]
|
||||
bool isSignatureLegal(FunctionType funcType) {
|
||||
return llvm::all_of(funcType.getInputs(),
|
||||
[this](Type type) { return isLegal(type); });
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Frontend to Krnl Dialect lowering pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This is a partial lowering to Krnl loops of the ONNX operations.
|
||||
namespace {
|
||||
struct FrontendToKrnlLoweringPass
|
||||
: public ModulePass<FrontendToKrnlLoweringPass> {
|
||||
void runOnModule() final;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void FrontendToKrnlLoweringPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
|
||||
// The first thing to define is the conversion target. This will define the
|
||||
// final target for this lowering.
|
||||
ConversionTarget target(getContext());
|
||||
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering.
|
||||
target
|
||||
.addLegalDialect<KrnlOpsDialect, AffineOpsDialect, StandardOpsDialect>();
|
||||
|
||||
// TODO: enable this once more ops are supported.
|
||||
// We also define the ONNX dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted.
|
||||
// target.addIllegalDialect<mlir::ONNXOpsDialect>();
|
||||
|
||||
// TODO: add any other ops which are considered legal.
|
||||
// Some operations can be marked as being still legal.
|
||||
// Example: target.addLegalOp<mlir::OpName>();
|
||||
|
||||
// Now that the conversion target has been defined, we just need to provide
|
||||
// the set of patterns that will lower the frontend operations.
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
// Convert TensorType to MemRef
|
||||
TensorTypeConverter tensor_to_memref_converter;
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
// FuncOp is legal only if types have been converted to Std types.
|
||||
return tensor_to_memref_converter.isSignatureLegal(op.getType());
|
||||
});
|
||||
|
||||
// Type conversion for function signatures.
|
||||
// Call MLIR FuncOp signature conversion when result type is
|
||||
// a ranked tensor.
|
||||
populateFuncOpTypeConversionPattern(
|
||||
patterns, &getContext(), tensor_to_memref_converter);
|
||||
|
||||
// Frontent operation lowering.
|
||||
patterns.insert<ONNXAddOpLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
// operations were not converted successfully.
|
||||
if (failed(applyPartialConversion(
|
||||
module, target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
|
||||
return std::make_unique<FrontendToKrnlLoweringPass>();
|
||||
}
|
|
@ -17,7 +17,8 @@ class Pass;
|
|||
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
|
||||
// TODO: Add pass for lowering to kernel IR.
|
||||
/// Add pass for lowering to Krnl IR.
|
||||
std::unique_ptr<mlir::Pass> createLowerToKrnlPass();
|
||||
|
||||
// TODO: Add pass for lowering to LLVM IR.
|
||||
|
||||
|
|
|
@ -71,6 +71,12 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
|||
<< op_worklist.size() << " operations couldn't be inferred\n";
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
||||
auto results = terminator_op->getOperandTypes();
|
||||
f.setType(FunctionType::get(f.getType().getInputs(),
|
||||
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
|
|
|
@ -3,14 +3,8 @@ add_executable(onnf-opt onnf_opt.cpp)
|
|||
target_include_directories(onnf-opt PRIVATE ${ONNF_SRC_ROOT})
|
||||
target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT})
|
||||
|
||||
set(LIB_LIST
|
||||
MLIRStandardOps
|
||||
MLIRAffineOps
|
||||
MLIRLoopOps
|
||||
MLIRTransformUtils
|
||||
MLIREDSC
|
||||
MLIRTransforms)
|
||||
whole_archive_link_mlir(onnf-opt ${LIB_LIST})
|
||||
target_link_libraries(onnf-opt compiler ${MLIRLibs})
|
||||
whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs})
|
||||
|
||||
# TODO: need to investigate how to whole-archive link compiler pass to onnf-opt.
|
||||
target_link_libraries(onnf-opt compiler)
|
||||
|
|
|
@ -124,6 +124,7 @@ int main(int ac, char* av[]) {
|
|||
mlir::PassManager pm(&context);
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createLowerToKrnlPass());
|
||||
pm.run(*module);
|
||||
|
||||
return 0;
|
||||
|
|
|
@ -1,4 +1 @@
|
|||
add_subdirectory(models)
|
||||
add_subdirectory(nodes)
|
||||
|
||||
add_subdirectory(mlir)
|
Loading…
Reference in New Issue