diff --git a/.circleci/config.yml b/.circleci/config.yml index 7314f1c..5a8c23f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -18,7 +18,7 @@ jobs: git submodule update --init --recursive # Use cached mlir installation if possible. - restore_cache: - key: V10-1-LLVM-PROJECT-{{ arch }} + key: V11-LLVM-PROJECT-{{ arch }} - run: name: Install MLIR command: | @@ -29,7 +29,7 @@ jobs: source onnx-mlir/utils/install-mlir.sh fi - save_cache: - key: V10-1-LLVM-PROJECT-{{ arch }} + key: V11-LLVM-PROJECT-{{ arch }} paths: - llvm-project - run: diff --git a/.gitignore b/.gitignore index 7f8814f..6c866a7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +.idea/ +cmake-*/ + # Prerequisites *.d diff --git a/MLIR.cmake b/MLIR.cmake index bacb967..2dbb7de 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -132,8 +132,9 @@ function(find_mlir_lib lib) endif() endfunction(find_mlir_lib) -find_mlir_lib(MLIRAffine) +find_mlir_lib(MLIRAffineOps) find_mlir_lib(MLIRAffineToStandard) +find_mlir_lib(MLIRAffineTransforms) find_mlir_lib(MLIRAnalysis) find_mlir_lib(MLIRCallInterfaces) find_mlir_lib(MLIRControlFlowInterfaces) @@ -193,8 +194,9 @@ set(MLIRLibs ${MLIRTargetLLVMIRModuleTranslation} ${MLIRTransforms} ${MLIRTransformUtils} - ${MLIRAffine} + ${MLIRAffineOps} ${MLIRAffineToStandard} + ${MLIRAffineTransforms} ${MLIRAnalysis} ${MLIRCallInterfaces} ${MLIRControlFlowInterfaces} @@ -244,14 +246,15 @@ set(MLIRLibs # must be specified on LD_PRELOAD for shared build. set(MLIRWholeArchiveLibs MLIRAffineToStandard - MLIRAffine + MLIRAffineOps MLIRLLVMIR MLIRStandardOps MLIRStandardToLLVM MLIRTransforms MLIRLoopToStandard MLIRVector - MLIRLoopOps) + MLIRLoopOps + MLIRIR) # ONNX MLIR libraries that must be linked with --whole-archive for static build or # must be specified on LD_PRELOAD for shared build. diff --git a/README.md b/README.md index 7e12e65..1293ba0 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. +cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) @@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. +cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..73c79c6 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,2 @@ +_site +Gemfile* diff --git a/docs/HowToAddAnOperation.md b/docs/HowToAddAnOperation.md index 7a4186e..d70e0c4 100644 --- a/docs/HowToAddAnOperation.md +++ b/docs/HowToAddAnOperation.md @@ -44,7 +44,7 @@ Once it is invoked, you will need to add literal tests in ` test/mlir/onnx/onnx_ Files related to the lowering of the new operations resides in the `src/Conversion/ONNXtoKRNL` directory and subdirectories. For the `concat` operation, we added code to lower it to krnl dialect in the ` src/Conversion/ONNXToKrnl/Tensor/concat.cpp` file. See other similar lowering for inspiration. We suggest to use `assert` statements for any unexpected values encountered while lowering the operation, as illegal parameter values should be caught in the shape inference phase, not successive passes such as lowering to the krnl dialect. -In that file, the `populateLoweringONNXConcatOpPattern` operation (where `Concat` would be replaced with the actual new operation) will need to be defined in ` src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp` and invoked in the ` runOnModule` function in the ` src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp` file. +In that file, the `populateLoweringONNXConcatOpPattern` operation (where `Concat` would be replaced with the actual new operation) will need to be defined in ` src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp` and invoked in the ` runOnOperation` function in the ` src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp` file. To compile properly, you will also need to add the new `.cpp` file in the ` src/Conversion/ONNXToKrnl/CMakeLists.txt` file. diff --git a/docs/README.md b/docs/README.md index 7e12e65..1293ba0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. +cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) @@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. +cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) diff --git a/docs/doc_check/CMakeLists.txt b/docs/doc_check/CMakeLists.txt index 95eec9d..bd1dedf 100644 --- a/docs/doc_check/CMakeLists.txt +++ b/docs/doc_check/CMakeLists.txt @@ -4,5 +4,8 @@ add_custom_target(check-doc COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/check.py ${ONNX_MLIR_SRC_ROOT} - --exclude_dirs third_party docs/doc_check/test) + --exclude_dirs + third_party + docs/doc_check/test + docs/_site) diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index be67331..95750a5 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -41,13 +41,13 @@ public: /// This is a partial lowering to Krnl loops of the ONNX operations. namespace { struct FrontendToKrnlLoweringPass - : public ModulePass { - void runOnModule() final; + : public PassWrapper> { + void runOnOperation() final; }; } // end anonymous namespace. -void FrontendToKrnlLoweringPass::runOnModule() { - ModuleOp module = getModule(); +void FrontendToKrnlLoweringPass::runOnOperation() { + ModuleOp module = getOperation(); // The first thing to define is the conversion target. This will define the // final target for this lowering. diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index d008f10..5d25081 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -456,11 +456,11 @@ struct ONNXPoolOpLowering : public ConversionPattern { } dimExpr.emplace_back(de); } - poolDimMap = AffineMap::get(1, 5, dimExpr); + poolDimMap = AffineMap::get(1, 5, dimExpr, rewriter.getContext()); // poolStartMap and poolEndMap - poolStartMap = AffineMap::get(1, 5, {start1, start2}); - poolEndMap = AffineMap::get(1, 5, {end1, end2}); + poolStartMap = AffineMap::get(1, 5, {start1, start2}, rewriter.getContext()); + poolEndMap = AffineMap::get(1, 5, {end1, end2}, rewriter.getContext()); } // Obtain values from the affine maps. diff --git a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp index 5c46194..83f601f 100644 --- a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp +++ b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp @@ -11,7 +11,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -23,28 +25,27 @@ using namespace onnx_mlir; -static llvm::cl::opt input_filename(llvm::cl::Positional, - llvm::cl::desc(""), - llvm::cl::init("-")); +// TODO(tjingrant): disable the following namespace import. +using namespace mlir; -static llvm::cl::opt - output_filename("o", llvm::cl::desc("Output filename"), - llvm::cl::value_desc("filename"), llvm::cl::init("-")); +static llvm::cl::opt input_filename( + llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-")); -static llvm::cl::opt split_input_file( - "split-input-file", +static llvm::cl::opt output_filename("o", + llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + +static llvm::cl::opt split_input_file("split-input-file", llvm::cl::desc("Split the input file into pieces and process each " "chunk independently"), llvm::cl::init(false)); -static llvm::cl::opt verify_diagnostics( - "verify-diagnostics", +static llvm::cl::opt verify_diagnostics("verify-diagnostics", llvm::cl::desc("Check that emitted diagnostics match " "expected-* lines on the corresponding line"), llvm::cl::init(false)); -static llvm::cl::opt verify_passes( - "verify-each", +static llvm::cl::opt verify_passes("verify-each", llvm::cl::desc("Run the verifier after each transformation pass"), llvm::cl::init(true)); @@ -58,16 +59,23 @@ int main(int argc, char **argv) { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); + + // Register transformation passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Transforms/Passes.h.inc" + llvm::InitLLVM y(argc, argv); mlir::registerDialect(); mlir::registerDialect(); + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); // Register any pass manager command line options. mlir::registerPassManagerCLOptions(); mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); - llvm::cl::ParseCommandLineOptions(argc, argv, - "ONNX MLIR modular optimizer driver\n"); + llvm::cl::ParseCommandLineOptions( + argc, argv, "ONNX MLIR modular optimizer driver\n"); // Set up the input file. std::string error_message; @@ -78,6 +86,6 @@ int main(int argc, char **argv) { assert(output); return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline, - split_input_file, verify_diagnostics, - verify_passes, allowUnregisteredDialects)); + split_input_file, verify_diagnostics, verify_passes, + allowUnregisteredDialects)); } diff --git a/src/Transform/ElideKrnlGlobalConstants.cpp b/src/Transform/ElideKrnlGlobalConstants.cpp index 3f61e35..73a0bbe 100644 --- a/src/Transform/ElideKrnlGlobalConstants.cpp +++ b/src/Transform/ElideKrnlGlobalConstants.cpp @@ -36,13 +36,13 @@ class KrnlConstGlobalValueElision : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(KrnlGlobalOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + KrnlGlobalOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - + if (op.value().hasValue()) { auto newGlobalOp = rewriter.create( - loc, op.getResult().getType(), op.shape(), op.name(), nullptr); + loc, op.getResult().getType(), op.shape(), op.name(), nullptr); rewriter.replaceOp(op, newGlobalOp.getResult()); } @@ -54,7 +54,7 @@ public: * Function pass that performs constant value elision of Krnl globals. */ class ElideConstGlobalValuePass - : public mlir::FunctionPass { + : public PassWrapper { public: void runOnFunction() override { auto function = getFunction(); @@ -63,7 +63,7 @@ public: OwningRewritePatternList patterns; patterns.insert(&getContext()); - applyPatternsGreedily(function, patterns); + applyPatternsAndFoldGreedily(function, patterns); } }; } // namespace diff --git a/src/Transform/LowerKrnl.cpp b/src/Transform/LowerKrnl.cpp index 552f0f3..79e3803 100644 --- a/src/Transform/LowerKrnl.cpp +++ b/src/Transform/LowerKrnl.cpp @@ -149,7 +149,7 @@ public: /// add and multiply, this pass will leave these operations intact. namespace { struct KrnlToAffineLoweringPass - : public FunctionPass { + : public PassWrapper { void runOnFunction() final; }; } // end anonymous namespace. diff --git a/src/Transform/LowerToLLVM.cpp b/src/Transform/LowerToLLVM.cpp index 1f88bd2..666955f 100644 --- a/src/Transform/LowerToLLVM.cpp +++ b/src/Transform/LowerToLLVM.cpp @@ -16,8 +16,8 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Pass/Pass.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" @@ -94,14 +94,13 @@ static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter, class KrnlGlobalOpLowering : public ConvertToLLVMPattern { public: - explicit KrnlGlobalOpLowering(MLIRContext *context, - LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(KrnlGlobalOp::getOperationName(), context, - lowering_) {} + explicit KrnlGlobalOpLowering( + MLIRContext *context, LLVMTypeConverter &lowering_) + : ConvertToLLVMPattern( + KrnlGlobalOp::getOperationName(), context, lowering_) {} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto *context = op->getContext(); auto loc = op->getLoc(); auto *llvmDialect = @@ -119,7 +118,7 @@ public: // Compute total number of elements. auto shape = (krnlGlobalOp.shape()).dyn_cast(); int64_t numElements = 1; - for (int i=0; i= 0; i--) + for (int i = shape.size() - 1; i >= 0; i--) globalType = LLVM::LLVMType::getArrayTy( globalType.cast(), ArrayAttrIntVal(shape, i)); // The llvm type of the global (example: [2 x [8 x float]]) @@ -145,7 +144,6 @@ public: assert(krnlGlobalOp.value().hasValue() && "Krnl Global must always have a value"); - global = rewriter.create(loc, llvmGlobalType, /*isConstant=*/true, LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue()); @@ -157,47 +155,46 @@ public: // Allocate the memory where the constants will be used from. // This is a region of local memory and needs to be emitted as an alloca. - auto one = rewriter.create(loc, - llvmI64Ty, rewriter.getI64IntegerAttr(1)); + auto one = rewriter.create( + loc, llvmI64Ty, rewriter.getI64IntegerAttr(1)); auto alloc = rewriter.create( loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0); // Copy constant value into the local alloca: // - Bitcast alloc to i8* - Value int8PtrAlloc = rewriter.create( - loc, llvmI8PtrTy, alloc); + Value int8PtrAlloc = + rewriter.create(loc, llvmI8PtrTy, alloc); // - Bitcast global to i8* Value globalValue = rewriter.create(loc, global); - Value i8PtrGlobal = rewriter.create( - loc, llvmI8PtrTy, globalValue); + Value i8PtrGlobal = + rewriter.create(loc, llvmI8PtrTy, globalValue); // - Set size. - Value memRefElementSize = rewriter.create(loc, - llvmI64Ty, rewriter.getI64IntegerAttr( - getMemRefEltSizeInBytes(memRefTy))); + Value memRefElementSize = rewriter.create(loc, llvmI64Ty, + rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy))); Value numElementsValue = rewriter.create( loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements)); - Value totalElementsSize = rewriter.create( - loc, memRefElementSize, numElementsValue); - Value int64Size = rewriter.create( - loc, llvmI64Ty, totalElementsSize); + Value totalElementsSize = + rewriter.create(loc, memRefElementSize, numElementsValue); + Value int64Size = + rewriter.create(loc, llvmI64Ty, totalElementsSize); // - Set volatile. - Value isVolatile = rewriter.create( - loc, LLVM::LLVMType::getInt1Ty(llvmDialect), + Value isVolatile = rewriter.create(loc, + LLVM::LLVMType::getInt1Ty(llvmDialect), rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); // - Copy constant data into the alloca. auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect); - rewriter.create( - loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), + rewriter.create(loc, memcpyRef, + LLVM::LLVMType::getVoidTy(llvmDialect), ArrayRef({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile})); // Prepare data to be inserted into MemRef. auto llvmConstantElementType = constantElementType.cast(); Value typedAlloc = rewriter.create( - loc, llvmConstantElementType.getPointerTo(), alloc); + loc, llvmConstantElementType.getPointerTo(), alloc); // Create llvm MemRef from original MemRef and fill the data pointers. auto llvmMemRef = MemRefDescriptor::fromStaticShape( - rewriter, loc, typeConverter, memRefTy, typedAlloc); + rewriter, loc, typeConverter, memRefTy, typedAlloc); rewriter.replaceOp(op, {llvmMemRef}); return success(); @@ -219,7 +216,7 @@ public: : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { + ConversionPatternRewriter &rewriter) const override { auto *context = op->getContext(); KrnlMemcpyOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); @@ -605,16 +602,18 @@ private: //===----------------------------------------------------------------------===// namespace { -struct KrnlToLLVMLoweringPass : public ModulePass { - void runOnModule() final; +struct KrnlToLLVMLoweringPass + : public PassWrapper> { + void runOnOperation() final; }; } // end anonymous namespace -void KrnlToLLVMLoweringPass::runOnModule() { +void KrnlToLLVMLoweringPass::runOnOperation() { // Define the target for this lowering i.e. the LLVM dialect. ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); + // target.addLegalOp(); // Lower the MemRef types to a representation in LLVM. LLVMTypeConverter typeConverter(&getContext()); @@ -625,8 +624,8 @@ void KrnlToLLVMLoweringPass::runOnModule() { populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(typeConverter, patterns, - /*useAlloca=*/false, - /*emitCWrapper=*/true); + /*emitCWrapperS=*/true, + /*useAlignedAlloc=*/false); patterns.insert(&getContext(), typeConverter); @@ -636,8 +635,8 @@ void KrnlToLLVMLoweringPass::runOnModule() { // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. - if (failed( - applyFullConversion(getModule(), target, patterns, &typeConverter))) + if (failed(applyFullConversion( + getOperation(), target, patterns, &typeConverter))) signalPassFailure(); } diff --git a/src/Transform/ONNX/AttributePromotion.cpp b/src/Transform/ONNX/AttributePromotion.cpp index 7abf2f8..8420740 100644 --- a/src/Transform/ONNX/AttributePromotion.cpp +++ b/src/Transform/ONNX/AttributePromotion.cpp @@ -39,7 +39,7 @@ void getOrCreateNoneValue(llvm::Optional &none, FuncOp f) { * desirable (as instructed by the PromotableConstOperandsOpInterface). */ class AttributePromotionPass - : public mlir::FunctionPass { + : public mlir::PassWrapper { public: void runOnFunction() override { auto f = getFunction(); @@ -75,7 +75,7 @@ public: OwningRewritePatternList patterns; auto *context = &getContext(); ConstantOp::getCanonicalizationPatterns(patterns, context); - applyPatternsGreedily(f, patterns); + applyPatternsAndFoldGreedily(f, patterns); } }; } // end anonymous namespace diff --git a/src/Transform/ONNX/ElideConstants.cpp b/src/Transform/ONNX/ElideConstants.cpp index b8fc493..7f09a91 100644 --- a/src/Transform/ONNX/ElideConstants.cpp +++ b/src/Transform/ONNX/ElideConstants.cpp @@ -37,8 +37,8 @@ class ConstantValueElision : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ONNXConstantOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + ONNXConstantOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto constOp = llvm::dyn_cast(&op); @@ -58,7 +58,7 @@ public: * Function pass that performs constant value elision. */ class ElideConstantValuePass - : public mlir::FunctionPass { + : public PassWrapper { public: void runOnFunction() override { auto function = getFunction(); @@ -67,7 +67,7 @@ public: OwningRewritePatternList patterns; patterns.insert(&getContext()); - applyPatternsGreedily(function, patterns); + applyPatternsAndFoldGreedily(function, patterns); } }; } // end anonymous namespace diff --git a/src/Transform/ONNX/ONNXDecompose.cpp b/src/Transform/ONNX/ONNXDecompose.cpp index 1145a73..78dc07a 100644 --- a/src/Transform/ONNX/ONNXDecompose.cpp +++ b/src/Transform/ONNX/ONNXDecompose.cpp @@ -27,7 +27,7 @@ namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Transform/ONNX/ONNXDecompose.inc" -struct DecomposeONNXToONNXPass : public FunctionPass { +struct DecomposeONNXToONNXPass : public PassWrapper { void runOnFunction() final; }; } // end anonymous namespace. diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 2188628..ec8112b 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -25,7 +25,7 @@ namespace { * candidate operations and propagating the shape information until the list * of operations is empty [credit MLIR authors]. */ -class ShapeInferencePass : public mlir::FunctionPass { +class ShapeInferencePass : public mlir::PassWrapper { public: void runOnFunction() override { auto f = getFunction(); diff --git a/test/backend/CMakeLists.txt b/test/backend/CMakeLists.txt index 39075b3..4e4c3ce 100644 --- a/test/backend/CMakeLists.txt +++ b/test/backend/CMakeLists.txt @@ -3,7 +3,7 @@ configure_file(test_config.py.in test_config.py) find_package(PythonInterp 3 REQUIRED) add_custom_target(check-onnx-backend - COMMAND LD_PRELOAD=${CMAKE_BINARY_DIR}/lib/libcruntime.so ${PYTHON_EXECUTABLE} + COMMAND LD_PRELOAD=$ ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/test.py) add_dependencies(check-onnx-backend onnx-mlir) diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index b47c69c..99858da 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 07e462526d0cbae40b320e1a4307ce11e197fb0a && cd .. +cd llvm-project && git checkout 3ce0ad1b336e67a76d78ae7ff7d66fe127586620 && cd ..