Update LLVM commit ID to the version that corresponds to MLIR News, 13th edition (8/7/2020) (#248)

* Update LLVM commit ID to include to the new modeling of LLVM type in MLIR

* Fix commit id discrepancy

* Update README.md

* Update MLIR version

* Force rebuild prereq dockers and see what happens.

* Use LLVM commit ID that corresponds to MLIR News, 13th edition (8/7/2020)

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-08-14 13:52:48 +09:00 committed by GitHub
parent c7b5eecc7f
commit 1b42d0b4eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 157 additions and 167 deletions

View File

@ -18,7 +18,7 @@ jobs:
git submodule update --init --recursive
# Use cached mlir installation if possible.
- restore_cache:
key: V15-LLVM-PROJECT-{{ arch }}
key: V17-LLVM-PROJECT-{{ arch }}
- run:
name: Install MLIR
command: |
@ -29,7 +29,7 @@ jobs:
source onnx-mlir/utils/install-mlir.sh
fi
- save_cache:
key: V15-LLVM-PROJECT-{{ arch }}
key: V17-LLVM-PROJECT-{{ arch }}
paths:
- llvm-project
- run:

View File

@ -11,7 +11,7 @@ script:
- echo "CPU Architecture is " $TRAVIS_CPU_ARCH
- echo "commit is " $TRAVIS_COMMIT
- df -h
- if [ "$TRAVIS_COMMIT_MESSAGE" == "Update MLIR version" ]; then
- if true; then
echo "Building Prereq";
docker build --tag onnxmlirczar/onnx-mlir-llvmimage:$TRAVIS_CPU_ARCH -f ./docker/prereq.$TRAVIS_CPU_ARCH.Dockerfile ./utils;
docker login -u onnxmlirczar -p 143f1da2-332f-45a1-8587-d6cb07c13230

View File

@ -133,11 +133,13 @@ function(find_mlir_lib lib)
endfunction(find_mlir_lib)
find_mlir_lib(MLIRAffineOps)
find_mlir_lib(MLIRAffineUtils)
find_mlir_lib(MLIRAffineToStandard)
find_mlir_lib(MLIRAffineTransforms)
find_mlir_lib(MLIRAnalysis)
find_mlir_lib(MLIRCallInterfaces)
find_mlir_lib(MLIRControlFlowInterfaces)
find_mlir_lib(MLIRCopyOpInterface)
find_mlir_lib(MLIRDialect)
find_mlir_lib(MLIREDSC)
find_mlir_lib(MLIRExecutionEngine)
@ -158,6 +160,7 @@ find_mlir_lib(MLIRMlirOptMain)
find_mlir_lib(MLIRParser)
find_mlir_lib(MLIRPass)
find_mlir_lib(MLIRStandardOps)
find_mlir_lib(MLIRStandardOpsTransforms)
find_mlir_lib(MLIRStandardToLLVM)
find_mlir_lib(MLIRSideEffectInterfaces)
find_mlir_lib(MLIRTargetLLVMIR)
@ -172,12 +175,14 @@ find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRTranslation)
find_mlir_lib(MLIRVector)
find_mlir_lib(MLIRVectorInterfaces)
find_mlir_lib(MLIRVectorToLLVM)
find_mlir_lib(MLIRVectorToSCF)
find_mlir_lib(MLIRMlirOptMain)
find_mlir_lib(MLIRAffineEDSC)
find_mlir_lib(MLIRLinalgEDSC)
find_mlir_lib(MLIRViewLikeInterface)
find_mlir_lib(MLIRPresburger)
find_mlir_lib(LLVMCore)
find_mlir_lib(LLVMSupport)
@ -200,12 +205,16 @@ find_mlir_lib(LLVMFrontendOpenMP)
set(MLIRLibs
${MLIRAffineToStandard}
${MLIRAffineOps}
${MLIRAffineUtils}
${MLIRCopyOpInterface}
${MLIRLLVMIR}
${MLIRStandardOps}
${MLIRStandardOpsTransforms}
${MLIRStandardToLLVM}
${MLIRTransforms}
${MLIRSCFToStandard}
${MLIRVector}
${MLIRVectorInterfaces}
${MLIRVectorToLLVM}
${MLIRVectorToSCF}
${MLIRSCF}
@ -249,6 +258,7 @@ set(MLIRLibs
${MLIRAffineEDSC}
${MLIRLinalgEDSC}
${MLIRViewLikeInterface}
${MLIRPresburger}
# strict order verified
${LLVMBitWriter}
${LLVMObject}

View File

@ -62,7 +62,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..
```
[same-as-file]: <> (utils/build-mlir.sh)
@ -152,7 +152,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..
```
[same-as-file]: <> (utils/build-mlir.cmd)

View File

@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..
```
[same-as-file]: <> (utils/build-mlir.sh)
@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..
```
[same-as-file]: <> (utils/build-mlir.cmd)

View File

@ -31,7 +31,7 @@ public:
LogicalResult matchAndRewrite(
KrnlTerminatorOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
rewriter.replaceOpWithNewOp<AffineYieldOp>(op);
return success();
}
};
@ -233,7 +233,7 @@ void ConvertKrnlToAffinePass::runOnFunction() {
ConversionTarget target(getContext());
target.addIllegalOp<KrnlTerminatorOp>();
target.addLegalOp<AffineTerminatorOp>();
target.addLegalOp<AffineYieldOp>();
OwningRewritePatternList patterns;
patterns.insert<KrnlTerminatorLowering>(&getContext());
DenseSet<Operation *> unconverted;

View File

@ -105,17 +105,17 @@ static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
/// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
static FlatSymbolRefAttr getOrInsertMemcpy(
PatternRewriter &rewriter, ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
// Create a function declaration for memcpy, the signature is:
// * `void (i8*, i8* , i64, i1)`
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(context);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(context);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>(
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
@ -144,9 +144,6 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
KrnlGetRefOpAdaptor operandAdaptor(operands);
@ -209,9 +206,6 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
auto krnlGlobalOp = llvm::dyn_cast<KrnlGlobalOp>(op);
@ -258,8 +252,8 @@ public:
}
// Some frequently used types.
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
// Allocate the memory where the constants will be used from.
// This is a region of local memory and needs to be emitted as an alloca.
@ -288,17 +282,17 @@ public:
rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
// - Set volatile.
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
LLVM::LLVMType::getInt1Ty(llvmDialect),
LLVM::LLVMType::getInt1Ty(context),
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// - Copy constant data into the alloca.
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
auto memcpyRef = getOrInsertMemcpy(rewriter, module);
rewriter.create<CallOp>(loc, memcpyRef,
LLVM::LLVMType::getVoidTy(llvmDialect),
LLVM::LLVMType::getVoidTy(context),
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
} else {
// Some frequently used types.
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
// Allocate the memory where the constants will be used from.
// This is a region of local memory and needs to be emitted as an alloca.
@ -351,13 +345,10 @@ public:
auto *context = op->getContext();
KrnlMemcpyOpAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
// Get a symbol reference to the memcpy function, inserting it if necessary.
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule);
// First operand.
Type dstType = operandAdaptor.dest()
@ -367,7 +358,7 @@ public:
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
loc, LLVM::LLVMType::getInt8PtrTy(context), alignedDstMemory);
// Second operand.
Type srcType = operandAdaptor.src()
@ -377,20 +368,19 @@ public:
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
loc, LLVM::LLVMType::getInt8PtrTy(context), alignedSrcMemory);
// Size.
Value int64Size = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operandAdaptor.size());
loc, LLVM::LLVMType::getInt64Ty(context), operandAdaptor.size());
// Is volatile (set to false).
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
LLVM::LLVMType::getInt1Ty(llvmDialect),
LLVM::LLVMType::getInt1Ty(context),
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// Memcpy call
rewriter.create<CallOp>(loc, memcpyRef,
LLVM::LLVMType::getVoidTy(llvmDialect),
rewriter.create<CallOp>(loc, memcpyRef, LLVM::LLVMType::getVoidTy(context),
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
int64Size, isVolatile}));
@ -441,19 +431,17 @@ public:
LogicalResult matchAndRewrite(
KrnlEntryPointOp op, PatternRewriter &rewriter) const override {
auto *llvmDialect =
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
auto module = op.getParentOfType<ModuleOp>();
auto apiRegistry = RegisterAllApis(module, rewriter, llvmDialect);
auto *context = module.getContext();
auto apiRegistry = RegisterAllApis(module, rewriter);
auto loc = op.getLoc();
auto numOutputs =
op.getAttrOfType<IntegerAttr>(KrnlEntryPointOp::getNumOutputsAttrName())
.getInt();
using LLVMType = LLVM::LLVMType;
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
auto int32Ty = LLVMType::getInt32Ty(context);
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
// signature. The signature is dynamic because it remains the same no matter
@ -514,7 +502,7 @@ public:
// Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef.
fillPtrToMemRefWithRtMemRef(
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, module);
// ptrToMemRef will be an input to main computation graph function.
staticInputs.emplace_back(ptrToMemRef);
@ -565,7 +553,7 @@ public:
auto outRtMemRef = callApi(rewriter, loc, apiRegistry,
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
fillRtMemRefWithMemRef(
memRef, outRtMemRef, rewriter, loc, apiRegistry, llvmDialect);
memRef, outRtMemRef, rewriter, loc, apiRegistry, module);
auto idx = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(i));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
@ -580,13 +568,14 @@ public:
private:
using ApiRegistry = std::map<API, ApiSpec>;
ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter,
LLVM::LLVMDialect *llvmDialect) const {
ApiRegistry RegisterAllApis(
ModuleOp &module, PatternRewriter &rewriter) const {
auto *context = module.getContext();
using LLVMType = LLVM::LLVMType;
auto voidTy = LLVMType::getVoidTy(llvmDialect);
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
auto int64Ty = LLVMType::getInt64Ty(llvmDialect);
auto voidTy = LLVMType::getVoidTy(context);
auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
auto int32Ty = LLVMType::getInt32Ty(context);
auto int64Ty = LLVMType::getInt64Ty(context);
auto int64PtrTy = int64Ty.getPointerTo();
// Declare API type as an enum value, its string name and an LLVM Type
@ -646,11 +635,11 @@ private:
void fillPtrToMemRefWithRtMemRef(Value &dynMemRef, Value &ptrToMemRef,
PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const {
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
auto *context = module.getContext();
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, ptrToMemRef);
@ -707,18 +696,18 @@ private:
void fillRtMemRefWithMemRef(Value &outMemRef, Value &outRtMemRef,
PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const {
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
auto *context = module.getContext();
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
// Extract the data pointer, and record it in dynamic mem ref created.
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
outMemRefTy.getStructElementType(0), outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefDataPtr);
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{outRtMemRef, outMemRefDataPtr});
auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy();
@ -776,15 +765,11 @@ public:
ModuleOp module = op->getParentOfType<ModuleOp>();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
auto packedConstOp = llvm::dyn_cast<KrnlPackedConstantOp>(op);
LLVM::GlobalOp globalBase;
// Some frequently used types.
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
{
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
@ -808,8 +793,7 @@ public:
llvmI8PtrTy, {llvmI64Ty}, /*isVarArg=*/false),
rewriter);
auto constPackSize = rewriter.create<LLVM::ConstantOp>(loc,
LLVM::LLVMType::getInt64Ty(llvmDialect),
packedConstOp.size_in_bytesAttr());
LLVM::LLVMType::getInt64Ty(context), packedConstOp.size_in_bytesAttr());
Value alloc = rewriter
.create<CallOp>(loc, getEmbeddedConstPoolRef, llvmI8PtrTy,
ArrayRef<Value>({constPackSize}))
@ -822,14 +806,13 @@ public:
// Record constant pack *file path* as a global variable (by recording the
// file path string's underlying char array + its length).
const auto &fileNameAttr = packedConstOp.file_nameAttr();
auto type =
LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
fileNameAttr.getValue().size());
auto type = LLVM::LLVMType::getArrayTy(
LLVM::LLVMType::getInt8Ty(context), fileNameAttr.getValue().size());
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName(),
fileNameAttr);
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
type = LLVM::LLVMType::getInt64Ty(context);
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFilePathStrLenSymbolName(),
@ -840,18 +823,18 @@ public:
auto constPackFileName =
llvm::sys::path::filename(fileNameAttr.getValue());
type = LLVM::LLVMType::getArrayTy(
LLVM::LLVMType::getInt8Ty(llvmDialect), constPackFileName.size());
LLVM::LLVMType::getInt8Ty(context), constPackFileName.size());
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName(),
rewriter.getStringAttr(constPackFileName));
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
type = LLVM::LLVMType::getInt64Ty(context);
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName(),
rewriter.getI64IntegerAttr(constPackFileName.size()));
type = LLVM::LLVMType::getInt8Ty(llvmDialect);
type = LLVM::LLVMType::getInt8Ty(context);
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackIsLESymbolName(),
@ -889,14 +872,14 @@ void ConvertKrlnToLLVMPass::runOnOperation() {
// Lower the MemRef types to a representation in LLVM.
LowerToLLVMOptions options;
options.emitCWrappers = true;
LLVMTypeConverter typeConverter(&getContext());
LLVMTypeConverter typeConverter(&getContext(), options);
// We have a combination of `krnl`, `affine`, and `std` operations. We
// lower in stages until all the code is in the LLVM dialect.
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns, options);
populateStdToLLVMConversionPatterns(typeConverter, patterns);
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
&getContext(), typeConverter);

