Add selected MLIR options (#213)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Add selected MLIR passes.

* Refactor.

* Fix descriptions.

* Refactor.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-07-09 18:12:57 -04:00 committed by GitHub
parent c557bd3444
commit af75b4c75e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 0 deletions

View File

@ -152,6 +152,7 @@ find_mlir_lib(MLIRLinalgAnalysis)
find_mlir_lib(MLIRLinalgTransforms) find_mlir_lib(MLIRLinalgTransforms)
find_mlir_lib(MLIRLinalgUtils) find_mlir_lib(MLIRLinalgUtils)
find_mlir_lib(MLIRSCF) find_mlir_lib(MLIRSCF)
find_mlir_lib(MLIRSCFTransforms)
find_mlir_lib(MLIRLLVMIRTransforms) find_mlir_lib(MLIRLLVMIRTransforms)
find_mlir_lib(MLIRMlirOptMain) find_mlir_lib(MLIRMlirOptMain)
find_mlir_lib(MLIRParser) find_mlir_lib(MLIRParser)
@ -171,6 +172,8 @@ find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRTranslation) find_mlir_lib(MLIRTranslation)
find_mlir_lib(MLIRVector) find_mlir_lib(MLIRVector)
find_mlir_lib(MLIRVectorToLLVM)
find_mlir_lib(MLIRVectorToSCF)
find_mlir_lib(MLIRMlirOptMain) find_mlir_lib(MLIRMlirOptMain)
find_mlir_lib(MLIRAffineEDSC) find_mlir_lib(MLIRAffineEDSC)
find_mlir_lib(MLIRLinalgEDSC) find_mlir_lib(MLIRLinalgEDSC)
@ -203,6 +206,8 @@ set(MLIRLibs
${MLIRTransforms} ${MLIRTransforms}
${MLIRSCFToStandard} ${MLIRSCFToStandard}
${MLIRVector} ${MLIRVector}
${MLIRVectorToLLVM}
${MLIRVectorToSCF}
${MLIRSCF} ${MLIRSCF}
${MLIRIR} ${MLIRIR}
${MLIRLLVMIR} ${MLIRLLVMIR}
@ -226,6 +231,7 @@ set(MLIRLibs
${MLIRLLVMIRTransforms} ${MLIRLLVMIRTransforms}
${MLIRSCFToStandard} ${MLIRSCFToStandard}
${MLIRSCF} ${MLIRSCF}
${MLIRSCFTransforms}
${MLIRLoopAnalysis} ${MLIRLoopAnalysis}
${MLIRLoopLikeInterface} ${MLIRLoopLikeInterface}
${MLIROpenMP} ${MLIROpenMP}

25
src/InitMLIRPasses.hpp Normal file
View File

@ -0,0 +1,25 @@
#include "mlir/Pass/Pass.h"
namespace onnx_mlir {
void initMLIRPasses() {
mlir::registerPass("convert-vector-to-scf", "Convert vector to SCF.",
[]() -> std::unique_ptr<mlir::Pass> {
return mlir::createConvertVectorToSCFPass();
});
mlir::registerPass("lower-affine",
"Lower Affine Dialect to Standard Dialect.",
[]() -> std::unique_ptr<mlir::Pass> {
return mlir::createLowerAffinePass();
});
mlir::registerPass("convert-scf-to-std", "Lower SCF to Standard Dialect.",
[]() -> std::unique_ptr<mlir::Pass> {
return mlir::createLowerToCFGPass();
});
mlir::registerPass("convert-vector-to-llvm",
"Lower Vector Dialect to LLVM IR Dialect.",
[]() -> std::unique_ptr<mlir::Pass> {
return mlir::createConvertVectorToLLVMPass();
});
}
} // namespace onnx_mlir

View File

@ -24,6 +24,7 @@
#include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/InitMLIRPasses.hpp"
#include "src/InitOMPasses.hpp" #include "src/InitOMPasses.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
@ -64,6 +65,7 @@ int main(int argc, char **argv) {
mlir::registerDialect<mlir::LLVM::LLVMDialect>(); mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::scf::SCFDialect>(); mlir::registerDialect<mlir::scf::SCFDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>(); mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::vector::VectorDialect>();
// Register transformation passes. // Register transformation passes.
#define GEN_PASS_REGISTRATION #define GEN_PASS_REGISTRATION
@ -75,11 +77,15 @@ int main(int argc, char **argv) {
#define GEN_PASS_REGISTRATION #define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Linalg/Passes.h.inc" #include "mlir/Dialect/Linalg/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/SCF/Passes.h.inc"
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);
mlir::registerDialect<mlir::ONNXOpsDialect>(); mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>(); mlir::registerDialect<mlir::KrnlOpsDialect>();
initOMPasses(); initOMPasses();
initMLIRPasses();
mlir::registerAsmPrinterCLOptions(); mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions(); mlir::registerMLIRContextCLOptions();