Update LLVM commit ID to the version that corresponds to MLIR News, 13th edition (8/7/2020) (#248)
* Update LLVM commit ID to include to the new modeling of LLVM type in MLIR * Fix commit id discrepancy * Update README.md * Update MLIR version * Force rebuild prereq dockers and see what happens. * Use LLVM commit ID that corresponds to MLIR News, 13th edition (8/7/2020) Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
c7b5eecc7f
commit
1b42d0b4eb
|
@ -18,7 +18,7 @@ jobs:
|
|||
git submodule update --init --recursive
|
||||
# Use cached mlir installation if possible.
|
||||
- restore_cache:
|
||||
key: V15-LLVM-PROJECT-{{ arch }}
|
||||
key: V17-LLVM-PROJECT-{{ arch }}
|
||||
- run:
|
||||
name: Install MLIR
|
||||
command: |
|
||||
|
@ -29,7 +29,7 @@ jobs:
|
|||
source onnx-mlir/utils/install-mlir.sh
|
||||
fi
|
||||
- save_cache:
|
||||
key: V15-LLVM-PROJECT-{{ arch }}
|
||||
key: V17-LLVM-PROJECT-{{ arch }}
|
||||
paths:
|
||||
- llvm-project
|
||||
- run:
|
||||
|
|
|
@ -11,7 +11,7 @@ script:
|
|||
- echo "CPU Architecture is " $TRAVIS_CPU_ARCH
|
||||
- echo "commit is " $TRAVIS_COMMIT
|
||||
- df -h
|
||||
- if [ "$TRAVIS_COMMIT_MESSAGE" == "Update MLIR version" ]; then
|
||||
- if true; then
|
||||
echo "Building Prereq";
|
||||
docker build --tag onnxmlirczar/onnx-mlir-llvmimage:$TRAVIS_CPU_ARCH -f ./docker/prereq.$TRAVIS_CPU_ARCH.Dockerfile ./utils;
|
||||
docker login -u onnxmlirczar -p 143f1da2-332f-45a1-8587-d6cb07c13230
|
||||
|
|
10
MLIR.cmake
10
MLIR.cmake
|
@ -133,11 +133,13 @@ function(find_mlir_lib lib)
|
|||
endfunction(find_mlir_lib)
|
||||
|
||||
find_mlir_lib(MLIRAffineOps)
|
||||
find_mlir_lib(MLIRAffineUtils)
|
||||
find_mlir_lib(MLIRAffineToStandard)
|
||||
find_mlir_lib(MLIRAffineTransforms)
|
||||
find_mlir_lib(MLIRAnalysis)
|
||||
find_mlir_lib(MLIRCallInterfaces)
|
||||
find_mlir_lib(MLIRControlFlowInterfaces)
|
||||
find_mlir_lib(MLIRCopyOpInterface)
|
||||
find_mlir_lib(MLIRDialect)
|
||||
find_mlir_lib(MLIREDSC)
|
||||
find_mlir_lib(MLIRExecutionEngine)
|
||||
|
@ -158,6 +160,7 @@ find_mlir_lib(MLIRMlirOptMain)
|
|||
find_mlir_lib(MLIRParser)
|
||||
find_mlir_lib(MLIRPass)
|
||||
find_mlir_lib(MLIRStandardOps)
|
||||
find_mlir_lib(MLIRStandardOpsTransforms)
|
||||
find_mlir_lib(MLIRStandardToLLVM)
|
||||
find_mlir_lib(MLIRSideEffectInterfaces)
|
||||
find_mlir_lib(MLIRTargetLLVMIR)
|
||||
|
@ -172,12 +175,14 @@ find_mlir_lib(MLIRTargetLLVMIR)
|
|||
find_mlir_lib(MLIRTransformUtils)
|
||||
find_mlir_lib(MLIRTranslation)
|
||||
find_mlir_lib(MLIRVector)
|
||||
find_mlir_lib(MLIRVectorInterfaces)
|
||||
find_mlir_lib(MLIRVectorToLLVM)
|
||||
find_mlir_lib(MLIRVectorToSCF)
|
||||
find_mlir_lib(MLIRMlirOptMain)
|
||||
find_mlir_lib(MLIRAffineEDSC)
|
||||
find_mlir_lib(MLIRLinalgEDSC)
|
||||
find_mlir_lib(MLIRViewLikeInterface)
|
||||
find_mlir_lib(MLIRPresburger)
|
||||
|
||||
find_mlir_lib(LLVMCore)
|
||||
find_mlir_lib(LLVMSupport)
|
||||
|
@ -200,12 +205,16 @@ find_mlir_lib(LLVMFrontendOpenMP)
|
|||
set(MLIRLibs
|
||||
${MLIRAffineToStandard}
|
||||
${MLIRAffineOps}
|
||||
${MLIRAffineUtils}
|
||||
${MLIRCopyOpInterface}
|
||||
${MLIRLLVMIR}
|
||||
${MLIRStandardOps}
|
||||
${MLIRStandardOpsTransforms}
|
||||
${MLIRStandardToLLVM}
|
||||
${MLIRTransforms}
|
||||
${MLIRSCFToStandard}
|
||||
${MLIRVector}
|
||||
${MLIRVectorInterfaces}
|
||||
${MLIRVectorToLLVM}
|
||||
${MLIRVectorToSCF}
|
||||
${MLIRSCF}
|
||||
|
@ -249,6 +258,7 @@ set(MLIRLibs
|
|||
${MLIRAffineEDSC}
|
||||
${MLIRLinalgEDSC}
|
||||
${MLIRViewLikeInterface}
|
||||
${MLIRPresburger}
|
||||
# strict order verified
|
||||
${LLVMBitWriter}
|
||||
${LLVMObject}
|
||||
|
|
|
@ -62,7 +62,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
|
|||
``` bash
|
||||
git clone https://github.com/llvm/llvm-project.git
|
||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
|
||||
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..
|
||||
```
|
||||
|
||||
[same-as-file]: <> (utils/build-mlir.sh)
|
||||
|
@ -152,7 +152,7 @@ Install MLIR (as a part of LLVM-Project):
|
|||
```shell
|
||||
git clone https://github.com/llvm/llvm-project.git
|
||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
|
||||
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..
|
||||
```
|
||||
|
||||
[same-as-file]: <> (utils/build-mlir.cmd)
|
||||
|
|
|
@ -20,7 +20,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
|
|||
``` bash
|
||||
git clone https://github.com/llvm/llvm-project.git
|
||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
|
||||
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..
|
||||
```
|
||||
|
||||
[same-as-file]: <> (utils/build-mlir.sh)
|
||||
|
@ -110,7 +110,7 @@ Install MLIR (as a part of LLVM-Project):
|
|||
```shell
|
||||
git clone https://github.com/llvm/llvm-project.git
|
||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
|
||||
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..
|
||||
```
|
||||
|
||||
[same-as-file]: <> (utils/build-mlir.cmd)
|
||||
|
|
|
@ -31,7 +31,7 @@ public:
|
|||
|
||||
LogicalResult matchAndRewrite(
|
||||
KrnlTerminatorOp op, PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
||||
rewriter.replaceOpWithNewOp<AffineYieldOp>(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -233,7 +233,7 @@ void ConvertKrnlToAffinePass::runOnFunction() {
|
|||
|
||||
ConversionTarget target(getContext());
|
||||
target.addIllegalOp<KrnlTerminatorOp>();
|
||||
target.addLegalOp<AffineTerminatorOp>();
|
||||
target.addLegalOp<AffineYieldOp>();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<KrnlTerminatorLowering>(&getContext());
|
||||
DenseSet<Operation *> unconverted;
|
||||
|
|
|
@ -105,17 +105,17 @@ static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
|
|||
|
||||
/// Return a symbol reference to the memcpy function, inserting it into the
|
||||
/// module if necessary.
|
||||
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
||||
ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
|
||||
static FlatSymbolRefAttr getOrInsertMemcpy(
|
||||
PatternRewriter &rewriter, ModuleOp module) {
|
||||
auto *context = module.getContext();
|
||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||
// Create a function declaration for memcpy, the signature is:
|
||||
// * `void (i8*, i8* , i64, i1)`
|
||||
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
||||
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(context);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
|
||||
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(context);
|
||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
||||
ArrayRef<mlir::LLVM::LLVMType>(
|
||||
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
||||
|
@ -144,9 +144,6 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto *context = op->getContext();
|
||||
auto loc = op->getLoc();
|
||||
auto *llvmDialect =
|
||||
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||
|
||||
KrnlGetRefOpAdaptor operandAdaptor(operands);
|
||||
|
||||
|
@ -209,9 +206,6 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto *context = op->getContext();
|
||||
auto loc = op->getLoc();
|
||||
auto *llvmDialect =
|
||||
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||
|
||||
auto krnlGlobalOp = llvm::dyn_cast<KrnlGlobalOp>(op);
|
||||
|
||||
|
@ -258,8 +252,8 @@ public:
|
|||
}
|
||||
|
||||
// Some frequently used types.
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
|
||||
|
||||
// Allocate the memory where the constants will be used from.
|
||||
// This is a region of local memory and needs to be emitted as an alloca.
|
||||
|
@ -288,17 +282,17 @@ public:
|
|||
rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
|
||||
// - Set volatile.
|
||||
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
|
||||
LLVM::LLVMType::getInt1Ty(llvmDialect),
|
||||
LLVM::LLVMType::getInt1Ty(context),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||
// - Copy constant data into the alloca.
|
||||
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
|
||||
auto memcpyRef = getOrInsertMemcpy(rewriter, module);
|
||||
rewriter.create<CallOp>(loc, memcpyRef,
|
||||
LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||
LLVM::LLVMType::getVoidTy(context),
|
||||
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
||||
} else {
|
||||
// Some frequently used types.
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
|
||||
|
||||
// Allocate the memory where the constants will be used from.
|
||||
// This is a region of local memory and needs to be emitted as an alloca.
|
||||
|
@ -351,13 +345,10 @@ public:
|
|||
auto *context = op->getContext();
|
||||
KrnlMemcpyOpAdaptor operandAdaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
auto *llvmDialect =
|
||||
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||
|
||||
// Get a symbol reference to the memcpy function, inserting it if necessary.
|
||||
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
|
||||
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
|
||||
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule);
|
||||
|
||||
// First operand.
|
||||
Type dstType = operandAdaptor.dest()
|
||||
|
@ -367,7 +358,7 @@ public:
|
|||
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1));
|
||||
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(context), alignedDstMemory);
|
||||
|
||||
// Second operand.
|
||||
Type srcType = operandAdaptor.src()
|
||||
|
@ -377,20 +368,19 @@ public:
|
|||
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1));
|
||||
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(context), alignedSrcMemory);
|
||||
|
||||
// Size.
|
||||
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operandAdaptor.size());
|
||||
loc, LLVM::LLVMType::getInt64Ty(context), operandAdaptor.size());
|
||||
|
||||
// Is volatile (set to false).
|
||||
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
|
||||
LLVM::LLVMType::getInt1Ty(llvmDialect),
|
||||
LLVM::LLVMType::getInt1Ty(context),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||
|
||||
// Memcpy call
|
||||
rewriter.create<CallOp>(loc, memcpyRef,
|
||||
LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||
rewriter.create<CallOp>(loc, memcpyRef, LLVM::LLVMType::getVoidTy(context),
|
||||
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
||||
int64Size, isVolatile}));
|
||||
|
||||
|
@ -441,19 +431,17 @@ public:
|
|||
LogicalResult matchAndRewrite(
|
||||
KrnlEntryPointOp op, PatternRewriter &rewriter) const override {
|
||||
|
||||
auto *llvmDialect =
|
||||
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
auto apiRegistry = RegisterAllApis(module, rewriter, llvmDialect);
|
||||
auto *context = module.getContext();
|
||||
auto apiRegistry = RegisterAllApis(module, rewriter);
|
||||
auto loc = op.getLoc();
|
||||
auto numOutputs =
|
||||
op.getAttrOfType<IntegerAttr>(KrnlEntryPointOp::getNumOutputsAttrName())
|
||||
.getInt();
|
||||
|
||||
using LLVMType = LLVM::LLVMType;
|
||||
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
|
||||
auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
|
||||
auto int32Ty = LLVMType::getInt32Ty(context);
|
||||
|
||||
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
|
||||
// signature. The signature is dynamic because it remains the same no matter
|
||||
|
@ -514,7 +502,7 @@ public:
|
|||
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||
// from dynMemRef.
|
||||
fillPtrToMemRefWithRtMemRef(
|
||||
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, module);
|
||||
|
||||
// ptrToMemRef will be an input to main computation graph function.
|
||||
staticInputs.emplace_back(ptrToMemRef);
|
||||
|
@ -565,7 +553,7 @@ public:
|
|||
auto outRtMemRef = callApi(rewriter, loc, apiRegistry,
|
||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||
fillRtMemRefWithMemRef(
|
||||
memRef, outRtMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||
memRef, outRtMemRef, rewriter, loc, apiRegistry, module);
|
||||
auto idx = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||
|
@ -580,13 +568,14 @@ public:
|
|||
private:
|
||||
using ApiRegistry = std::map<API, ApiSpec>;
|
||||
|
||||
ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter,
|
||||
LLVM::LLVMDialect *llvmDialect) const {
|
||||
ApiRegistry RegisterAllApis(
|
||||
ModuleOp &module, PatternRewriter &rewriter) const {
|
||||
auto *context = module.getContext();
|
||||
using LLVMType = LLVM::LLVMType;
|
||||
auto voidTy = LLVMType::getVoidTy(llvmDialect);
|
||||
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
|
||||
auto int64Ty = LLVMType::getInt64Ty(llvmDialect);
|
||||
auto voidTy = LLVMType::getVoidTy(context);
|
||||
auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
|
||||
auto int32Ty = LLVMType::getInt32Ty(context);
|
||||
auto int64Ty = LLVMType::getInt64Ty(context);
|
||||
auto int64PtrTy = int64Ty.getPointerTo();
|
||||
|
||||
// Declare API type as an enum value, its string name and an LLVM Type
|
||||
|
@ -646,11 +635,11 @@ private:
|
|||
|
||||
void fillPtrToMemRefWithRtMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
||||
PatternRewriter &rewriter, const Location &loc,
|
||||
const std::map<API, ApiSpec> &apiRegistry,
|
||||
LLVM::LLVMDialect *llvmDialect) const {
|
||||
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
|
||||
auto *context = module.getContext();
|
||||
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
|
||||
|
||||
Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, ptrToMemRef);
|
||||
|
||||
|
@ -707,18 +696,18 @@ private:
|
|||
|
||||
void fillRtMemRefWithMemRef(Value &outMemRef, Value &outRtMemRef,
|
||||
PatternRewriter &rewriter, const Location &loc,
|
||||
const std::map<API, ApiSpec> &apiRegistry,
|
||||
LLVM::LLVMDialect *llvmDialect) const {
|
||||
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
|
||||
auto *context = module.getContext();
|
||||
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
|
||||
auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
|
||||
|
||||
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||
outMemRefTy.getStructElementType(0), outMemRef,
|
||||
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
||||
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefDataPtr);
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||
{outRtMemRef, outMemRefDataPtr});
|
||||
auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy();
|
||||
|
@ -776,15 +765,11 @@ public:
|
|||
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
auto *llvmDialect =
|
||||
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||
|
||||
auto packedConstOp = llvm::dyn_cast<KrnlPackedConstantOp>(op);
|
||||
LLVM::GlobalOp globalBase;
|
||||
// Some frequently used types.
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
|
||||
{
|
||||
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
|
@ -808,8 +793,7 @@ public:
|
|||
llvmI8PtrTy, {llvmI64Ty}, /*isVarArg=*/false),
|
||||
rewriter);
|
||||
auto constPackSize = rewriter.create<LLVM::ConstantOp>(loc,
|
||||
LLVM::LLVMType::getInt64Ty(llvmDialect),
|
||||
packedConstOp.size_in_bytesAttr());
|
||||
LLVM::LLVMType::getInt64Ty(context), packedConstOp.size_in_bytesAttr());
|
||||
Value alloc = rewriter
|
||||
.create<CallOp>(loc, getEmbeddedConstPoolRef, llvmI8PtrTy,
|
||||
ArrayRef<Value>({constPackSize}))
|
||||
|
@ -822,14 +806,13 @@ public:
|
|||
// Record constant pack *file path* as a global variable (by recording the
|
||||
// file path string's underlying char array + its length).
|
||||
const auto &fileNameAttr = packedConstOp.file_nameAttr();
|
||||
auto type =
|
||||
LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
|
||||
fileNameAttr.getValue().size());
|
||||
auto type = LLVM::LLVMType::getArrayTy(
|
||||
LLVM::LLVMType::getInt8Ty(context), fileNameAttr.getValue().size());
|
||||
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||
LLVM::Linkage::External,
|
||||
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName(),
|
||||
fileNameAttr);
|
||||
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
type = LLVM::LLVMType::getInt64Ty(context);
|
||||
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||
LLVM::Linkage::External,
|
||||
mlir::KrnlPackedConstantOp::getConstPackFilePathStrLenSymbolName(),
|
||||
|
@ -840,18 +823,18 @@ public:
|
|||
auto constPackFileName =
|
||||
llvm::sys::path::filename(fileNameAttr.getValue());
|
||||
type = LLVM::LLVMType::getArrayTy(
|
||||
LLVM::LLVMType::getInt8Ty(llvmDialect), constPackFileName.size());
|
||||
LLVM::LLVMType::getInt8Ty(context), constPackFileName.size());
|
||||
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||
LLVM::Linkage::External,
|
||||
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName(),
|
||||
rewriter.getStringAttr(constPackFileName));
|
||||
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
type = LLVM::LLVMType::getInt64Ty(context);
|
||||
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||
LLVM::Linkage::External,
|
||||
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName(),
|
||||
rewriter.getI64IntegerAttr(constPackFileName.size()));
|
||||
|
||||
type = LLVM::LLVMType::getInt8Ty(llvmDialect);
|
||||
type = LLVM::LLVMType::getInt8Ty(context);
|
||||
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
||||
LLVM::Linkage::External,
|
||||
mlir::KrnlPackedConstantOp::getConstPackIsLESymbolName(),
|
||||
|
@ -889,14 +872,14 @@ void ConvertKrlnToLLVMPass::runOnOperation() {
|
|||
// Lower the MemRef types to a representation in LLVM.
|
||||
LowerToLLVMOptions options;
|
||||
options.emitCWrappers = true;
|
||||
LLVMTypeConverter typeConverter(&getContext());
|
||||
LLVMTypeConverter typeConverter(&getContext(), options);
|
||||
|
||||
// We have a combination of `krnl`, `affine`, and `std` operations. We
|
||||
// lower in stages until all the code is in the LLVM dialect.
|
||||
OwningRewritePatternList patterns;
|
||||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns, options);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
|
||||
&getContext(), typeConverter);
|
||||
|
|
|
@ -23,7 +23,9 @@ enum Kinds {
|
|||
};
|
||||
}
|
||||
|
||||
class LoopType : public mlir::Type::TypeBase<LoopType, mlir::Type> {
|
||||
class LoopType
|
||||
: public mlir::Type::TypeBase<LoopType, mlir::Type, mlir::TypeStorage> {
|
||||
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
|
|
|
@ -67,7 +67,8 @@ enum Kind {
|
|||
};
|
||||
} // namespace ONNXTypes
|
||||
|
||||
class StringType : public mlir::Type::TypeBase<StringType, mlir::Type> {
|
||||
class StringType
|
||||
: public mlir::Type::TypeBase<StringType, mlir::Type, mlir::TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) { return kind == ONNXTypes::STRING; }
|
||||
|
|
|
@ -262,8 +262,9 @@ void genLLVMBitcode(const mlir::OwningModuleRef &module,
|
|||
llvm::raw_fd_ostream moduleBitcodeStream(
|
||||
unoptimizedBitcodePath, error, llvm::sys::fs::F_None);
|
||||
|
||||
llvm::WriteBitcodeToFile(
|
||||
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
||||
llvm::LLVMContext llvmContext;
|
||||
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module, llvmContext),
|
||||
moduleBitcodeStream);
|
||||
moduleBitcodeStream.flush();
|
||||
|
||||
// Use the LLVM's 'opt' command to optimize the bitcode.
|
||||
|
|
|
@ -67,18 +67,11 @@ int main(int argc, char **argv) {
|
|||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::vector::VectorDialect>();
|
||||
|
||||
// Register transformation passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Transforms/Passes.h.inc"
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/Affine/Passes.h.inc"
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/Linalg/Passes.h.inc"
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/SCF/Passes.h.inc"
|
||||
registerTransformsPasses();
|
||||
registerAffinePasses();
|
||||
registerLinalgPasses();
|
||||
registerSCFPasses();
|
||||
registerStandardPasses();
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
|
|
|
@ -6,16 +6,16 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
|
|||
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
|
||||
// CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm<"[3 x [2 x float]]">
|
||||
// CHECK: llvm.func @test_constant({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
|
||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1)
|
||||
// CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm.array<3 x array<2 x float>>
|
||||
// CHECK: llvm.func @test_constant({{.*}}) -> !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)> {
|
||||
|
||||
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64
|
||||
// CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm<"[3 x [2 x float]]"> : (!llvm.i64) -> !llvm<"[3 x [2 x float]]*">
|
||||
// CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
|
||||
// CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm.array<3 x array<2 x float>> : (!llvm.i64) -> !llvm.ptr<array<3 x array<2 x float>>>
|
||||
// CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<i8>
|
||||
|
||||
// CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm<"[3 x [2 x float]]*">
|
||||
// CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
|
||||
// CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm.ptr<array<3 x array<2 x float>>>
|
||||
// CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<i8>
|
||||
|
||||
/// Size of the constant tensor in bytes.
|
||||
// CHECK: [[CONST4:%.+]] = llvm.mlir.constant(4 : i64) : !llvm.i64
|
||||
|
@ -26,30 +26,30 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
|
|||
/// Volatile flag
|
||||
// CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1
|
||||
|
||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
|
||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
|
||||
|
||||
/// Prepare data for MemRef insertion.
|
||||
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"float*">
|
||||
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<float>
|
||||
|
||||
/// Insert the constant value in the local MemRef.
|
||||
// CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
|
||||
/// Insert offset.
|
||||
// CHECK: [[CONST00:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
|
||||
/// Insert sizes and strides.
|
||||
// CHECK: [[CONST3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
|
||||
// CHECK: [[CONST2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
|
||||
// CHECK: llvm.return [[MEMREF5]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.return [[MEMREF5]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
}
|
||||
|
|
|
@ -11,16 +11,16 @@ func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> {
|
|||
// CHECK: [[CONST_10_0:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: [[CONST_10_1:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: [[MUL1:%.+]] = llvm.mul [[CONST_10_0]], [[CONST_10_1]] : !llvm.i64
|
||||
// CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.null : !llvm<"float*">
|
||||
// CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.null : !llvm.ptr<float>
|
||||
// CHECK: %[[CONST_1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm<"float*"> to !llvm.i64
|
||||
// CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
|
||||
// CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm.ptr<float> to !llvm.i64
|
||||
// CHECK: [[MUL2:%.+]] = llvm.mul [[MUL1]], [[ELEM_SIZE]] : !llvm.i64
|
||||
// CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[MUL2]]) : (!llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK: [[TYPED_MEMPOOL:%.+]] = llvm.bitcast [[MEMPOOL]] : !llvm<"i8*"> to !llvm<"float*">
|
||||
// CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMREF1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[MUL2]]) : (!llvm.i64) -> !llvm.ptr<i8>
|
||||
// CHECK: [[TYPED_MEMPOOL:%.+]] = llvm.bitcast [[MEMPOOL]] : !llvm.ptr<i8> to !llvm.ptr<float>
|
||||
// CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMREF1]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.mlir.constant
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: llvm.mlir.constant
|
||||
|
@ -29,10 +29,10 @@ func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> {
|
|||
// CHECK: llvm.insertvalue
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm<"float*"> to !llvm<"float*">
|
||||
// CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
|
||||
// CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm.ptr<float> to !llvm.ptr<float>
|
||||
// CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
}
|
||||
|
|
|
@ -10,46 +10,46 @@ func @test_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
|||
|
||||
/// Allocate memory for the memory pool.
|
||||
// CHECK: [[MEMPOOL_SIZE:%.+]] = llvm.mlir.constant(400 : index) : !llvm.i64
|
||||
// CHECK: [[TMP1:%.+]] = llvm.mlir.null : !llvm<"i8*">
|
||||
// CHECK: [[TMP1:%.+]] = llvm.mlir.null : !llvm.ptr<i8>
|
||||
// CHECK: %[[CONST1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: [[TMP2:%.+]] = llvm.getelementptr [[TMP1]][%[[CONST1]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK: [[TYPE_SIZE_IN_BYTES:%.+]] = llvm.ptrtoint [[TMP2]] : !llvm<"i8*"> to !llvm.i64
|
||||
// CHECK: [[TMP2:%.+]] = llvm.getelementptr [[TMP1]][%[[CONST1]]] : (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
|
||||
// CHECK: [[TYPE_SIZE_IN_BYTES:%.+]] = llvm.ptrtoint [[TMP2]] : !llvm.ptr<i8> to !llvm.i64
|
||||
// CHECK: [[TOTAL_SIZE:%.+]] = llvm.mul [[MEMPOOL_SIZE]], [[TYPE_SIZE_IN_BYTES]] : !llvm.i64
|
||||
// CHECK: [[ALLOC_MEM_POOL:%.+]] = llvm.call @malloc([[TOTAL_SIZE]]) : (!llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK: [[BITCAST_ALLOC_MEM_POOL:%.+]] = llvm.bitcast [[ALLOC_MEM_POOL]] : !llvm<"i8*"> to !llvm<"i8*">
|
||||
// CHECK: [[ALLOC_MEM_POOL:%.+]] = llvm.call @malloc([[TOTAL_SIZE]]) : (!llvm.i64) -> !llvm.ptr<i8>
|
||||
// CHECK: [[BITCAST_ALLOC_MEM_POOL:%.+]] = llvm.bitcast [[ALLOC_MEM_POOL]] : !llvm.ptr<i8> to !llvm.ptr<i8>
|
||||
|
||||
/// MemRef representing the memory pool and which contains the memory allocated above.
|
||||
// CHECK: [[MEMREF0:%.+]] = llvm.mlir.undef : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK: [[TMP3:%.+]] = llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[MEMREF0]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK: llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[TMP3]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK: [[MEMREF0:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[TMP3:%.+]] = llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[MEMREF0]][0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[TMP3]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: [[TMP4:%.+]] = llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK: [[TMP4:%.+]] = llvm.insertvalue {{.*}}[4, 0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
|
||||
/// Get reference within the memory pool where the data of the getref instruction has already been allocated.
|
||||
// CHECK: [[MEMPOOL_BASE:%.+]] = llvm.extractvalue [[TMP4]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK: [[GETREF_MEMORY:%.+]] = llvm.getelementptr [[MEMPOOL_BASE]][%[[OFFSET]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK: [[CASTED_GETREF_MEMORY:%.+]] = llvm.bitcast [[GETREF_MEMORY]] : !llvm<"i8*"> to !llvm<"float*">
|
||||
// CHECK: [[MEMPOOL_BASE:%.+]] = llvm.extractvalue [[TMP4]][1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[GETREF_MEMORY:%.+]] = llvm.getelementptr [[MEMPOOL_BASE]][%[[OFFSET]]] : (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
|
||||
// CHECK: [[CASTED_GETREF_MEMORY:%.+]] = llvm.bitcast [[GETREF_MEMORY]] : !llvm.ptr<i8> to !llvm.ptr<float>
|
||||
|
||||
/// Create MemRef for krnl.getref.
|
||||
// CHECK: [[MEMREF1:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1_TMP1:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1_TMP2:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1_TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[MEMREF1_TMP1:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1]][0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[MEMREF1_TMP2:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1_TMP1]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[CONST2:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF1_TMP3:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF1_TMP2]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1_TMP3:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF1_TMP2]][2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[CONST3:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF1_TMP4:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1_TMP3]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1_TMP4:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1_TMP3]][3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[CONST4:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF1_TMP5:%.+]] = llvm.insertvalue [[CONST4]], [[MEMREF1_TMP4]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1_TMP5:%.+]] = llvm.insertvalue [[CONST4]], [[MEMREF1_TMP4]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[CONST5:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF1_TMP6:%.+]] = llvm.insertvalue [[CONST5]], [[MEMREF1_TMP5]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1_TMP6:%.+]] = llvm.insertvalue [[CONST5]], [[MEMREF1_TMP5]][3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[CONST6:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: [[MEMREF1_TMP7:%.+]] = llvm.insertvalue [[CONST6]], [[MEMREF1_TMP6]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEMREF1_TMP7:%.+]] = llvm.insertvalue [[CONST6]], [[MEMREF1_TMP6]][4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
|
||||
/// Usage of the getref MemRef.
|
||||
// CHECK: [[MEM0:%.+]] = llvm.extractvalue [[MEMREF1_TMP7]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[MEM0:%.+]] = llvm.extractvalue [[MEMREF1_TMP7]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[CONST7:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: [[CONST8:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||
// CHECK: [[MUL1:%.+]] = llvm.mul {{.*}}, [[CONST8]] : !llvm.i64
|
||||
|
@ -57,10 +57,10 @@ func @test_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
|||
// CHECK: [[CONST9:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: [[MUL2:%.+]] = llvm.mul {{.*}}, [[CONST9]] : !llvm.i64
|
||||
// CHECK: %[[ADD2:.+]] = llvm.add [[ADD1]], [[MUL2]] : !llvm.i64
|
||||
// CHECK: llvm.getelementptr [[MEM0]][%[[ADD2]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK: llvm.getelementptr [[MEM0]][%[[ADD2]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
|
||||
|
||||
/// Deallocation of the memory pool.
|
||||
// CHECK: [[MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.extractvalue [[TMP4]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK: [[CASTED_MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.bitcast [[MEMPOOL_BASE_UNALIGNED]] : !llvm<"i8*"> to !llvm<"i8*">
|
||||
// CHECK: llvm.call @free([[CASTED_MEMPOOL_BASE_UNALIGNED]]) : (!llvm<"i8*">) -> ()
|
||||
// CHECK: [[MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.extractvalue [[TMP4]][0] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
|
||||
// CHECK: [[CASTED_MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.bitcast [[MEMPOOL_BASE_UNALIGNED]] : !llvm.ptr<i8> to !llvm.ptr<i8>
|
||||
// CHECK: llvm.call @free([[CASTED_MEMPOOL_BASE_UNALIGNED]]) : (!llvm.ptr<i8>) -> ()
|
||||
}
|
||||
|
|
|
@ -6,22 +6,22 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*x
|
|||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
|
||||
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1)
|
||||
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
|
||||
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
|
||||
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm.ptr<float> to !llvm.ptr<i8>
|
||||
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr<float> to !llvm.ptr<i8>
|
||||
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
||||
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1
|
||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
|
||||
// CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
|
||||
// CHECK: llvm.return [[RES]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
|
||||
}
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
git clone https://github.com/llvm/llvm-project.git
|
||||
# Check out a specific branch that is known to work with ONNX MLIR.
|
||||
cd llvm-project && git checkout 32791937d7aceb0a5e1eaabf1bb1a6dbe1639792 && cd ..
|
||||
cd llvm-project && git checkout 9c94908320549a1a2328c758d6bbb694466021e7 && cd ..
|
||||
|
|
Loading…
Reference in New Issue