Make onnx-mlir work with latest mlir. (#93)

* Make onnx-mlir work with latest mlir.

* Bump CircleCI cache version.

* Fix missing passes in onnx-mlir-opt.

* Fix backend test failure.

* Fix doc.

* Fix doc and exclude the generated _site directory from DocCheck.

* Remove debug code.

* Do not hard code target name, on Mac shared lib can end with .dylib.

* FunctionPass -> PassWrapper.
This commit is contained in:
Tian Jin 2020-04-27 17:03:56 +08:00 committed by GitHub
parent 137ce767e6
commit fad2ad7d03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 108 additions and 90 deletions

View File

@ -18,7 +18,7 @@ jobs:
git submodule update --init --recursive
# Use cached mlir installation if possible.
- restore_cache:
key: V10-1-LLVM-PROJECT-{{ arch }}
key: V11-LLVM-PROJECT-{{ arch }}
- run:
name: Install MLIR
command: |
@ -29,7 +29,7 @@ jobs:
source onnx-mlir/utils/install-mlir.sh
fi
- save_cache:
key: V10-1-LLVM-PROJECT-{{ arch }}
key: V11-LLVM-PROJECT-{{ arch }}
paths:
- llvm-project
- run:

3
.gitignore vendored
View File

@ -1,3 +1,6 @@
.idea/
cmake-*/
# Prerequisites
*.d

View File

@ -132,8 +132,9 @@ function(find_mlir_lib lib)
endif()
endfunction(find_mlir_lib)
find_mlir_lib(MLIRAffine)
find_mlir_lib(MLIRAffineOps)
find_mlir_lib(MLIRAffineToStandard)
find_mlir_lib(MLIRAffineTransforms)
find_mlir_lib(MLIRAnalysis)
find_mlir_lib(MLIRCallInterfaces)
find_mlir_lib(MLIRControlFlowInterfaces)
@ -193,8 +194,9 @@ set(MLIRLibs
${MLIRTargetLLVMIRModuleTranslation}
${MLIRTransforms}
${MLIRTransformUtils}
${MLIRAffine}
${MLIRAffineOps}
${MLIRAffineToStandard}
${MLIRAffineTransforms}
${MLIRAnalysis}
${MLIRCallInterfaces}
${MLIRControlFlowInterfaces}
@ -244,14 +246,15 @@ set(MLIRLibs
# must be specified on LD_PRELOAD for shared build.
set(MLIRWholeArchiveLibs
MLIRAffineToStandard
MLIRAffine
MLIRAffineOps
MLIRLLVMIR
MLIRStandardOps
MLIRStandardToLLVM
MLIRTransforms
MLIRLoopToStandard
MLIRVector
MLIRLoopOps)
MLIRLoopOps
MLIRIR)
# ONNX MLIR libraries that must be linked with --whole-archive for static build or
# must be specified on LD_PRELOAD for shared build.

View File

@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
```
[same-as-file]: <> (utils/build-mlir.sh)
@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
```
[same-as-file]: <> (utils/build-mlir.cmd)

2
docs/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
_site
Gemfile*

View File

@ -44,7 +44,7 @@ Once it is invoked, you will need to add literal tests in ` test/mlir/onnx/onnx_
Files related to the lowering of the new operations resides in the `src/Conversion/ONNXtoKRNL` directory and subdirectories. For the `concat` operation, we added code to lower it to krnl dialect in the ` src/Conversion/ONNXToKrnl/Tensor/concat.cpp` file. See other similar lowering for inspiration. We suggest to use `assert` statements for any unexpected values encountered while lowering the operation, as illegal parameter values should be caught in the shape inference phase, not successive passes such as lowering to the krnl dialect.
In that file, the `populateLoweringONNXConcatOpPattern` operation (where `Concat` would be replaced with the actual new operation) will need to be defined in ` src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp` and invoked in the ` runOnModule` function in the ` src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp` file.
In that file, the `populateLoweringONNXConcatOpPattern` operation (where `Concat` would be replaced with the actual new operation) will need to be defined in ` src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp` and invoked in the ` runOnOperation` function in the ` src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp` file.
To compile properly, you will also need to add the new `.cpp` file in the ` src/Conversion/ONNXToKrnl/CMakeLists.txt` file.

View File

@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
```
[same-as-file]: <> (utils/build-mlir.sh)
@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
```
[same-as-file]: <> (utils/build-mlir.cmd)

View File

@ -4,5 +4,8 @@ add_custom_target(check-doc
COMMAND ${PYTHON_EXECUTABLE}
${CMAKE_CURRENT_SOURCE_DIR}/check.py
${ONNX_MLIR_SRC_ROOT}
--exclude_dirs third_party docs/doc_check/test)
--exclude_dirs
third_party
docs/doc_check/test
docs/_site)

View File

@ -41,13 +41,13 @@ public:
/// This is a partial lowering to Krnl loops of the ONNX operations.
namespace {
struct FrontendToKrnlLoweringPass
: public ModulePass<FrontendToKrnlLoweringPass> {
void runOnModule() final;
: public PassWrapper<FrontendToKrnlLoweringPass, OperationPass<ModuleOp>> {
void runOnOperation() final;
};
} // end anonymous namespace.
void FrontendToKrnlLoweringPass::runOnModule() {
ModuleOp module = getModule();
void FrontendToKrnlLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
// The first thing to define is the conversion target. This will define the
// final target for this lowering.

View File

@ -456,11 +456,11 @@ struct ONNXPoolOpLowering : public ConversionPattern {
}
dimExpr.emplace_back(de);
}
poolDimMap = AffineMap::get(1, 5, dimExpr);
poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
// poolStartMap and poolEndMap
poolStartMap = AffineMap::get(1, 5, {start1, start2});
poolEndMap = AffineMap::get(1, 5, {end1, end2});
poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
}
// Obtain values from the affine maps.

View File

@ -11,7 +11,9 @@
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/InitLLVM.h>
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/IR/AsmState.h>
#include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
@ -23,28 +25,27 @@
using namespace onnx_mlir;
static llvm::cl::opt<std::string> input_filename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
// TODO(tjingrant): disable the following namespace import.
using namespace mlir;
static llvm::cl::opt<std::string> input_filename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
static llvm::cl::opt<std::string> output_filename("o",
llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string>
output_filename("o", llvm::cl::desc("Output filename"),
llvm::cl::value_desc("filename"), llvm::cl::init("-"));
static llvm::cl::opt<bool> split_input_file(
"split-input-file",
static llvm::cl::opt<bool> split_input_file("split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));
static llvm::cl::opt<bool> verify_diagnostics(
"verify-diagnostics",
static llvm::cl::opt<bool> verify_diagnostics("verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
static llvm::cl::opt<bool> verify_passes(
"verify-each",
static llvm::cl::opt<bool> verify_passes("verify-each",
llvm::cl::desc("Run the verifier after each transformation pass"),
llvm::cl::init(true));
@ -58,16 +59,23 @@ int main(int argc, char **argv) {
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
// Register transformation passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Transforms/Passes.h.inc"
llvm::InitLLVM y(argc, argv);
mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>();
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
// Register any pass manager command line options.
mlir::registerPassManagerCLOptions();
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
llvm::cl::ParseCommandLineOptions(argc, argv,
"ONNX MLIR modular optimizer driver\n");
llvm::cl::ParseCommandLineOptions(
argc, argv, "ONNX MLIR modular optimizer driver\n");
// Set up the input file.
std::string error_message;
@ -78,6 +86,6 @@ int main(int argc, char **argv) {
assert(output);
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
split_input_file, verify_diagnostics,
verify_passes, allowUnregisteredDialects));
split_input_file, verify_diagnostics, verify_passes,
allowUnregisteredDialects));
}

View File

@ -36,8 +36,8 @@ class KrnlConstGlobalValueElision : public OpRewritePattern<KrnlGlobalOp> {
public:
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
LogicalResult matchAndRewrite(KrnlGlobalOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(
KrnlGlobalOp op, PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
if (op.value().hasValue()) {
@ -54,7 +54,7 @@ public:
* Function pass that performs constant value elision of Krnl globals.
*/
class ElideConstGlobalValuePass
: public mlir::FunctionPass<ElideConstGlobalValuePass> {
: public PassWrapper<ElideConstGlobalValuePass, FunctionPass> {
public:
void runOnFunction() override {
auto function = getFunction();
@ -63,7 +63,7 @@ public:
OwningRewritePatternList patterns;
patterns.insert<KrnlConstGlobalValueElision>(&getContext());
applyPatternsGreedily(function, patterns);
applyPatternsAndFoldGreedily(function, patterns);
}
};
} // namespace

View File

@ -149,7 +149,7 @@ public:
/// add and multiply, this pass will leave these operations intact.
namespace {
struct KrnlToAffineLoweringPass
: public FunctionPass<KrnlToAffineLoweringPass> {
: public PassWrapper<KrnlToAffineLoweringPass, FunctionPass> {
void runOnFunction() final;
};
} // end anonymous namespace.

View File

@ -16,8 +16,8 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
@ -94,13 +94,12 @@ static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
public:
explicit KrnlGlobalOpLowering(MLIRContext *context,
LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(KrnlGlobalOp::getOperationName(), context,
lowering_) {}
explicit KrnlGlobalOpLowering(
MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(
KrnlGlobalOp::getOperationName(), context, lowering_) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
auto loc = op->getLoc();
@ -119,7 +118,7 @@ public:
// Compute total number of elements.
auto shape = (krnlGlobalOp.shape()).dyn_cast<ArrayAttr>();
int64_t numElements = 1;
for (int i=0; i<shape.size(); ++i)
for (int i = 0; i < shape.size(); ++i)
numElements *= ArrayAttrIntVal(shape, i);
// Create the global at the entry of the module.
@ -133,7 +132,7 @@ public:
auto constantElementType =
typeConverter.convertType(memRefTy.getElementType());
auto globalType = constantElementType;
for (int i=shape.size() - 1; i >= 0; i--)
for (int i = shape.size() - 1; i >= 0; i--)
globalType = LLVM::LLVMType::getArrayTy(
globalType.cast<LLVM::LLVMType>(), ArrayAttrIntVal(shape, i));
// The llvm type of the global (example: [2 x [8 x float]])
@ -145,7 +144,6 @@ public:
assert(krnlGlobalOp.value().hasValue() &&
"Krnl Global must always have a value");
global = rewriter.create<LLVM::GlobalOp>(loc,
llvmGlobalType, /*isConstant=*/true,
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
@ -157,37 +155,36 @@ public:
// Allocate the memory where the constants will be used from.
// This is a region of local memory and needs to be emitted as an alloca.
auto one = rewriter.create<LLVM::ConstantOp>(loc,
llvmI64Ty, rewriter.getI64IntegerAttr(1));
auto one = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
auto alloc = rewriter.create<LLVM::AllocaOp>(
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
// Copy constant value into the local alloca:
// - Bitcast alloc to i8*
Value int8PtrAlloc = rewriter.create<LLVM::BitcastOp>(
loc, llvmI8PtrTy, alloc);
Value int8PtrAlloc =
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, alloc);
// - Bitcast global to i8*
Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
Value i8PtrGlobal = rewriter.create<LLVM::BitcastOp>(
loc, llvmI8PtrTy, globalValue);
Value i8PtrGlobal =
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, globalValue);
// - Set size.
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc,
llvmI64Ty, rewriter.getI64IntegerAttr(
getMemRefEltSizeInBytes(memRefTy)));
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy)));
Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
Value totalElementsSize = rewriter.create<LLVM::MulOp>(
loc, memRefElementSize, numElementsValue);
Value int64Size = rewriter.create<LLVM::SExtOp>(
loc, llvmI64Ty, totalElementsSize);
Value totalElementsSize =
rewriter.create<LLVM::MulOp>(loc, memRefElementSize, numElementsValue);
Value int64Size =
rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
// - Set volatile.
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
LLVM::LLVMType::getInt1Ty(llvmDialect),
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// - Copy constant data into the alloca.
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
rewriter.create<CallOp>(
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
rewriter.create<CallOp>(loc, memcpyRef,
LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
// Prepare data to be inserted into MemRef.
@ -605,16 +602,18 @@ private:
//===----------------------------------------------------------------------===//
namespace {
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
void runOnModule() final;
struct KrnlToLLVMLoweringPass
: public PassWrapper<KrnlToLLVMLoweringPass, OperationPass<ModuleOp>> {
void runOnOperation() final;
};
} // end anonymous namespace
void KrnlToLLVMLoweringPass::runOnModule() {
void KrnlToLLVMLoweringPass::runOnOperation() {
// Define the target for this lowering i.e. the LLVM dialect.
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
// target.addLegalOp<KrnlEntryPointOp>();
// Lower the MemRef types to a representation in LLVM.
LLVMTypeConverter typeConverter(&getContext());
@ -625,8 +624,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns,
/*useAlloca=*/false,
/*emitCWrapper=*/true);
/*emitCWrapperS=*/true,
/*useAlignedAlloc=*/false);
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
@ -636,8 +635,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
if (failed(
applyFullConversion(getModule(), target, patterns, &typeConverter)))
if (failed(applyFullConversion(
getOperation(), target, patterns, &typeConverter)))
signalPassFailure();
}

View File

@ -39,7 +39,7 @@ void getOrCreateNoneValue(llvm::Optional<mlir::Value> &none, FuncOp f) {
* desirable (as instructed by the PromotableConstOperandsOpInterface).
*/
class AttributePromotionPass
: public mlir::FunctionPass<AttributePromotionPass> {
: public mlir::PassWrapper<AttributePromotionPass, mlir::FunctionPass> {
public:
void runOnFunction() override {
auto f = getFunction();
@ -75,7 +75,7 @@ public:
OwningRewritePatternList patterns;
auto *context = &getContext();
ConstantOp::getCanonicalizationPatterns(patterns, context);
applyPatternsGreedily(f, patterns);
applyPatternsAndFoldGreedily(f, patterns);
}
};
} // end anonymous namespace

View File

@ -37,8 +37,8 @@ class ConstantValueElision : public OpRewritePattern<ONNXConstantOp> {
public:
using OpRewritePattern<ONNXConstantOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXConstantOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(
ONNXConstantOp op, PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto constOp = llvm::dyn_cast<ONNXConstantOp>(&op);
@ -58,7 +58,7 @@ public:
* Function pass that performs constant value elision.
*/
class ElideConstantValuePass
: public mlir::FunctionPass<ElideConstantValuePass> {
: public PassWrapper<ElideConstantValuePass, FunctionPass> {
public:
void runOnFunction() override {
auto function = getFunction();
@ -67,7 +67,7 @@ public:
OwningRewritePatternList patterns;
patterns.insert<ConstantValueElision>(&getContext());
applyPatternsGreedily(function, patterns);
applyPatternsAndFoldGreedily(function, patterns);
}
};
} // end anonymous namespace

View File

@ -27,7 +27,7 @@ namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc"
struct DecomposeONNXToONNXPass : public FunctionPass<DecomposeONNXToONNXPass> {
struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
void runOnFunction() final;
};
} // end anonymous namespace.

View File

@ -25,7 +25,7 @@ namespace {
* candidate operations and propagating the shape information until the list
* of operations is empty [credit MLIR authors].
*/
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
public:
void runOnFunction() override {
auto f = getFunction();

View File

@ -3,7 +3,7 @@ configure_file(test_config.py.in test_config.py)
find_package(PythonInterp 3 REQUIRED)
add_custom_target(check-onnx-backend
COMMAND LD_PRELOAD=${CMAKE_BINARY_DIR}/lib/libcruntime.so ${PYTHON_EXECUTABLE}
COMMAND LD_PRELOAD=$<TARGET_FILE:cruntime> ${PYTHON_EXECUTABLE}
${CMAKE_CURRENT_BINARY_DIR}/test.py)
add_dependencies(check-onnx-backend onnx-mlir)

View File

@ -1,3 +1,3 @@
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..