View File

@ -23,7 +23,9 @@ enum Kinds {
};
}
class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> {
class LoopType
: public mlir::Type::TypeBase<LoopType, mlir::Type, mlir::TypeStorage> {
public:
using Base::Base;

View File

@ -67,7 +67,8 @@ enum Kind {
};
} // namespace ONNXTypes
class StringType : public mlir::Type::TypeBase<StringType, mlir::Type> {
class StringType
: public mlir::Type::TypeBase<StringType, mlir::Type, mlir::TypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == ONNXTypes::STRING; }

View File

@ -262,8 +262,9 @@ void genLLVMBitcode(const mlir::OwningModuleRef &module,
llvm::raw_fd_ostream moduleBitcodeStream(
unoptimizedBitcodePath, error, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
llvm::LLVMContext llvmContext;
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module, llvmContext),
moduleBitcodeStream);
moduleBitcodeStream.flush();
// Use the LLVM's 'opt' command to optimize the bitcode.

View File

@ -67,18 +67,11 @@ int main(int argc, char **argv) {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::vector::VectorDialect>();
// Register transformation passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Affine/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Linalg/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/SCF/Passes.h.inc"
registerTransformsPasses();
registerAffinePasses();
registerLinalgPasses();
registerSCFPasses();
registerStandardPasses();
llvm::InitLLVM y(argc, argv);

View File

@ -6,16 +6,16 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
// CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm<"[3 x [2 x float]]">
// CHECK: llvm.func @test_constant({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1)
// CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm.array<3 x array<2 x float>>
// CHECK: llvm.func @test_constant({{.*}}) -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64
// CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm<"[3 x [2 x float]]"> : (!llvm.i64) -> !llvm<"[3 x [2 x float]]*">
// CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
// CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm.array<3 x array<2 x float>> : (!llvm.i64) -> !llvm.ptr<array<3 x array<2 x float>>>
// CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<i8>
// CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm<"[3 x [2 x float]]*">
// CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
// CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm.ptr<array<3 x array<2 x float>>>
// CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<i8>
/// Size of the constant tensor in bytes.
// CHECK: [[CONST4:%.+]] = llvm.mlir.constant(4 : i64) : !llvm.i64
@ -26,30 +26,30 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
/// Volatile flag
// CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
/// Prepare data for MemRef insertion.
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"float*">
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<float>
/// Insert the constant value in the local MemRef.
// CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
/// Insert offset.
// CHECK: [[CONST00:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
/// Insert sizes and strides.
// CHECK: [[CONST3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[CONST2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.return [[MEMREF5]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.return [[MEMREF5]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
}

View File

@ -11,16 +11,16 @@ func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> {
// CHECK: [[CONST_10_0:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[CONST_10_1:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MUL1:%.+]] = llvm.mul [[CONST_10_0]], [[CONST_10_1]] : !llvm.i64
// CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.null : !llvm<"float*">
// CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.null : !llvm.ptr<float>
// CHECK: %[[CONST_1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm<"float*"> to !llvm.i64
// CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm.ptr<float> to !llvm.i64
// CHECK: [[MUL2:%.+]] = llvm.mul [[MUL1]], [[ELEM_SIZE]] : !llvm.i64
// CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[MUL2]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK: [[TYPED_MEMPOOL:%.+]] = llvm.bitcast [[MEMPOOL]] : !llvm<"i8*"> to !llvm<"float*">
// CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMREF1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[MUL2]]) : (!llvm.i64) -> !llvm.ptr<i8>
// CHECK: [[TYPED_MEMPOOL:%.+]] = llvm.bitcast [[MEMPOOL]] : !llvm.ptr<i8> to !llvm.ptr<float>
// CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMREF1]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.mlir.constant
// CHECK: llvm.insertvalue
// CHECK: llvm.mlir.constant
@ -29,10 +29,10 @@ func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> {
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
// CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm<"float*"> to !llvm<"float*">
// CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm.ptr<float> to !llvm.ptr<float>
// CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
}

View File

@ -10,46 +10,46 @@ func @test_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
/// Allocate memory for the memory pool.
// CHECK: [[MEMPOOL_SIZE:%.+]] = llvm.mlir.constant(400 : index) : !llvm.i64
// CHECK: [[TMP1:%.+]] = llvm.mlir.null : !llvm<"i8*">
// CHECK: [[TMP1:%.+]] = llvm.mlir.null : !llvm.ptr<i8>
// CHECK: %[[CONST1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[TMP2:%.+]] = llvm.getelementptr [[TMP1]][%[[CONST1]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
// CHECK: [[TYPE_SIZE_IN_BYTES:%.+]] = llvm.ptrtoint [[TMP2]] : !llvm<"i8*"> to !llvm.i64
// CHECK: [[TMP2:%.+]] = llvm.getelementptr [[TMP1]][%[[CONST1]]] : (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
// CHECK: [[TYPE_SIZE_IN_BYTES:%.+]] = llvm.ptrtoint [[TMP2]] : !llvm.ptr<i8> to !llvm.i64
// CHECK: [[TOTAL_SIZE:%.+]] = llvm.mul [[MEMPOOL_SIZE]], [[TYPE_SIZE_IN_BYTES]] : !llvm.i64
// CHECK: [[ALLOC_MEM_POOL:%.+]] = llvm.call @malloc([[TOTAL_SIZE]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK: [[BITCAST_ALLOC_MEM_POOL:%.+]] = llvm.bitcast [[ALLOC_MEM_POOL]] : !llvm<"i8*"> to !llvm<"i8*">
// CHECK: [[ALLOC_MEM_POOL:%.+]] = llvm.call @malloc([[TOTAL_SIZE]]) : (!llvm.i64) -> !llvm.ptr<i8>
// CHECK: [[BITCAST_ALLOC_MEM_POOL:%.+]] = llvm.bitcast [[ALLOC_MEM_POOL]] : !llvm.ptr<i8> to !llvm.ptr<i8>
/// MemRef representing the memory pool and which contains the memory allocated above.
// CHECK: [[MEMREF0:%.+]] = llvm.mlir.undef : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: [[TMP3:%.+]] = llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[MEMREF0]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[TMP3]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: [[MEMREF0:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[TMP3:%.+]] = llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[MEMREF0]][0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[TMP3]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.insertvalue
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue
// CHECK: [[TMP4:%.+]] = llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: [[TMP4:%.+]] = llvm.insertvalue {{.*}}[4, 0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
/// Get reference within the memory pool where the data of the getref instruction has already been allocated.
// CHECK: [[MEMPOOL_BASE:%.+]] = llvm.extractvalue [[TMP4]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: [[GETREF_MEMORY:%.+]] = llvm.getelementptr [[MEMPOOL_BASE]][%[[OFFSET]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
// CHECK: [[CASTED_GETREF_MEMORY:%.+]] = llvm.bitcast [[GETREF_MEMORY]] : !llvm<"i8*"> to !llvm<"float*">
// CHECK: [[MEMPOOL_BASE:%.+]] = llvm.extractvalue [[TMP4]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[GETREF_MEMORY:%.+]] = llvm.getelementptr [[MEMPOOL_BASE]][%[[OFFSET]]] : (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
// CHECK: [[CASTED_GETREF_MEMORY:%.+]] = llvm.bitcast [[GETREF_MEMORY]] : !llvm.ptr<i8> to !llvm.ptr<float>
/// Create MemRef for krnl.getref.
// CHECK: [[MEMREF1:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1_TMP1:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1_TMP2:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1_TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[MEMREF1_TMP1:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[MEMREF1_TMP2:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1_TMP1]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[CONST2:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP3:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF1_TMP2]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1_TMP3:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF1_TMP2]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[CONST3:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP4:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1_TMP3]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1_TMP4:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1_TMP3]][3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[CONST4:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP5:%.+]] = llvm.insertvalue [[CONST4]], [[MEMREF1_TMP4]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1_TMP5:%.+]] = llvm.insertvalue [[CONST4]], [[MEMREF1_TMP4]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[CONST5:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP6:%.+]] = llvm.insertvalue [[CONST5]], [[MEMREF1_TMP5]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1_TMP6:%.+]] = llvm.insertvalue [[CONST5]], [[MEMREF1_TMP5]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[CONST6:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP7:%.+]] = llvm.insertvalue [[CONST6]], [[MEMREF1_TMP6]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1_TMP7:%.+]] = llvm.insertvalue [[CONST6]], [[MEMREF1_TMP6]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
/// Usage of the getref MemRef.
// CHECK: [[MEM0:%.+]] = llvm.extractvalue [[MEMREF1_TMP7]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEM0:%.+]] = llvm.extractvalue [[MEMREF1_TMP7]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[CONST7:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: [[CONST8:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MUL1:%.+]] = llvm.mul {{.*}}, [[CONST8]] : !llvm.i64
@ -57,10 +57,10 @@ func @test_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK: [[CONST9:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[MUL2:%.+]] = llvm.mul {{.*}}, [[CONST9]] : !llvm.i64
// CHECK: %[[ADD2:.+]] = llvm.add [[ADD1]], [[MUL2]] : !llvm.i64
// CHECK: llvm.getelementptr [[MEM0]][%[[ADD2]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK: llvm.getelementptr [[MEM0]][%[[ADD2]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
/// Deallocation of the memory pool.
// CHECK: [[MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.extractvalue [[TMP4]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: [[CASTED_MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.bitcast [[MEMPOOL_BASE_UNALIGNED]] : !llvm<"i8*"> to !llvm<"i8*">
// CHECK: llvm.call @free([[CASTED_MEMPOOL_BASE_UNALIGNED]]) : (!llvm<"i8*">) -> ()
// CHECK: [[MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.extractvalue [[TMP4]][0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[CASTED_MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.bitcast [[MEMPOOL_BASE_UNALIGNED]] : !llvm.ptr<i8> to !llvm.ptr<i8>
// CHECK: llvm.call @free([[CASTED_MEMPOOL_BASE_UNALIGNED]]) : (!llvm.ptr<i8>) -> ()
}

View File

@ -6,22 +6,22 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*x
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1)
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm.ptr<float> to !llvm.ptr<i8>
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr<float> to !llvm.ptr<i8>
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
// CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
// CHECK: llvm.return [[RES]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
}

View File

@ -1,3 +1,3 @@
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX MLIR.
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..