Make onnx-mlir work with latest mlir. (#93)
* Make onnx-mlir work with latest mlir. * Bump CircleCI cache version. * Fix missing passes in onnx-mlir-opt. * Fix backend test failure. * Fix doc. * Fix doc and exclude the generated _site directory from DocCheck. * Remove debug code. * Do not hard code target name, on Mac shared lib can end with .dylib. * FunctionPass -> PassWrapper.
This commit is contained in:
parent
137ce767e6
commit
fad2ad7d03
|
@ -18,7 +18,7 @@ jobs:
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
# Use cached mlir installation if possible.
|
# Use cached mlir installation if possible.
|
||||||
- restore_cache:
|
- restore_cache:
|
||||||
key: V10-1-LLVM-PROJECT-{{ arch }}
|
key: V11-LLVM-PROJECT-{{ arch }}
|
||||||
- run:
|
- run:
|
||||||
name: Install MLIR
|
name: Install MLIR
|
||||||
command: |
|
command: |
|
||||||
|
@ -29,7 +29,7 @@ jobs:
|
||||||
source onnx-mlir/utils/install-mlir.sh
|
source onnx-mlir/utils/install-mlir.sh
|
||||||
fi
|
fi
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: V10-1-LLVM-PROJECT-{{ arch }}
|
key: V11-LLVM-PROJECT-{{ arch }}
|
||||||
paths:
|
paths:
|
||||||
- llvm-project
|
- llvm-project
|
||||||
- run:
|
- run:
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
.idea/
|
||||||
|
cmake-*/
|
||||||
|
|
||||||
# Prerequisites
|
# Prerequisites
|
||||||
*.d
|
*.d
|
||||||
|
|
||||||
|
|
11
MLIR.cmake
11
MLIR.cmake
|
@ -132,8 +132,9 @@ function(find_mlir_lib lib)
|
||||||
endif()
|
endif()
|
||||||
endfunction(find_mlir_lib)
|
endfunction(find_mlir_lib)
|
||||||
|
|
||||||
find_mlir_lib(MLIRAffine)
|
find_mlir_lib(MLIRAffineOps)
|
||||||
find_mlir_lib(MLIRAffineToStandard)
|
find_mlir_lib(MLIRAffineToStandard)
|
||||||
|
find_mlir_lib(MLIRAffineTransforms)
|
||||||
find_mlir_lib(MLIRAnalysis)
|
find_mlir_lib(MLIRAnalysis)
|
||||||
find_mlir_lib(MLIRCallInterfaces)
|
find_mlir_lib(MLIRCallInterfaces)
|
||||||
find_mlir_lib(MLIRControlFlowInterfaces)
|
find_mlir_lib(MLIRControlFlowInterfaces)
|
||||||
|
@ -193,8 +194,9 @@ set(MLIRLibs
|
||||||
${MLIRTargetLLVMIRModuleTranslation}
|
${MLIRTargetLLVMIRModuleTranslation}
|
||||||
${MLIRTransforms}
|
${MLIRTransforms}
|
||||||
${MLIRTransformUtils}
|
${MLIRTransformUtils}
|
||||||
${MLIRAffine}
|
${MLIRAffineOps}
|
||||||
${MLIRAffineToStandard}
|
${MLIRAffineToStandard}
|
||||||
|
${MLIRAffineTransforms}
|
||||||
${MLIRAnalysis}
|
${MLIRAnalysis}
|
||||||
${MLIRCallInterfaces}
|
${MLIRCallInterfaces}
|
||||||
${MLIRControlFlowInterfaces}
|
${MLIRControlFlowInterfaces}
|
||||||
|
@ -244,14 +246,15 @@ set(MLIRLibs
|
||||||
# must be specified on LD_PRELOAD for shared build.
|
# must be specified on LD_PRELOAD for shared build.
|
||||||
set(MLIRWholeArchiveLibs
|
set(MLIRWholeArchiveLibs
|
||||||
MLIRAffineToStandard
|
MLIRAffineToStandard
|
||||||
MLIRAffine
|
MLIRAffineOps
|
||||||
MLIRLLVMIR
|
MLIRLLVMIR
|
||||||
MLIRStandardOps
|
MLIRStandardOps
|
||||||
MLIRStandardToLLVM
|
MLIRStandardToLLVM
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRLoopToStandard
|
MLIRLoopToStandard
|
||||||
MLIRVector
|
MLIRVector
|
||||||
MLIRLoopOps)
|
MLIRLoopOps
|
||||||
|
MLIRIR)
|
||||||
|
|
||||||
# ONNX MLIR libraries that must be linked with --whole-archive for static build or
|
# ONNX MLIR libraries that must be linked with --whole-archive for static build or
|
||||||
# must be specified on LD_PRELOAD for shared build.
|
# must be specified on LD_PRELOAD for shared build.
|
||||||
|
|
|
@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
|
||||||
``` bash
|
``` bash
|
||||||
git clone https://github.com/llvm/llvm-project.git
|
git clone https://github.com/llvm/llvm-project.git
|
||||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||||
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
|
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
|
||||||
```
|
```
|
||||||
|
|
||||||
[same-as-file]: <> (utils/build-mlir.sh)
|
[same-as-file]: <> (utils/build-mlir.sh)
|
||||||
|
@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project):
|
||||||
```shell
|
```shell
|
||||||
git clone https://github.com/llvm/llvm-project.git
|
git clone https://github.com/llvm/llvm-project.git
|
||||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||||
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
|
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
|
||||||
```
|
```
|
||||||
|
|
||||||
[same-as-file]: <> (utils/build-mlir.cmd)
|
[same-as-file]: <> (utils/build-mlir.cmd)
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
_site
|
||||||
|
Gemfile*
|
|
@ -44,7 +44,7 @@ Once it is invoked, you will need to add literal tests in ` test/mlir/onnx/onnx_
|
||||||
|
|
||||||
Files related to the lowering of the new operations resides in the `src/Conversion/ONNXtoKRNL` directory and subdirectories. For the `concat` operation, we added code to lower it to krnl dialect in the ` src/Conversion/ONNXToKrnl/Tensor/concat.cpp` file. See other similar lowering for inspiration. We suggest to use `assert` statements for any unexpected values encountered while lowering the operation, as illegal parameter values should be caught in the shape inference phase, not successive passes such as lowering to the krnl dialect.
|
Files related to the lowering of the new operations resides in the `src/Conversion/ONNXtoKRNL` directory and subdirectories. For the `concat` operation, we added code to lower it to krnl dialect in the ` src/Conversion/ONNXToKrnl/Tensor/concat.cpp` file. See other similar lowering for inspiration. We suggest to use `assert` statements for any unexpected values encountered while lowering the operation, as illegal parameter values should be caught in the shape inference phase, not successive passes such as lowering to the krnl dialect.
|
||||||
|
|
||||||
In that file, the `populateLoweringONNXConcatOpPattern` operation (where `Concat` would be replaced with the actual new operation) will need to be defined in ` src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp` and invoked in the ` runOnModule` function in the ` src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp` file.
|
In that file, the `populateLoweringONNXConcatOpPattern` operation (where `Concat` would be replaced with the actual new operation) will need to be defined in ` src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp` and invoked in the ` runOnOperation` function in the ` src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp` file.
|
||||||
|
|
||||||
To compile properly, you will also need to add the new `.cpp` file in the ` src/Conversion/ONNXToKrnl/CMakeLists.txt` file.
|
To compile properly, you will also need to add the new `.cpp` file in the ` src/Conversion/ONNXToKrnl/CMakeLists.txt` file.
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
|
||||||
``` bash
|
``` bash
|
||||||
git clone https://github.com/llvm/llvm-project.git
|
git clone https://github.com/llvm/llvm-project.git
|
||||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||||
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
|
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
|
||||||
```
|
```
|
||||||
|
|
||||||
[same-as-file]: <> (utils/build-mlir.sh)
|
[same-as-file]: <> (utils/build-mlir.sh)
|
||||||
|
@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project):
|
||||||
```shell
|
```shell
|
||||||
git clone https://github.com/llvm/llvm-project.git
|
git clone https://github.com/llvm/llvm-project.git
|
||||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||||
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
|
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
|
||||||
```
|
```
|
||||||
|
|
||||||
[same-as-file]: <> (utils/build-mlir.cmd)
|
[same-as-file]: <> (utils/build-mlir.cmd)
|
||||||
|
|
|
@ -4,5 +4,8 @@ add_custom_target(check-doc
|
||||||
COMMAND ${PYTHON_EXECUTABLE}
|
COMMAND ${PYTHON_EXECUTABLE}
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/check.py
|
${CMAKE_CURRENT_SOURCE_DIR}/check.py
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
--exclude_dirs third_party docs/doc_check/test)
|
--exclude_dirs
|
||||||
|
third_party
|
||||||
|
docs/doc_check/test
|
||||||
|
docs/_site)
|
||||||
|
|
||||||
|
|
|
@ -41,13 +41,13 @@ public:
|
||||||
/// This is a partial lowering to Krnl loops of the ONNX operations.
|
/// This is a partial lowering to Krnl loops of the ONNX operations.
|
||||||
namespace {
|
namespace {
|
||||||
struct FrontendToKrnlLoweringPass
|
struct FrontendToKrnlLoweringPass
|
||||||
: public ModulePass<FrontendToKrnlLoweringPass> {
|
: public PassWrapper<FrontendToKrnlLoweringPass, OperationPass<ModuleOp>> {
|
||||||
void runOnModule() final;
|
void runOnOperation() final;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
||||||
void FrontendToKrnlLoweringPass::runOnModule() {
|
void FrontendToKrnlLoweringPass::runOnOperation() {
|
||||||
ModuleOp module = getModule();
|
ModuleOp module = getOperation();
|
||||||
|
|
||||||
// The first thing to define is the conversion target. This will define the
|
// The first thing to define is the conversion target. This will define the
|
||||||
// final target for this lowering.
|
// final target for this lowering.
|
||||||
|
|
|
@ -456,11 +456,11 @@ struct ONNXPoolOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
dimExpr.emplace_back(de);
|
dimExpr.emplace_back(de);
|
||||||
}
|
}
|
||||||
poolDimMap = AffineMap::get(1, 5, dimExpr);
|
poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext());
|
||||||
|
|
||||||
// poolStartMap and poolEndMap
|
// poolStartMap and poolEndMap
|
||||||
poolStartMap = AffineMap::get(1, 5, {start1, start2});
|
poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext());
|
||||||
poolEndMap = AffineMap::get(1, 5, {end1, end2});
|
poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Obtain values from the affine maps.
|
// Obtain values from the affine maps.
|
||||||
|
|
|
@ -11,7 +11,9 @@
|
||||||
#include <llvm/Support/CommandLine.h>
|
#include <llvm/Support/CommandLine.h>
|
||||||
#include <llvm/Support/InitLLVM.h>
|
#include <llvm/Support/InitLLVM.h>
|
||||||
#include <llvm/Support/ToolOutputFile.h>
|
#include <llvm/Support/ToolOutputFile.h>
|
||||||
|
#include <mlir/IR/AsmState.h>
|
||||||
#include <mlir/InitAllDialects.h>
|
#include <mlir/InitAllDialects.h>
|
||||||
|
#include <mlir/InitAllPasses.h>
|
||||||
#include <mlir/Pass/Pass.h>
|
#include <mlir/Pass/Pass.h>
|
||||||
#include <mlir/Pass/PassManager.h>
|
#include <mlir/Pass/PassManager.h>
|
||||||
#include <mlir/Support/FileUtilities.h>
|
#include <mlir/Support/FileUtilities.h>
|
||||||
|
@ -23,28 +25,27 @@
|
||||||
|
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
|
|
||||||
static llvm::cl::opt<std::string> input_filename(llvm::cl::Positional,
|
// TODO(tjingrant): disable the following namespace import.
|
||||||
llvm::cl::desc("<input file>"),
|
using namespace mlir;
|
||||||
llvm::cl::init("-"));
|
|
||||||
|
|
||||||
static llvm::cl::opt<std::string>
|
static llvm::cl::opt<std::string> input_filename(
|
||||||
output_filename("o", llvm::cl::desc("Output filename"),
|
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
|
||||||
llvm::cl::value_desc("filename"), llvm::cl::init("-"));
|
|
||||||
|
|
||||||
static llvm::cl::opt<bool> split_input_file(
|
static llvm::cl::opt<std::string> output_filename("o",
|
||||||
"split-input-file",
|
llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
||||||
|
llvm::cl::init("-"));
|
||||||
|
|
||||||
|
static llvm::cl::opt<bool> split_input_file("split-input-file",
|
||||||
llvm::cl::desc("Split the input file into pieces and process each "
|
llvm::cl::desc("Split the input file into pieces and process each "
|
||||||
"chunk independently"),
|
"chunk independently"),
|
||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
static llvm::cl::opt<bool> verify_diagnostics(
|
static llvm::cl::opt<bool> verify_diagnostics("verify-diagnostics",
|
||||||
"verify-diagnostics",
|
|
||||||
llvm::cl::desc("Check that emitted diagnostics match "
|
llvm::cl::desc("Check that emitted diagnostics match "
|
||||||
"expected-* lines on the corresponding line"),
|
"expected-* lines on the corresponding line"),
|
||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
static llvm::cl::opt<bool> verify_passes(
|
static llvm::cl::opt<bool> verify_passes("verify-each",
|
||||||
"verify-each",
|
|
||||||
llvm::cl::desc("Run the verifier after each transformation pass"),
|
llvm::cl::desc("Run the verifier after each transformation pass"),
|
||||||
llvm::cl::init(true));
|
llvm::cl::init(true));
|
||||||
|
|
||||||
|
@ -58,16 +59,23 @@ int main(int argc, char **argv) {
|
||||||
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||||||
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
||||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||||
|
|
||||||
|
// Register transformation passes.
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "mlir/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
|
||||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||||
|
|
||||||
|
mlir::registerAsmPrinterCLOptions();
|
||||||
|
mlir::registerMLIRContextCLOptions();
|
||||||
// Register any pass manager command line options.
|
// Register any pass manager command line options.
|
||||||
mlir::registerPassManagerCLOptions();
|
mlir::registerPassManagerCLOptions();
|
||||||
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
|
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
llvm::cl::ParseCommandLineOptions(
|
||||||
"ONNX MLIR modular optimizer driver\n");
|
argc, argv, "ONNX MLIR modular optimizer driver\n");
|
||||||
|
|
||||||
// Set up the input file.
|
// Set up the input file.
|
||||||
std::string error_message;
|
std::string error_message;
|
||||||
|
@ -78,6 +86,6 @@ int main(int argc, char **argv) {
|
||||||
assert(output);
|
assert(output);
|
||||||
|
|
||||||
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
|
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
|
||||||
split_input_file, verify_diagnostics,
|
split_input_file, verify_diagnostics, verify_passes,
|
||||||
verify_passes, allowUnregisteredDialects));
|
allowUnregisteredDialects));
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,13 +36,13 @@ class KrnlConstGlobalValueElision : public OpRewritePattern<KrnlGlobalOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
|
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(KrnlGlobalOp op,
|
LogicalResult matchAndRewrite(
|
||||||
PatternRewriter &rewriter) const override {
|
KrnlGlobalOp op, PatternRewriter &rewriter) const override {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
if (op.value().hasValue()) {
|
if (op.value().hasValue()) {
|
||||||
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(
|
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(
|
||||||
loc, op.getResult().getType(), op.shape(), op.name(), nullptr);
|
loc, op.getResult().getType(), op.shape(), op.name(), nullptr);
|
||||||
rewriter.replaceOp(op, newGlobalOp.getResult());
|
rewriter.replaceOp(op, newGlobalOp.getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ public:
|
||||||
* Function pass that performs constant value elision of Krnl globals.
|
* Function pass that performs constant value elision of Krnl globals.
|
||||||
*/
|
*/
|
||||||
class ElideConstGlobalValuePass
|
class ElideConstGlobalValuePass
|
||||||
: public mlir::FunctionPass<ElideConstGlobalValuePass> {
|
: public PassWrapper<ElideConstGlobalValuePass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto function = getFunction();
|
auto function = getFunction();
|
||||||
|
@ -63,7 +63,7 @@ public:
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlConstGlobalValueElision>(&getContext());
|
patterns.insert<KrnlConstGlobalValueElision>(&getContext());
|
||||||
|
|
||||||
applyPatternsGreedily(function, patterns);
|
applyPatternsAndFoldGreedily(function, patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -149,7 +149,7 @@ public:
|
||||||
/// add and multiply, this pass will leave these operations intact.
|
/// add and multiply, this pass will leave these operations intact.
|
||||||
namespace {
|
namespace {
|
||||||
struct KrnlToAffineLoweringPass
|
struct KrnlToAffineLoweringPass
|
||||||
: public FunctionPass<KrnlToAffineLoweringPass> {
|
: public PassWrapper<KrnlToAffineLoweringPass, FunctionPass> {
|
||||||
void runOnFunction() final;
|
void runOnFunction() final;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "llvm/ADT/Sequence.h"
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
|
||||||
|
@ -94,14 +94,13 @@ static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
||||||
|
|
||||||
class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
|
class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
|
||||||
public:
|
public:
|
||||||
explicit KrnlGlobalOpLowering(MLIRContext *context,
|
explicit KrnlGlobalOpLowering(
|
||||||
LLVMTypeConverter &lowering_)
|
MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||||
: ConvertToLLVMPattern(KrnlGlobalOp::getOperationName(), context,
|
: ConvertToLLVMPattern(
|
||||||
lowering_) {}
|
KrnlGlobalOp::getOperationName(), context, lowering_) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto *context = op->getContext();
|
auto *context = op->getContext();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto *llvmDialect =
|
auto *llvmDialect =
|
||||||
|
@ -119,7 +118,7 @@ public:
|
||||||
// Compute total number of elements.
|
// Compute total number of elements.
|
||||||
auto shape = (krnlGlobalOp.shape()).dyn_cast<ArrayAttr>();
|
auto shape = (krnlGlobalOp.shape()).dyn_cast<ArrayAttr>();
|
||||||
int64_t numElements = 1;
|
int64_t numElements = 1;
|
||||||
for (int i=0; i<shape.size(); ++i)
|
for (int i = 0; i < shape.size(); ++i)
|
||||||
numElements *= ArrayAttrIntVal(shape, i);
|
numElements *= ArrayAttrIntVal(shape, i);
|
||||||
|
|
||||||
// Create the global at the entry of the module.
|
// Create the global at the entry of the module.
|
||||||
|
@ -133,7 +132,7 @@ public:
|
||||||
auto constantElementType =
|
auto constantElementType =
|
||||||
typeConverter.convertType(memRefTy.getElementType());
|
typeConverter.convertType(memRefTy.getElementType());
|
||||||
auto globalType = constantElementType;
|
auto globalType = constantElementType;
|
||||||
for (int i=shape.size() - 1; i >= 0; i--)
|
for (int i = shape.size() - 1; i >= 0; i--)
|
||||||
globalType = LLVM::LLVMType::getArrayTy(
|
globalType = LLVM::LLVMType::getArrayTy(
|
||||||
globalType.cast<LLVM::LLVMType>(), ArrayAttrIntVal(shape, i));
|
globalType.cast<LLVM::LLVMType>(), ArrayAttrIntVal(shape, i));
|
||||||
// The llvm type of the global (example: [2 x [8 x float]])
|
// The llvm type of the global (example: [2 x [8 x float]])
|
||||||
|
@ -145,7 +144,6 @@ public:
|
||||||
|
|
||||||
assert(krnlGlobalOp.value().hasValue() &&
|
assert(krnlGlobalOp.value().hasValue() &&
|
||||||
"Krnl Global must always have a value");
|
"Krnl Global must always have a value");
|
||||||
|
|
||||||
global = rewriter.create<LLVM::GlobalOp>(loc,
|
global = rewriter.create<LLVM::GlobalOp>(loc,
|
||||||
llvmGlobalType, /*isConstant=*/true,
|
llvmGlobalType, /*isConstant=*/true,
|
||||||
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
|
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
|
||||||
|
@ -157,47 +155,46 @@ public:
|
||||||
|
|
||||||
// Allocate the memory where the constants will be used from.
|
// Allocate the memory where the constants will be used from.
|
||||||
// This is a region of local memory and needs to be emitted as an alloca.
|
// This is a region of local memory and needs to be emitted as an alloca.
|
||||||
auto one = rewriter.create<LLVM::ConstantOp>(loc,
|
auto one = rewriter.create<LLVM::ConstantOp>(
|
||||||
llvmI64Ty, rewriter.getI64IntegerAttr(1));
|
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
|
||||||
auto alloc = rewriter.create<LLVM::AllocaOp>(
|
auto alloc = rewriter.create<LLVM::AllocaOp>(
|
||||||
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
|
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
|
||||||
|
|
||||||
// Copy constant value into the local alloca:
|
// Copy constant value into the local alloca:
|
||||||
// - Bitcast alloc to i8*
|
// - Bitcast alloc to i8*
|
||||||
Value int8PtrAlloc = rewriter.create<LLVM::BitcastOp>(
|
Value int8PtrAlloc =
|
||||||
loc, llvmI8PtrTy, alloc);
|
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, alloc);
|
||||||
// - Bitcast global to i8*
|
// - Bitcast global to i8*
|
||||||
Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
|
Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
|
||||||
Value i8PtrGlobal = rewriter.create<LLVM::BitcastOp>(
|
Value i8PtrGlobal =
|
||||||
loc, llvmI8PtrTy, globalValue);
|
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, globalValue);
|
||||||
// - Set size.
|
// - Set size.
|
||||||
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc,
|
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
|
||||||
llvmI64Ty, rewriter.getI64IntegerAttr(
|
rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy)));
|
||||||
getMemRefEltSizeInBytes(memRefTy)));
|
|
||||||
Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
|
Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
|
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
|
||||||
Value totalElementsSize = rewriter.create<LLVM::MulOp>(
|
Value totalElementsSize =
|
||||||
loc, memRefElementSize, numElementsValue);
|
rewriter.create<LLVM::MulOp>(loc, memRefElementSize, numElementsValue);
|
||||||
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
Value int64Size =
|
||||||
loc, llvmI64Ty, totalElementsSize);
|
rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
|
||||||
// - Set volatile.
|
// - Set volatile.
|
||||||
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
|
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
|
||||||
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
|
LLVM::LLVMType::getInt1Ty(llvmDialect),
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||||
// - Copy constant data into the alloca.
|
// - Copy constant data into the alloca.
|
||||||
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
|
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
|
||||||
rewriter.create<CallOp>(
|
rewriter.create<CallOp>(loc, memcpyRef,
|
||||||
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||||
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
||||||
|
|
||||||
// Prepare data to be inserted into MemRef.
|
// Prepare data to be inserted into MemRef.
|
||||||
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
|
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
|
||||||
Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
|
Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, llvmConstantElementType.getPointerTo(), alloc);
|
loc, llvmConstantElementType.getPointerTo(), alloc);
|
||||||
|
|
||||||
// Create llvm MemRef from original MemRef and fill the data pointers.
|
// Create llvm MemRef from original MemRef and fill the data pointers.
|
||||||
auto llvmMemRef = MemRefDescriptor::fromStaticShape(
|
auto llvmMemRef = MemRefDescriptor::fromStaticShape(
|
||||||
rewriter, loc, typeConverter, memRefTy, typedAlloc);
|
rewriter, loc, typeConverter, memRefTy, typedAlloc);
|
||||||
|
|
||||||
rewriter.replaceOp(op, {llvmMemRef});
|
rewriter.replaceOp(op, {llvmMemRef});
|
||||||
return success();
|
return success();
|
||||||
|
@ -219,7 +216,7 @@ public:
|
||||||
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto *context = op->getContext();
|
auto *context = op->getContext();
|
||||||
KrnlMemcpyOpOperandAdaptor operandAdaptor(operands);
|
KrnlMemcpyOpOperandAdaptor operandAdaptor(operands);
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
@ -605,16 +602,18 @@ private:
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
|
struct KrnlToLLVMLoweringPass
|
||||||
void runOnModule() final;
|
: public PassWrapper<KrnlToLLVMLoweringPass, OperationPass<ModuleOp>> {
|
||||||
|
void runOnOperation() final;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
void KrnlToLLVMLoweringPass::runOnModule() {
|
void KrnlToLLVMLoweringPass::runOnOperation() {
|
||||||
// Define the target for this lowering i.e. the LLVM dialect.
|
// Define the target for this lowering i.e. the LLVM dialect.
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
|
// target.addLegalOp<KrnlEntryPointOp>();
|
||||||
|
|
||||||
// Lower the MemRef types to a representation in LLVM.
|
// Lower the MemRef types to a representation in LLVM.
|
||||||
LLVMTypeConverter typeConverter(&getContext());
|
LLVMTypeConverter typeConverter(&getContext());
|
||||||
|
@ -625,8 +624,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
||||||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||||
populateLoopToStdConversionPatterns(patterns, &getContext());
|
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||||
populateStdToLLVMConversionPatterns(typeConverter, patterns,
|
populateStdToLLVMConversionPatterns(typeConverter, patterns,
|
||||||
/*useAlloca=*/false,
|
/*emitCWrapperS=*/true,
|
||||||
/*emitCWrapper=*/true);
|
/*useAlignedAlloc=*/false);
|
||||||
|
|
||||||
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
|
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
|
||||||
|
|
||||||
|
@ -636,8 +635,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
||||||
|
|
||||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||||
// ensures that only legal operations will remain after the conversion.
|
// ensures that only legal operations will remain after the conversion.
|
||||||
if (failed(
|
if (failed(applyFullConversion(
|
||||||
applyFullConversion(getModule(), target, patterns, &typeConverter)))
|
getOperation(), target, patterns, &typeConverter)))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ void getOrCreateNoneValue(llvm::Optional<mlir::Value> &none, FuncOp f) {
|
||||||
* desirable (as instructed by the PromotableConstOperandsOpInterface).
|
* desirable (as instructed by the PromotableConstOperandsOpInterface).
|
||||||
*/
|
*/
|
||||||
class AttributePromotionPass
|
class AttributePromotionPass
|
||||||
: public mlir::FunctionPass<AttributePromotionPass> {
|
: public mlir::PassWrapper<AttributePromotionPass, mlir::FunctionPass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto f = getFunction();
|
auto f = getFunction();
|
||||||
|
@ -75,7 +75,7 @@ public:
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
ConstantOp::getCanonicalizationPatterns(patterns, context);
|
ConstantOp::getCanonicalizationPatterns(patterns, context);
|
||||||
applyPatternsGreedily(f, patterns);
|
applyPatternsAndFoldGreedily(f, patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
|
@ -37,8 +37,8 @@ class ConstantValueElision : public OpRewritePattern<ONNXConstantOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<ONNXConstantOp>::OpRewritePattern;
|
using OpRewritePattern<ONNXConstantOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXConstantOp op,
|
LogicalResult matchAndRewrite(
|
||||||
PatternRewriter &rewriter) const override {
|
ONNXConstantOp op, PatternRewriter &rewriter) const override {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto constOp = llvm::dyn_cast<ONNXConstantOp>(&op);
|
auto constOp = llvm::dyn_cast<ONNXConstantOp>(&op);
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ public:
|
||||||
* Function pass that performs constant value elision.
|
* Function pass that performs constant value elision.
|
||||||
*/
|
*/
|
||||||
class ElideConstantValuePass
|
class ElideConstantValuePass
|
||||||
: public mlir::FunctionPass<ElideConstantValuePass> {
|
: public PassWrapper<ElideConstantValuePass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto function = getFunction();
|
auto function = getFunction();
|
||||||
|
@ -67,7 +67,7 @@ public:
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<ConstantValueElision>(&getContext());
|
patterns.insert<ConstantValueElision>(&getContext());
|
||||||
|
|
||||||
applyPatternsGreedily(function, patterns);
|
applyPatternsAndFoldGreedily(function, patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace {
|
||||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
#include "src/Transform/ONNX/ONNXDecompose.inc"
|
#include "src/Transform/ONNX/ONNXDecompose.inc"
|
||||||
|
|
||||||
struct DecomposeONNXToONNXPass : public FunctionPass<DecomposeONNXToONNXPass> {
|
struct DecomposeONNXToONNXPass : public PassWrapper<DecomposeONNXToONNXPass, FunctionPass> {
|
||||||
void runOnFunction() final;
|
void runOnFunction() final;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace {
|
||||||
* candidate operations and propagating the shape information until the list
|
* candidate operations and propagating the shape information until the list
|
||||||
* of operations is empty [credit MLIR authors].
|
* of operations is empty [credit MLIR authors].
|
||||||
*/
|
*/
|
||||||
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
class ShapeInferencePass : public mlir::PassWrapper<ShapeInferencePass, mlir::FunctionPass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto f = getFunction();
|
auto f = getFunction();
|
||||||
|
|
|
@ -3,7 +3,7 @@ configure_file(test_config.py.in test_config.py)
|
||||||
|
|
||||||
find_package(PythonInterp 3 REQUIRED)
|
find_package(PythonInterp 3 REQUIRED)
|
||||||
add_custom_target(check-onnx-backend
|
add_custom_target(check-onnx-backend
|
||||||
COMMAND LD_PRELOAD=${CMAKE_BINARY_DIR}/lib/libcruntime.so ${PYTHON_EXECUTABLE}
|
COMMAND LD_PRELOAD=$<TARGET_FILE:cruntime> ${PYTHON_EXECUTABLE}
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/test.py)
|
${CMAKE_CURRENT_BINARY_DIR}/test.py)
|
||||||
|
|
||||||
add_dependencies(check-onnx-backend onnx-mlir)
|
add_dependencies(check-onnx-backend onnx-mlir)
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
git clone https://github.com/llvm/llvm-project.git
|
git clone https://github.com/llvm/llvm-project.git
|
||||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||||
cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd ..
|
cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..
|
||||||
|
|
Loading…
Reference in New Issue