Make onnx-mlir work with latest mlir. (#93)

* Make onnx-mlir work with latest mlir.

* Bump CircleCI cache version.

* Fix missing passes in onnx-mlir-opt.

* Fix backend test failure.

* Fix doc.

* Fix doc and exclude the generated _site directory from DocCheck.

* Remove debug code.

* Do not hard code target name, on Mac shared lib can end with .dylib.

* FunctionPass -> PassWrapper.
This commit is contained in:
Tian Jin 2020-04-27 17:03:56 +08:00 committed by GitHub
parent 137ce767e6
commit fad2ad7d03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 108 additions and 90 deletions

View File

@ -18,7 +18,7 @@ jobs:
git submodule update --init --recursive git submodule update --init --recursive
# Use cached mlir installation if possible. # Use cached mlir installation if possible.
- restore_cache: - restore_cache:
key: V10-1-LLVM-PROJECT-{{ arch }} key: V11-LLVM-PROJECT-{{ arch }}
- run: - run:
name: Install MLIR name: Install MLIR
command: | command: |
@ -29,7 +29,7 @@ jobs:
source onnx-mlir/utils/install-mlir.sh source onnx-mlir/utils/install-mlir.sh
fi fi
- save_cache: - save_cache:
key: V10-1-LLVM-PROJECT-{{ arch }} key: V11-LLVM-PROJECT-{{ arch }}
paths: paths:
- llvm-project - llvm-project
- run: - run:

3
.gitignore vendored
View File

@ -1,3 +1,6 @@
.idea/
cmake-*/
# Prerequisites # Prerequisites
*.d *.d

View File

@ -132,8 +132,9 @@ function(find_mlir_lib lib)
endif() endif()
endfunction(find_mlir_lib) endfunction(find_mlir_lib)
find_mlir_lib(MLIRAffine) find_mlir_lib(MLIRAffineOps)
find_mlir_lib(MLIRAffineToStandard) find_mlir_lib(MLIRAffineToStandard)
find_mlir_lib(MLIRAffineTransforms)
find_mlir_lib(MLIRAnalysis) find_mlir_lib(MLIRAnalysis)
find_mlir_lib(MLIRCallInterfaces) find_mlir_lib(MLIRCallInterfaces)
find_mlir_lib(MLIRControlFlowInterfaces) find_mlir_lib(MLIRControlFlowInterfaces)
@ -193,8 +194,9 @@ set(MLIRLibs
${MLIRTargetLLVMIRModuleTranslation} ${MLIRTargetLLVMIRModuleTranslation}
${MLIRTransforms} ${MLIRTransforms}
${MLIRTransformUtils} ${MLIRTransformUtils}
${MLIRAffine} ${MLIRAffineOps}
${MLIRAffineToStandard} ${MLIRAffineToStandard}
${MLIRAffineTransforms}
${MLIRAnalysis} ${MLIRAnalysis}
${MLIRCallInterfaces} ${MLIRCallInterfaces}
${MLIRControlFlowInterfaces} ${MLIRControlFlowInterfaces}
@ -244,14 +246,15 @@ set(MLIRLibs
# must be specified on LD_PRELOAD for shared build. # must be specified on LD_PRELOAD for shared build.
set(MLIRWholeArchiveLibs set(MLIRWholeArchiveLibs
MLIRAffineToStandard MLIRAffineToStandard
MLIRAffine MLIRAffineOps
MLIRLLVMIR MLIRLLVMIR
MLIRStandardOps MLIRStandardOps
MLIRStandardToLLVM MLIRStandardToLLVM
MLIRTransforms MLIRTransforms
MLIRLoopToStandard MLIRLoopToStandard
MLIRVector MLIRVector
MLIRLoopOps) MLIRLoopOps
MLIRIR)
# ONNX MLIR libraries that must be linked with --whole-archive for static build or # ONNX MLIR libraries that must be linked with --whole-archive for static build or
# must be specified on LD_PRELOAD for shared build. # must be specified on LD_PRELOAD for shared build.

View File

@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash ``` bash
git clone https://github.com/llvm/llvm-project.git git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR. # Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
``` ```
[same-as-file]: <> (utils/build-mlir.sh) [same-as-file]: <> (utils/build-mlir.sh)
@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project):
```shell ```shell
git clone https://github.com/llvm/llvm-project.git git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR. # Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
``` ```
[same-as-file]: <> (utils/build-mlir.cmd) [same-as-file]: <> (utils/build-mlir.cmd)

2
docs/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
_site
Gemfile*

View File

@ -44,7 +44,7 @@ Once it is invoked, you will need to add literal tests in ` test/mlir/onnx/onnx_
Files related to the lowering of the new operations resides in the `src/Conversion/ONNXtoKRNL` directory and subdirectories. For the `concat` operation, we added code to lower it to krnl dialect in the ` src/Conversion/ONNXToKrnl/Tensor/concat.cpp` file. See other similar lowering for inspiration. We suggest to use `assert` statements for any unexpected values encountered while lowering the operation, as illegal parameter values should be caught in the shape inference phase, not successive passes such as lowering to the krnl dialect. Files related to the lowering of the new operations resides in the `src/Conversion/ONNXtoKRNL` directory and subdirectories. For the `concat` operation, we added code to lower it to krnl dialect in the ` src/Conversion/ONNXToKrnl/Tensor/concat.cpp` file. See other similar lowering for inspiration. We suggest to use `assert` statements for any unexpected values encountered while lowering the operation, as illegal parameter values should be caught in the shape inference phase, not successive passes such as lowering to the krnl dialect.
In that file, the `populateLoweringONNXConcatOpPattern` operation (where `Concat` would be replaced with the actual new operation) will need to be defined in ` src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp` and invoked in the ` runOnModule` function in the ` src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp` file. In that file, the `populateLoweringONNXConcatOpPattern` operation (where `Concat` would be replaced with the actual new operation) will need to be defined in ` src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp` and invoked in the ` runOnOperation` function in the ` src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp` file.
To compile properly, you will also need to add the new `.cpp` file in the ` src/Conversion/ONNXToKrnl/CMakeLists.txt` file. To compile properly, you will also need to add the new `.cpp` file in the ` src/Conversion/ONNXToKrnl/CMakeLists.txt` file.

View File

@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash ``` bash
git clone https://github.com/llvm/llvm-project.git git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR. # Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
``` ```
[same-as-file]: <> (utils/build-mlir.sh) [same-as-file]: <> (utils/build-mlir.sh)
@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project):
```shell ```shell
git clone https://github.com/llvm/llvm-project.git git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR. # Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
``` ```
[same-as-file]: <> (utils/build-mlir.cmd) [same-as-file]: <> (utils/build-mlir.cmd)

View File

@ -4,5 +4,8 @@ add_custom_target(check-doc
COMMAND ${PYTHON_EXECUTABLE} COMMAND ${PYTHON_EXECUTABLE}
${CMAKE_CURRENT_SOURCE_DIR}/check.py ${CMAKE_CURRENT_SOURCE_DIR}/check.py
${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_SRC_ROOT}
--exclude_dirs third_party docs/doc_check/test) --exclude_dirs
third_party
docs/doc_check/test
docs/_site)

View File

@ -41,13 +41,13 @@ public:
/// This is a partial lowering to Krnl loops of the ONNX operations. /// This is a partial lowering to Krnl loops of the ONNX operations.
namespace { namespace {
struct FrontendToKrnlLoweringPass struct FrontendToKrnlLoweringPass
: public ModulePass<FrontendToKrnlLoweringPass> { : public PassWrapper<FrontendToKrnlLoweringPass, OperationPass<ModuleOp>> {
void runOnModule() final; void runOnOperation() final;
}; };
} // end anonymous namespace. } // end anonymous namespace.
void FrontendToKrnlLoweringPass::runOnModule() { void FrontendToKrnlLoweringPass::runOnOperation() {
ModuleOp module = getModule(); ModuleOp module = getOperation();
// The first thing to define is the conversion target. This will define the // The first thing to define is the conversion target. This will define the
// final target for this lowering. // final target for this lowering.

View File

@ -456,11 +456,11 @@ struct ONNXPoolOpLowering : public ConversionPattern {
} }
dimExpr.emplace_back(de); dimExpr.emplace_back(de);
} }
poolDimMap = AffineMap::get(1, 5, dimExpr); poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
// poolStartMap and poolEndMap // poolStartMap and poolEndMap
poolStartMap = AffineMap::get(1, 5, {start1, start2}); poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
poolEndMap = AffineMap::get(1, 5, {end1, end2}); poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
} }
// Obtain values from the affine maps. // Obtain values from the affine maps.

View File

@ -11,7 +11,9 @@
#include <llvm/Support/CommandLine.h> #include <llvm/Support/CommandLine.h>
#include <llvm/Support/InitLLVM.h> #include <llvm/Support/InitLLVM.h>
#include <llvm/Support/ToolOutputFile.h> #include <llvm/Support/ToolOutputFile.h>
#include <mlir/IR/AsmState.h>
#include <mlir/InitAllDialects.h> #include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h> #include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h> #include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h> #include <mlir/Support/FileUtilities.h>
@ -23,28 +25,27 @@
using namespace onnx_mlir; using namespace onnx_mlir;
static llvm::cl::opt<std::string> input_filename(llvm::cl::Positional, // TODO(tjingrant): disable the following namespace import.
llvm::cl::desc("<input file>"), using namespace mlir;
llvm::cl::init("-"));
static llvm::cl::opt<std::string> static llvm::cl::opt<std::string> input_filename(
output_filename("o", llvm::cl::desc("Output filename"), llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
llvm::cl::value_desc("filename"), llvm::cl::init("-"));
static llvm::cl::opt<bool> split_input_file( static llvm::cl::opt<std::string> output_filename("o",
"split-input-file", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
static llvm::cl::opt<bool> split_input_file("split-input-file",
llvm::cl::desc("Split the input file into pieces and process each " llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"), "chunk independently"),
llvm::cl::init(false)); llvm::cl::init(false));
static llvm::cl::opt<bool> verify_diagnostics( static llvm::cl::opt<bool> verify_diagnostics("verify-diagnostics",
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match " llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"), "expected-* lines on the corresponding line"),
llvm::cl::init(false)); llvm::cl::init(false));
static llvm::cl::opt<bool> verify_passes( static llvm::cl::opt<bool> verify_passes("verify-each",
"verify-each",
llvm::cl::desc("Run the verifier after each transformation pass"), llvm::cl::desc("Run the verifier after each transformation pass"),
llvm::cl::init(true)); llvm::cl::init(true));
@ -58,16 +59,23 @@ int main(int argc, char **argv) {
mlir::registerDialect<mlir::LLVM::LLVMDialect>(); mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>(); mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>(); mlir::registerDialect<mlir::StandardOpsDialect>();
// Register transformation passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Transforms/Passes.h.inc"
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);
mlir::registerDialect<mlir::ONNXOpsDialect>(); mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>(); mlir::registerDialect<mlir::KrnlOpsDialect>();
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
// Register any pass manager command line options. // Register any pass manager command line options.
mlir::registerPassManagerCLOptions(); mlir::registerPassManagerCLOptions();
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
llvm::cl::ParseCommandLineOptions(argc, argv, llvm::cl::ParseCommandLineOptions(
"ONNX MLIR modular optimizer driver\n"); argc, argv, "ONNX MLIR modular optimizer driver\n");
// Set up the input file. // Set up the input file.
std::string error_message; std::string error_message;
@ -78,6 +86,6 @@ int main(int argc, char **argv) {
assert(output); assert(output);
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline, return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
split_input_file, verify_diagnostics, split_input_file, verify_diagnostics, verify_passes,
verify_passes, allowUnregisteredDialects)); allowUnregisteredDialects));
} }

View File

@ -36,13 +36,13 @@ class KrnlConstGlobalValueElision : public OpRewritePattern<KrnlGlobalOp> {
public: public:
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern; using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
LogicalResult matchAndRewrite(KrnlGlobalOp op, LogicalResult matchAndRewrite(
PatternRewriter &rewriter) const override { KrnlGlobalOp op, PatternRewriter &rewriter) const override {
auto loc = op.getLoc(); auto loc = op.getLoc();
if (op.value().hasValue()) { if (op.value().hasValue()) {
auto newGlobalOp = rewriter.create<KrnlGlobalOp>( auto newGlobalOp = rewriter.create<KrnlGlobalOp>(
loc, op.getResult().getType(), op.shape(), op.name(), nullptr); loc, op.getResult().getType(), op.shape(), op.name(), nullptr);
rewriter.replaceOp(op, newGlobalOp.getResult()); rewriter.replaceOp(op, newGlobalOp.getResult());
} }
@ -54,7 +54,7 @@ public:
* Function pass that performs constant value elision of Krnl globals. * Function pass that performs constant value elision of Krnl globals.
*/ */
class ElideConstGlobalValuePass class ElideConstGlobalValuePass
: public mlir::FunctionPass<ElideConstGlobalValuePass> { : public PassWrapper<ElideConstGlobalValuePass, FunctionPass> {
public: public:
void runOnFunction() override { void runOnFunction() override {
auto function = getFunction(); auto function = getFunction();
@ -63,7 +63,7 @@ public:
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<KrnlConstGlobalValueElision>(&getContext()); patterns.insert<KrnlConstGlobalValueElision>(&getContext());
applyPatternsGreedily(function, patterns); applyPatternsAndFoldGreedily(function, patterns);
} }
}; };
} // namespace } // namespace

View File

@ -149,7 +149,7 @@ public:
/// add and multiply, this pass will leave these operations intact. /// add and multiply, this pass will leave these operations intact.
namespace { namespace {
struct KrnlToAffineLoweringPass struct KrnlToAffineLoweringPass
: public FunctionPass<KrnlToAffineLoweringPass> { : public PassWrapper<KrnlToAffineLoweringPass, FunctionPass> {
void runOnFunction() final; void runOnFunction() final;
}; };
} // end anonymous namespace. } // end anonymous namespace.

View File

@ -16,8 +16,8 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h" #include "llvm/ADT/Sequence.h"
@ -94,14 +94,13 @@ static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
class KrnlGlobalOpLowering : public ConvertToLLVMPattern { class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
public: public:
explicit KrnlGlobalOpLowering(MLIRContext *context, explicit KrnlGlobalOpLowering(
LLVMTypeConverter &lowering_) MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(KrnlGlobalOp::getOperationName(), context, : ConvertToLLVMPattern(
lowering_) {} KrnlGlobalOp::getOperationName(), context, lowering_) {}
LogicalResult LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override {
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext(); auto *context = op->getContext();
auto loc = op->getLoc(); auto loc = op->getLoc();
auto *llvmDialect = auto *llvmDialect =
@ -119,7 +118,7 @@ public:
// Compute total number of elements. // Compute total number of elements.
auto shape = (krnlGlobalOp.shape()).dyn_cast<ArrayAttr>(); auto shape = (krnlGlobalOp.shape()).dyn_cast<ArrayAttr>();
int64_t numElements = 1; int64_t numElements = 1;
for (int i=0; i<shape.size(); ++i) for (int i = 0; i < shape.size(); ++i)
numElements *= ArrayAttrIntVal(shape, i); numElements *= ArrayAttrIntVal(shape, i);
// Create the global at the entry of the module. // Create the global at the entry of the module.
@ -133,7 +132,7 @@ public:
auto constantElementType = auto constantElementType =
typeConverter.convertType(memRefTy.getElementType()); typeConverter.convertType(memRefTy.getElementType());
auto globalType = constantElementType; auto globalType = constantElementType;
for (int i=shape.size() - 1; i >= 0; i--) for (int i = shape.size() - 1; i >= 0; i--)
globalType = LLVM::LLVMType::getArrayTy( globalType = LLVM::LLVMType::getArrayTy(
globalType.cast<LLVM::LLVMType>(), ArrayAttrIntVal(shape, i)); globalType.cast<LLVM::LLVMType>(), ArrayAttrIntVal(shape, i));
// The llvm type of the global (example: [2 x [8 x float]]) // The llvm type of the global (example: [2 x [8 x float]])
@ -145,7 +144,6 @@ public:
assert(krnlGlobalOp.value().hasValue() && assert(krnlGlobalOp.value().hasValue() &&
"Krnl Global must always have a value"); "Krnl Global must always have a value");
global = rewriter.create<LLVM::GlobalOp>(loc, global = rewriter.create<LLVM::GlobalOp>(loc,
llvmGlobalType, /*isConstant=*/true, llvmGlobalType, /*isConstant=*/true,
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue()); LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
@ -157,47 +155,46 @@ public:
// Allocate the memory where the constants will be used from. // Allocate the memory where the constants will be used from.
// This is a region of local memory and needs to be emitted as an alloca. // This is a region of local memory and needs to be emitted as an alloca.
auto one = rewriter.create<LLVM::ConstantOp>(loc, auto one = rewriter.create<LLVM::ConstantOp>(
llvmI64Ty, rewriter.getI64IntegerAttr(1)); loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
auto alloc = rewriter.create<LLVM::AllocaOp>( auto alloc = rewriter.create<LLVM::AllocaOp>(
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0); loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
// Copy constant value into the local alloca: // Copy constant value into the local alloca:
// - Bitcast alloc to i8* // - Bitcast alloc to i8*
Value int8PtrAlloc = rewriter.create<LLVM::BitcastOp>( Value int8PtrAlloc =
loc, llvmI8PtrTy, alloc); rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, alloc);
// - Bitcast global to i8* // - Bitcast global to i8*
Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global); Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
Value i8PtrGlobal = rewriter.create<LLVM::BitcastOp>( Value i8PtrGlobal =
loc, llvmI8PtrTy, globalValue); rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, globalValue);
// - Set size. // - Set size.
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc, Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
llvmI64Ty, rewriter.getI64IntegerAttr( rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy)));
getMemRefEltSizeInBytes(memRefTy)));
Value numElementsValue = rewriter.create<LLVM::ConstantOp>( Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements)); loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
Value totalElementsSize = rewriter.create<LLVM::MulOp>( Value totalElementsSize =
loc, memRefElementSize, numElementsValue); rewriter.create<LLVM::MulOp>(loc, memRefElementSize, numElementsValue);
Value int64Size = rewriter.create<LLVM::SExtOp>( Value int64Size =
loc, llvmI64Ty, totalElementsSize); rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
// - Set volatile. // - Set volatile.
Value isVolatile = rewriter.create<LLVM::ConstantOp>( Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
loc, LLVM::LLVMType::getInt1Ty(llvmDialect), LLVM::LLVMType::getInt1Ty(llvmDialect),
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// - Copy constant data into the alloca. // - Copy constant data into the alloca.
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect); auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
rewriter.create<CallOp>( rewriter.create<CallOp>(loc, memcpyRef,
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile})); ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
// Prepare data to be inserted into MemRef. // Prepare data to be inserted into MemRef.
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>(); auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
Value typedAlloc = rewriter.create<LLVM::BitcastOp>( Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
loc, llvmConstantElementType.getPointerTo(), alloc); loc, llvmConstantElementType.getPointerTo(), alloc);
// Create llvm MemRef from original MemRef and fill the data pointers. // Create llvm MemRef from original MemRef and fill the data pointers.
auto llvmMemRef = MemRefDescriptor::fromStaticShape( auto llvmMemRef = MemRefDescriptor::fromStaticShape(
rewriter, loc, typeConverter, memRefTy, typedAlloc); rewriter, loc, typeConverter, memRefTy, typedAlloc);
rewriter.replaceOp(op, {llvmMemRef}); rewriter.replaceOp(op, {llvmMemRef});
return success(); return success();
@ -219,7 +216,7 @@ public:
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext(); auto *context = op->getContext();
KrnlMemcpyOpOperandAdaptor operandAdaptor(operands); KrnlMemcpyOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc(); auto loc = op->getLoc();
@ -605,16 +602,18 @@ private:
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace { namespace {
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> { struct KrnlToLLVMLoweringPass
void runOnModule() final; : public PassWrapper<KrnlToLLVMLoweringPass, OperationPass<ModuleOp>> {
void runOnOperation() final;
}; };
} // end anonymous namespace } // end anonymous namespace
void KrnlToLLVMLoweringPass::runOnModule() { void KrnlToLLVMLoweringPass::runOnOperation() {
// Define the target for this lowering i.e. the LLVM dialect. // Define the target for this lowering i.e. the LLVM dialect.
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>(); target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
// target.addLegalOp<KrnlEntryPointOp>();
// Lower the MemRef types to a representation in LLVM. // Lower the MemRef types to a representation in LLVM.
LLVMTypeConverter typeConverter(&getContext()); LLVMTypeConverter typeConverter(&getContext());
@ -625,8 +624,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
populateAffineToStdConversionPatterns(patterns, &getContext()); populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns, populateStdToLLVMConversionPatterns(typeConverter, patterns,
/*useAlloca=*/false, /*emitCWrapperS=*/true,
/*emitCWrapper=*/true); /*useAlignedAlloc=*/false);
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter); patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
@ -636,8 +635,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
// We want to completely lower to LLVM, so we use a `FullConversion`. This // We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion. // ensures that only legal operations will remain after the conversion.
if (failed( if (failed(applyFullConversion(
applyFullConversion(getModule(), target, patterns, &typeConverter))) getOperation(), target, patterns, &typeConverter)))
signalPassFailure(); signalPassFailure();
} }

View File

@ -39,7 +39,7 @@ void getOrCreateNoneValue(llvm::Optional<mlir::Value> &none, FuncOp f) {
* desirable (as instructed by the PromotableConstOperandsOpInterface). * desirable (as instructed by the PromotableConstOperandsOpInterface).
*/ */
class AttributePromotionPass class AttributePromotionPass
: public mlir::FunctionPass<AttributePromotionPass> { : public mlir::PassWrapper<AttributePromotionPass, mlir::FunctionPass> {
public: public:
void runOnFunction() override { void runOnFunction() override {
auto f = getFunction(); auto f = getFunction();
@ -75,7 +75,7 @@ public:
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto *context = &getContext(); auto *context = &getContext();
ConstantOp::getCanonicalizationPatterns(patterns, context); ConstantOp::getCanonicalizationPatterns(patterns, context);
applyPatternsGreedily(f, patterns); applyPatternsAndFoldGreedily(f, patterns);
} }
}; };
} // end anonymous namespace } // end anonymous namespace

View File

@ -37,8 +37,8 @@ class ConstantValueElision : public OpRewritePattern<ONNXConstantOp> {
public: public:
using OpRewritePattern<ONNXConstantOp>::OpRewritePattern; using OpRewritePattern<ONNXConstantOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXConstantOp op, LogicalResult matchAndRewrite(
PatternRewriter &rewriter) const override { ONNXConstantOp op, PatternRewriter &rewriter) const override {
auto loc = op.getLoc(); auto loc = op.getLoc();
auto constOp = llvm::dyn_cast<ONNXConstantOp>(&op); auto constOp = llvm::dyn_cast<ONNXConstantOp>(&op);
@ -58,7 +58,7 @@ public:
* Function pass that performs constant value elision. * Function pass that performs constant value elision.
*/ */
class ElideConstantValuePass class ElideConstantValuePass
: public mlir::FunctionPass<ElideConstantValuePass> { : public PassWrapper<ElideConstantValuePass, FunctionPass> {
public: public:
void runOnFunction() override { void runOnFunction() override {
auto function = getFunction(); auto function = getFunction();
@ -67,7 +67,7 @@ public:
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<ConstantValueElision>(&getContext()); patterns.insert<ConstantValueElision>(&getContext());
applyPatternsGreedily(function, patterns); applyPatternsAndFoldGreedily(function, patterns);
} }
}; };
} // end anonymous namespace } // end anonymous namespace

View File

@ -27,7 +27,7 @@ namespace {
/// Include the patterns defined in the Declarative Rewrite framework. /// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXDecompose.inc" #include "src/Transform/ONNX/ONNXDecompose.inc"
struct DecomposeONNXToONNXPass : public FunctionPass<DecomposeONNXToONNXPass> { struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
void runOnFunction() final; void runOnFunction() final;
}; };
} // end anonymous namespace. } // end anonymous namespace.

View File

@ -25,7 +25,7 @@ namespace {
* candidate operations and propagating the shape information until the list * candidate operations and propagating the shape information until the list
* of operations is empty [credit MLIR authors]. * of operations is empty [credit MLIR authors].
*/ */
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
public: public:
void runOnFunction() override { void runOnFunction() override {
auto f = getFunction(); auto f = getFunction();

View File

@ -3,7 +3,7 @@ configure_file(test_config.py.in test_config.py)
find_package(PythonInterp 3 REQUIRED) find_package(PythonInterp 3 REQUIRED)
add_custom_target(check-onnx-backend add_custom_target(check-onnx-backend
COMMAND LD_PRELOAD=${CMAKE_BINARY_DIR}/lib/libcruntime.so ${PYTHON_EXECUTABLE} COMMAND LD_PRELOAD=$<TARGET_FILE:cruntime> ${PYTHON_EXECUTABLE}
${CMAKE_CURRENT_BINARY_DIR}/test.py) ${CMAKE_CURRENT_BINARY_DIR}/test.py)
add_dependencies(check-onnx-backend onnx-mlir) add_dependencies(check-onnx-backend onnx-mlir)

View File

@ -1,3 +1,3 @@
git clone https://github.com/llvm/llvm-project.git git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR. # Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..