diff --git a/.circleci/config.yml b/.circleci/config.yml index f5e283c..1b97883 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -18,7 +18,7 @@ jobs: git submodule update --init --recursive # Use cached mlir installation if possible. - restore_cache: - key: V15-LLVM-PROJECT-{{ arch }} + key: V17-LLVM-PROJECT-{{ arch }} - run: name: Install MLIR command: | @@ -29,7 +29,7 @@ jobs: source onnx-mlir/utils/install-mlir.sh fi - save_cache: - key: V15-LLVM-PROJECT-{{ arch }} + key: V17-LLVM-PROJECT-{{ arch }} paths: - llvm-project - run: diff --git a/.travis.yml b/.travis.yml index 5be1f34..15cfbc2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,7 +11,7 @@ script: - echo "CPU Architecture is " $TRAVIS_CPU_ARCH - echo "commit is " $TRAVIS_COMMIT - df -h - - if [ "$TRAVIS_COMMIT_MESSAGE" == "Update MLIR version" ]; then + - if true; then echo "Building Prereq"; docker build --tag onnxmlirczar/onnx-mlir-llvmimage:$TRAVIS_CPU_ARCH -f ./docker/prereq.$TRAVIS_CPU_ARCH.Dockerfile ./utils; docker login -u onnxmlirczar -p 143f1da2-332f-45a1-8587-d6cb07c13230 diff --git a/MLIR.cmake b/MLIR.cmake index 41d0050..ab8a1a5 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -133,11 +133,13 @@ function(find_mlir_lib lib) endfunction(find_mlir_lib) find_mlir_lib(MLIRAffineOps) +find_mlir_lib(MLIRAffineUtils) find_mlir_lib(MLIRAffineToStandard) find_mlir_lib(MLIRAffineTransforms) find_mlir_lib(MLIRAnalysis) find_mlir_lib(MLIRCallInterfaces) find_mlir_lib(MLIRControlFlowInterfaces) +find_mlir_lib(MLIRCopyOpInterface) find_mlir_lib(MLIRDialect) find_mlir_lib(MLIREDSC) find_mlir_lib(MLIRExecutionEngine) @@ -158,6 +160,7 @@ find_mlir_lib(MLIRMlirOptMain) find_mlir_lib(MLIRParser) find_mlir_lib(MLIRPass) find_mlir_lib(MLIRStandardOps) +find_mlir_lib(MLIRStandardOpsTransforms) find_mlir_lib(MLIRStandardToLLVM) find_mlir_lib(MLIRSideEffectInterfaces) find_mlir_lib(MLIRTargetLLVMIR) @@ -172,12 +175,14 @@ find_mlir_lib(MLIRTargetLLVMIR) find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRTranslation) find_mlir_lib(MLIRVector) +find_mlir_lib(MLIRVectorInterfaces) find_mlir_lib(MLIRVectorToLLVM) find_mlir_lib(MLIRVectorToSCF) find_mlir_lib(MLIRMlirOptMain) find_mlir_lib(MLIRAffineEDSC) find_mlir_lib(MLIRLinalgEDSC) find_mlir_lib(MLIRViewLikeInterface) +find_mlir_lib(MLIRPresburger) find_mlir_lib(LLVMCore) find_mlir_lib(LLVMSupport) @@ -200,12 +205,16 @@ find_mlir_lib(LLVMFrontendOpenMP) set(MLIRLibs ${MLIRAffineToStandard} ${MLIRAffineOps} + ${MLIRAffineUtils} + ${MLIRCopyOpInterface} ${MLIRLLVMIR} ${MLIRStandardOps} + ${MLIRStandardOpsTransforms} ${MLIRStandardToLLVM} ${MLIRTransforms} ${MLIRSCFToStandard} ${MLIRVector} + ${MLIRVectorInterfaces} ${MLIRVectorToLLVM} ${MLIRVectorToSCF} ${MLIRSCF} @@ -249,6 +258,7 @@ set(MLIRLibs ${MLIRAffineEDSC} ${MLIRLinalgEDSC} ${MLIRViewLikeInterface} + ${MLIRPresburger} # strict order verified ${LLVMBitWriter} ${LLVMObject} diff --git a/README.md b/README.md index 977bdb1..3d803db 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd .. +cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) @@ -152,7 +152,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd .. +cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) diff --git a/docs/README.md b/docs/README.md index 80ac015..e9cf1f4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd .. +cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) @@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd .. +cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) diff --git a/src/Conversion/KrnlToAffine/KrnlToAffine.cpp b/src/Conversion/KrnlToAffine/KrnlToAffine.cpp index a0ae2de..5c9626a 100644 --- a/src/Conversion/KrnlToAffine/KrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/KrnlToAffine.cpp @@ -31,7 +31,7 @@ public: LogicalResult matchAndRewrite( KrnlTerminatorOp op, PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op); + rewriter.replaceOpWithNewOp(op); return success(); } }; @@ -233,7 +233,7 @@ void ConvertKrnlToAffinePass::runOnFunction() { ConversionTarget target(getContext()); target.addIllegalOp(); - target.addLegalOp(); + target.addLegalOp(); OwningRewritePatternList patterns; patterns.insert(&getContext()); DenseSet unconverted; diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp index c3f5a60..5ca2a89 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp @@ -105,17 +105,17 @@ static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) { /// Return a symbol reference to the memcpy function, inserting it into the /// module if necessary. -static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter, - ModuleOp module, LLVM::LLVMDialect *llvmDialect) { +static FlatSymbolRefAttr getOrInsertMemcpy( + PatternRewriter &rewriter, ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol("llvm.memcpy.p0i8.p0i8.i64")) return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); // Create a function declaration for memcpy, the signature is: // * `void (i8*, i8* , i64, i1)` - auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect); - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); - auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); - auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); + auto llvmVoidTy = LLVM::LLVMType::getVoidTy(context); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context); + auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(context); auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy, ArrayRef( {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), @@ -144,9 +144,6 @@ public: ConversionPatternRewriter &rewriter) const override { auto *context = op->getContext(); auto loc = op->getLoc(); - auto *llvmDialect = - op->getContext()->getRegisteredDialect(); - assert(llvmDialect && "expected llvm dialect to be registered"); KrnlGetRefOpAdaptor operandAdaptor(operands); @@ -209,9 +206,6 @@ public: ConversionPatternRewriter &rewriter) const override { auto *context = op->getContext(); auto loc = op->getLoc(); - auto *llvmDialect = - op->getContext()->getRegisteredDialect(); - assert(llvmDialect && "expected llvm dialect to be registered"); auto krnlGlobalOp = llvm::dyn_cast(op); @@ -258,8 +252,8 @@ public: } // Some frequently used types. - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); - auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context); // Allocate the memory where the constants will be used from. // This is a region of local memory and needs to be emitted as an alloca. @@ -288,17 +282,17 @@ public: rewriter.create(loc, llvmI64Ty, totalElementsSize); // - Set volatile. Value isVolatile = rewriter.create(loc, - LLVM::LLVMType::getInt1Ty(llvmDialect), + LLVM::LLVMType::getInt1Ty(context), rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); // - Copy constant data into the alloca. - auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect); + auto memcpyRef = getOrInsertMemcpy(rewriter, module); rewriter.create(loc, memcpyRef, - LLVM::LLVMType::getVoidTy(llvmDialect), + LLVM::LLVMType::getVoidTy(context), ArrayRef({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile})); } else { // Some frequently used types. - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); - auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context); // Allocate the memory where the constants will be used from. // This is a region of local memory and needs to be emitted as an alloca. @@ -351,13 +345,10 @@ public: auto *context = op->getContext(); KrnlMemcpyOpAdaptor operandAdaptor(operands); auto loc = op->getLoc(); - auto *llvmDialect = - op->getContext()->getRegisteredDialect(); - assert(llvmDialect && "expected llvm dialect to be registered"); // Get a symbol reference to the memcpy function, inserting it if necessary. ModuleOp parentModule = op->getParentOfType(); - auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect); + auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule); // First operand. Type dstType = operandAdaptor.dest() @@ -367,7 +358,7 @@ public: Value alignedDstMemory = rewriter.create( loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1)); Value alignedInt8PtrDstMemory = rewriter.create( - loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); + loc, LLVM::LLVMType::getInt8PtrTy(context), alignedDstMemory); // Second operand. Type srcType = operandAdaptor.src() @@ -377,20 +368,19 @@ public: Value alignedSrcMemory = rewriter.create( loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1)); Value alignedInt8PtrSrcMemory = rewriter.create( - loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); + loc, LLVM::LLVMType::getInt8PtrTy(context), alignedSrcMemory); // Size. Value int64Size = rewriter.create( - loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operandAdaptor.size()); + loc, LLVM::LLVMType::getInt64Ty(context), operandAdaptor.size()); // Is volatile (set to false). Value isVolatile = rewriter.create(loc, - LLVM::LLVMType::getInt1Ty(llvmDialect), + LLVM::LLVMType::getInt1Ty(context), rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); // Memcpy call - rewriter.create(loc, memcpyRef, - LLVM::LLVMType::getVoidTy(llvmDialect), + rewriter.create(loc, memcpyRef, LLVM::LLVMType::getVoidTy(context), ArrayRef({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size, isVolatile})); @@ -441,19 +431,17 @@ public: LogicalResult matchAndRewrite( KrnlEntryPointOp op, PatternRewriter &rewriter) const override { - auto *llvmDialect = - op.getContext()->getRegisteredDialect(); - assert(llvmDialect && "expected llvm dialect to be registered"); auto module = op.getParentOfType(); - auto apiRegistry = RegisterAllApis(module, rewriter, llvmDialect); + auto *context = module.getContext(); + auto apiRegistry = RegisterAllApis(module, rewriter); auto loc = op.getLoc(); auto numOutputs = op.getAttrOfType(KrnlEntryPointOp::getNumOutputsAttrName()) .getInt(); using LLVMType = LLVM::LLVMType; - auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect); - auto int32Ty = LLVMType::getInt32Ty(llvmDialect); + auto opaquePtrTy = LLVMType::getInt8PtrTy(context); + auto int32Ty = LLVMType::getInt32Ty(context); // Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic // signature. The signature is dynamic because it remains the same no matter @@ -514,7 +502,7 @@ public: // Fill in the memref underlying ptrToMemRef with information extracted // from dynMemRef. fillPtrToMemRefWithRtMemRef( - dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect); + dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, module); // ptrToMemRef will be an input to main computation graph function. staticInputs.emplace_back(ptrToMemRef); @@ -565,7 +553,7 @@ public: auto outRtMemRef = callApi(rewriter, loc, apiRegistry, API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); fillRtMemRefWithMemRef( - memRef, outRtMemRef, rewriter, loc, apiRegistry, llvmDialect); + memRef, outRtMemRef, rewriter, loc, apiRegistry, module); auto idx = rewriter.create( loc, int32Ty, rewriter.getI32IntegerAttr(i)); callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, @@ -580,13 +568,14 @@ public: private: using ApiRegistry = std::map; - ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter, - LLVM::LLVMDialect *llvmDialect) const { + ApiRegistry RegisterAllApis( + ModuleOp &module, PatternRewriter &rewriter) const { + auto *context = module.getContext(); using LLVMType = LLVM::LLVMType; - auto voidTy = LLVMType::getVoidTy(llvmDialect); - auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect); - auto int32Ty = LLVMType::getInt32Ty(llvmDialect); - auto int64Ty = LLVMType::getInt64Ty(llvmDialect); + auto voidTy = LLVMType::getVoidTy(context); + auto opaquePtrTy = LLVMType::getInt8PtrTy(context); + auto int32Ty = LLVMType::getInt32Ty(context); + auto int64Ty = LLVMType::getInt64Ty(context); auto int64PtrTy = int64Ty.getPointerTo(); // Declare API type as an enum value, its string name and an LLVM Type @@ -646,11 +635,11 @@ private: void fillPtrToMemRefWithRtMemRef(Value &dynMemRef, Value &ptrToMemRef, PatternRewriter &rewriter, const Location &loc, - const std::map &apiRegistry, - LLVM::LLVMDialect *llvmDialect) const { + const std::map &apiRegistry, ModuleOp &module) const { + auto *context = module.getContext(); auto memRefPtrTy = ptrToMemRef.getType().dyn_cast(); auto memRefTy = memRefPtrTy.getPointerElementTy(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + auto int64Ty = LLVM::LLVMType::getInt64Ty(context); Value memRef = rewriter.create(loc, memRefPtrTy, ptrToMemRef); @@ -707,18 +696,18 @@ private: void fillRtMemRefWithMemRef(Value &outMemRef, Value &outRtMemRef, PatternRewriter &rewriter, const Location &loc, - const std::map &apiRegistry, - LLVM::LLVMDialect *llvmDialect) const { + const std::map &apiRegistry, ModuleOp &module) const { + auto *context = module.getContext(); auto outMemRefTy = outMemRef.getType().dyn_cast(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); - auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); + auto int64Ty = LLVM::LLVMType::getInt64Ty(context); + auto int32Ty = LLVM::LLVMType::getInt32Ty(context); // Extract the data pointer, and record it in dynamic mem ref created. Value outMemRefDataPtr = rewriter.create(loc, outMemRefTy.getStructElementType(0), outMemRef, rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); outMemRefDataPtr = rewriter.create( - loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr); + loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefDataPtr); callApi(rewriter, loc, apiRegistry, API::SET_DATA, {outRtMemRef, outMemRefDataPtr}); auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy(); @@ -776,15 +765,11 @@ public: ModuleOp module = op->getParentOfType(); auto loc = op->getLoc(); - auto *llvmDialect = - op->getContext()->getRegisteredDialect(); - assert(llvmDialect && "expected llvm dialect to be registered"); - auto packedConstOp = llvm::dyn_cast(op); LLVM::GlobalOp globalBase; // Some frequently used types. - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); - auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context); { OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); @@ -808,8 +793,7 @@ public: llvmI8PtrTy, {llvmI64Ty}, /*isVarArg=*/false), rewriter); auto constPackSize = rewriter.create(loc, - LLVM::LLVMType::getInt64Ty(llvmDialect), - packedConstOp.size_in_bytesAttr()); + LLVM::LLVMType::getInt64Ty(context), packedConstOp.size_in_bytesAttr()); Value alloc = rewriter .create(loc, getEmbeddedConstPoolRef, llvmI8PtrTy, ArrayRef({constPackSize})) @@ -822,14 +806,13 @@ public: // Record constant pack *file path* as a global variable (by recording the // file path string's underlying char array + its length). const auto &fileNameAttr = packedConstOp.file_nameAttr(); - auto type = - LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect), - fileNameAttr.getValue().size()); + auto type = LLVM::LLVMType::getArrayTy( + LLVM::LLVMType::getInt8Ty(context), fileNameAttr.getValue().size()); rewriter.create(loc, type, /*isConstant=*/true, LLVM::Linkage::External, mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName(), fileNameAttr); - type = LLVM::LLVMType::getInt64Ty(llvmDialect); + type = LLVM::LLVMType::getInt64Ty(context); rewriter.create(loc, type, /*isConstant=*/true, LLVM::Linkage::External, mlir::KrnlPackedConstantOp::getConstPackFilePathStrLenSymbolName(), @@ -840,18 +823,18 @@ public: auto constPackFileName = llvm::sys::path::filename(fileNameAttr.getValue()); type = LLVM::LLVMType::getArrayTy( - LLVM::LLVMType::getInt8Ty(llvmDialect), constPackFileName.size()); + LLVM::LLVMType::getInt8Ty(context), constPackFileName.size()); rewriter.create(loc, type, /*isConstant=*/true, LLVM::Linkage::External, mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName(), rewriter.getStringAttr(constPackFileName)); - type = LLVM::LLVMType::getInt64Ty(llvmDialect); + type = LLVM::LLVMType::getInt64Ty(context); rewriter.create(loc, type, /*isConstant=*/true, LLVM::Linkage::External, mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName(), rewriter.getI64IntegerAttr(constPackFileName.size())); - type = LLVM::LLVMType::getInt8Ty(llvmDialect); + type = LLVM::LLVMType::getInt8Ty(context); rewriter.create(loc, type, /*isConstant=*/true, LLVM::Linkage::External, mlir::KrnlPackedConstantOp::getConstPackIsLESymbolName(), @@ -889,14 +872,14 @@ void ConvertKrlnToLLVMPass::runOnOperation() { // Lower the MemRef types to a representation in LLVM. LowerToLLVMOptions options; options.emitCWrappers = true; - LLVMTypeConverter typeConverter(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), options); // We have a combination of `krnl`, `affine`, and `std` operations. We // lower in stages until all the code is in the LLVM dialect. OwningRewritePatternList patterns; populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); - populateStdToLLVMConversionPatterns(typeConverter, patterns, options); + populateStdToLLVMConversionPatterns(typeConverter, patterns); patterns.insert( &getContext(), typeConverter); diff --git a/src/Dialect/Krnl/KrnlTypes.hpp b/src/Dialect/Krnl/KrnlTypes.hpp index db75855..4adc1c6 100644 --- a/src/Dialect/Krnl/KrnlTypes.hpp +++ b/src/Dialect/Krnl/KrnlTypes.hpp @@ -23,7 +23,9 @@ enum Kinds { }; } -class LoopType : public mlir::Type::TypeBase { +class LoopType + : public mlir::Type::TypeBase { + public: using Base::Base; diff --git a/src/Dialect/ONNX/ONNXOps.hpp b/src/Dialect/ONNX/ONNXOps.hpp index 5ab42c6..f54a92d 100644 --- a/src/Dialect/ONNX/ONNXOps.hpp +++ b/src/Dialect/ONNX/ONNXOps.hpp @@ -67,7 +67,8 @@ enum Kind { }; } // namespace ONNXTypes -class StringType : public mlir::Type::TypeBase { +class StringType + : public mlir::Type::TypeBase { public: using Base::Base; static bool kindof(unsigned kind) { return kind == ONNXTypes::STRING; } diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index 20816f5..bebad23 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -262,8 +262,9 @@ void genLLVMBitcode(const mlir::OwningModuleRef &module, llvm::raw_fd_ostream moduleBitcodeStream( unoptimizedBitcodePath, error, llvm::sys::fs::F_None); - llvm::WriteBitcodeToFile( - *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); + llvm::LLVMContext llvmContext; + llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module, llvmContext), + moduleBitcodeStream); moduleBitcodeStream.flush(); // Use the LLVM's 'opt' command to optimize the bitcode. diff --git a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp index f3a4c27..65560e0 100644 --- a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp +++ b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp @@ -67,18 +67,11 @@ int main(int argc, char **argv) { mlir::registerDialect(); mlir::registerDialect(); - // Register transformation passes. -#define GEN_PASS_REGISTRATION -#include "mlir/Transforms/Passes.h.inc" - -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/Affine/Passes.h.inc" - -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/Linalg/Passes.h.inc" - -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/SCF/Passes.h.inc" + registerTransformsPasses(); + registerAffinePasses(); + registerLinalgPasses(); + registerSCFPasses(); + registerStandardPasses(); llvm::InitLLVM y(argc, argv); diff --git a/test/mlir/krnl/constant.mlir b/test/mlir/krnl/constant.mlir index 3c48d2f..884c210 100644 --- a/test/mlir/krnl/constant.mlir +++ b/test/mlir/krnl/constant.mlir @@ -6,16 +6,16 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> { %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () - // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) - // CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm<"[3 x [2 x float]]"> - // CHECK: llvm.func @test_constant({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { + // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) + // CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm.array<3 x array<2 x float>> + // CHECK: llvm.func @test_constant({{.*}}) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64 - // CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm<"[3 x [2 x float]]"> : (!llvm.i64) -> !llvm<"[3 x [2 x float]]*"> - // CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*"> + // CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm.array<3 x array<2 x float>> : (!llvm.i64) -> !llvm.ptr>> + // CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr>> to !llvm.ptr - // CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm<"[3 x [2 x float]]*"> - // CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*"> + // CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm.ptr>> + // CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm.ptr>> to !llvm.ptr /// Size of the constant tensor in bytes. // CHECK: [[CONST4:%.+]] = llvm.mlir.constant(4 : i64) : !llvm.i64 @@ -26,30 +26,30 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> { /// Volatile flag // CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1 - // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> !llvm.void /// Prepare data for MemRef insertion. - // CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"float*"> + // CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr>> to !llvm.ptr /// Insert the constant value in the local MemRef. - // CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> /// Insert offset. // CHECK: [[CONST00:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 - // CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> /// Insert sizes and strides. // CHECK: [[CONST3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64 - // CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 - // CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[CONST2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 - // CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 - // CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: llvm.return [[MEMREF5]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.return [[MEMREF5]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> } diff --git a/test/mlir/krnl/krnl_getref_lowering.mlir b/test/mlir/krnl/krnl_getref_lowering.mlir index da86731..b026b5e 100644 --- a/test/mlir/krnl/krnl_getref_lowering.mlir +++ b/test/mlir/krnl/krnl_getref_lowering.mlir @@ -11,16 +11,16 @@ func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> { // CHECK: [[CONST_10_0:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 // CHECK: [[CONST_10_1:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 // CHECK: [[MUL1:%.+]] = llvm.mul [[CONST_10_0]], [[CONST_10_1]] : !llvm.i64 - // CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.null : !llvm<"float*"> + // CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.null : !llvm.ptr // CHECK: %[[CONST_1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 - // CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> - // CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm<"float*"> to !llvm.i64 + // CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm.ptr to !llvm.i64 // CHECK: [[MUL2:%.+]] = llvm.mul [[MUL1]], [[ELEM_SIZE]] : !llvm.i64 - // CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[MUL2]]) : (!llvm.i64) -> !llvm<"i8*"> - // CHECK: [[TYPED_MEMPOOL:%.+]] = llvm.bitcast [[MEMPOOL]] : !llvm<"i8*"> to !llvm<"float*"> - // CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMREF1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[MUL2]]) : (!llvm.i64) -> !llvm.ptr + // CHECK: [[TYPED_MEMPOOL:%.+]] = llvm.bitcast [[MEMPOOL]] : !llvm.ptr to !llvm.ptr + // CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMREF1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.mlir.constant // CHECK: llvm.insertvalue // CHECK: llvm.mlir.constant @@ -29,10 +29,10 @@ func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> { // CHECK: llvm.insertvalue // CHECK: llvm.insertvalue // CHECK: llvm.insertvalue - // CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> - // CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm<"float*"> to !llvm<"float*"> - // CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm.ptr to !llvm.ptr + // CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> } diff --git a/test/mlir/krnl/memory_pool.mlir b/test/mlir/krnl/memory_pool.mlir index c3d0084..b32a5b7 100644 --- a/test/mlir/krnl/memory_pool.mlir +++ b/test/mlir/krnl/memory_pool.mlir @@ -10,46 +10,46 @@ func @test_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { /// Allocate memory for the memory pool. // CHECK: [[MEMPOOL_SIZE:%.+]] = llvm.mlir.constant(400 : index) : !llvm.i64 - // CHECK: [[TMP1:%.+]] = llvm.mlir.null : !llvm<"i8*"> + // CHECK: [[TMP1:%.+]] = llvm.mlir.null : !llvm.ptr // CHECK: %[[CONST1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 - // CHECK: [[TMP2:%.+]] = llvm.getelementptr [[TMP1]][%[[CONST1]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> - // CHECK: [[TYPE_SIZE_IN_BYTES:%.+]] = llvm.ptrtoint [[TMP2]] : !llvm<"i8*"> to !llvm.i64 + // CHECK: [[TMP2:%.+]] = llvm.getelementptr [[TMP1]][%[[CONST1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK: [[TYPE_SIZE_IN_BYTES:%.+]] = llvm.ptrtoint [[TMP2]] : !llvm.ptr to !llvm.i64 // CHECK: [[TOTAL_SIZE:%.+]] = llvm.mul [[MEMPOOL_SIZE]], [[TYPE_SIZE_IN_BYTES]] : !llvm.i64 - // CHECK: [[ALLOC_MEM_POOL:%.+]] = llvm.call @malloc([[TOTAL_SIZE]]) : (!llvm.i64) -> !llvm<"i8*"> - // CHECK: [[BITCAST_ALLOC_MEM_POOL:%.+]] = llvm.bitcast [[ALLOC_MEM_POOL]] : !llvm<"i8*"> to !llvm<"i8*"> + // CHECK: [[ALLOC_MEM_POOL:%.+]] = llvm.call @malloc([[TOTAL_SIZE]]) : (!llvm.i64) -> !llvm.ptr + // CHECK: [[BITCAST_ALLOC_MEM_POOL:%.+]] = llvm.bitcast [[ALLOC_MEM_POOL]] : !llvm.ptr to !llvm.ptr /// MemRef representing the memory pool and which contains the memory allocated above. - // CHECK: [[MEMREF0:%.+]] = llvm.mlir.undef : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> - // CHECK: [[TMP3:%.+]] = llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[MEMREF0]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> - // CHECK: llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[TMP3]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: [[MEMREF0:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[TMP3:%.+]] = llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[MEMREF0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[TMP3]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: llvm.insertvalue // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: llvm.insertvalue - // CHECK: [[TMP4:%.+]] = llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: [[TMP4:%.+]] = llvm.insertvalue {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> /// Get reference within the memory pool where the data of the getref instruction has already been allocated. - // CHECK: [[MEMPOOL_BASE:%.+]] = llvm.extractvalue [[TMP4]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> - // CHECK: [[GETREF_MEMORY:%.+]] = llvm.getelementptr [[MEMPOOL_BASE]][%[[OFFSET]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> - // CHECK: [[CASTED_GETREF_MEMORY:%.+]] = llvm.bitcast [[GETREF_MEMORY]] : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: [[MEMPOOL_BASE:%.+]] = llvm.extractvalue [[TMP4]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[GETREF_MEMORY:%.+]] = llvm.getelementptr [[MEMPOOL_BASE]][%[[OFFSET]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK: [[CASTED_GETREF_MEMORY:%.+]] = llvm.bitcast [[GETREF_MEMORY]] : !llvm.ptr to !llvm.ptr /// Create MemRef for krnl.getref. - // CHECK: [[MEMREF1:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[MEMREF1_TMP1:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[MEMREF1_TMP2:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1_TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[MEMREF1_TMP1:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[MEMREF1_TMP2:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1_TMP1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[CONST2:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 - // CHECK: [[MEMREF1_TMP3:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF1_TMP2]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1_TMP3:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF1_TMP2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[CONST3:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 - // CHECK: [[MEMREF1_TMP4:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1_TMP3]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1_TMP4:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1_TMP3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[CONST4:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 - // CHECK: [[MEMREF1_TMP5:%.+]] = llvm.insertvalue [[CONST4]], [[MEMREF1_TMP4]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1_TMP5:%.+]] = llvm.insertvalue [[CONST4]], [[MEMREF1_TMP4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[CONST5:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 - // CHECK: [[MEMREF1_TMP6:%.+]] = llvm.insertvalue [[CONST5]], [[MEMREF1_TMP5]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1_TMP6:%.+]] = llvm.insertvalue [[CONST5]], [[MEMREF1_TMP5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[CONST6:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 - // CHECK: [[MEMREF1_TMP7:%.+]] = llvm.insertvalue [[CONST6]], [[MEMREF1_TMP6]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1_TMP7:%.+]] = llvm.insertvalue [[CONST6]], [[MEMREF1_TMP6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> /// Usage of the getref MemRef. - // CHECK: [[MEM0:%.+]] = llvm.extractvalue [[MEMREF1_TMP7]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEM0:%.+]] = llvm.extractvalue [[MEMREF1_TMP7]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[CONST7:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: [[CONST8:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 // CHECK: [[MUL1:%.+]] = llvm.mul {{.*}}, [[CONST8]] : !llvm.i64 @@ -57,10 +57,10 @@ func @test_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { // CHECK: [[CONST9:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[MUL2:%.+]] = llvm.mul {{.*}}, [[CONST9]] : !llvm.i64 // CHECK: %[[ADD2:.+]] = llvm.add [[ADD1]], [[MUL2]] : !llvm.i64 - // CHECK: llvm.getelementptr [[MEM0]][%[[ADD2]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + // CHECK: llvm.getelementptr [[MEM0]][%[[ADD2]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr /// Deallocation of the memory pool. - // CHECK: [[MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.extractvalue [[TMP4]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> - // CHECK: [[CASTED_MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.bitcast [[MEMPOOL_BASE_UNALIGNED]] : !llvm<"i8*"> to !llvm<"i8*"> - // CHECK: llvm.call @free([[CASTED_MEMPOOL_BASE_UNALIGNED]]) : (!llvm<"i8*">) -> () + // CHECK: [[MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.extractvalue [[TMP4]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[CASTED_MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.bitcast [[MEMPOOL_BASE_UNALIGNED]] : !llvm.ptr to !llvm.ptr + // CHECK: llvm.call @free([[CASTED_MEMPOOL_BASE_UNALIGNED]]) : (!llvm.ptr) -> () } diff --git a/test/mlir/krnl/reshape.mlir b/test/mlir/krnl/reshape.mlir index 65f24ce..5b349e5 100644 --- a/test/mlir/krnl/reshape.mlir +++ b/test/mlir/krnl/reshape.mlir @@ -6,22 +6,22 @@ func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi64>) -> tensor<*x %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor, tensor<4xi64>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () - // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) - // CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> - // CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> - // CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*"> - // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*"> + // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) + // CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm.ptr to !llvm.ptr + // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr to !llvm.ptr // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64 // CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1 - // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void - // CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> !llvm.void + // CHECK: llvm.return [[RES]] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> } diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index 4abfc92..7c1e059 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ git clone https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX MLIR. -cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd .. +cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..