Fix building ONNF with latest LLVM/MLIR (#89)
* Fix build and link errors. * Fix end to end tests. * Fix indentation. * Fix type conversion. * Use newest LLVM version. * Use newest LLVM version.
This commit is contained in:
parent
b9f2f25b56
commit
b28c6906b4
|
@ -18,7 +18,7 @@ jobs:
|
|||
git submodule update --init --recursive
|
||||
# Use cached mlir installation if possible.
|
||||
- restore_cache:
|
||||
key: V4-LLVM-PROJECT-{{ arch }}
|
||||
key: V6-LLVM-PROJECT-{{ arch }}
|
||||
- run:
|
||||
name: Install MLIR
|
||||
command: |
|
||||
|
@ -29,7 +29,7 @@ jobs:
|
|||
source ONNF/utils/install-mlir.sh
|
||||
fi
|
||||
- save_cache:
|
||||
key: V4-LLVM-PROJECT-{{ arch }}
|
||||
key: V6-LLVM-PROJECT-{{ arch }}
|
||||
paths:
|
||||
- llvm-project
|
||||
- run:
|
||||
|
|
10
MLIR.cmake
10
MLIR.cmake
|
@ -58,9 +58,11 @@ find_mlir_lib(MLIRAffineOps)
|
|||
find_mlir_lib(MLIRAffineToStandard)
|
||||
find_mlir_lib(MLIRAnalysis)
|
||||
find_mlir_lib(MLIRDialect)
|
||||
find_mlir_lib(MLIREDSC)
|
||||
find_mlir_lib(MLIRExecutionEngine)
|
||||
find_mlir_lib(MLIRIR)
|
||||
find_mlir_lib(MLIRLLVMIR)
|
||||
find_mlir_lib(MLIRLoopAnalysis)
|
||||
find_mlir_lib(MLIRLoopToStandard)
|
||||
find_mlir_lib(MLIRLoopOps)
|
||||
find_mlir_lib(MLIRParser)
|
||||
|
@ -71,7 +73,8 @@ find_mlir_lib(MLIRTargetLLVMIR)
|
|||
find_mlir_lib(MLIRTransforms)
|
||||
find_mlir_lib(MLIRTransformUtils)
|
||||
find_mlir_lib(MLIRSupport)
|
||||
find_mlir_lib(MLIROptMain)
|
||||
find_mlir_lib(MLIRMlirOptMain)
|
||||
find_mlir_lib(MLIROptLib)
|
||||
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
|
||||
find_mlir_lib(MLIRTargetLLVMIR)
|
||||
find_mlir_lib(MLIRTransformUtils)
|
||||
|
@ -117,12 +120,15 @@ set(MLIRLibsOnce
|
|||
${MLIRAffineToStandard}
|
||||
${MLIRAnalysis}
|
||||
${MLIRDialect}
|
||||
${MLIREDSC}
|
||||
${MLIRExecutionEngine}
|
||||
${MLIRIR}
|
||||
${MLIRLLVMIR}
|
||||
${MLIRLoopToStandard}
|
||||
${MLIRLoopOps}
|
||||
${MLIROptMain}
|
||||
${MLIRLoopAnalysis}
|
||||
${MLIRMlirOptMain}
|
||||
${MLIROptLib}
|
||||
${MLIRParser}
|
||||
${MLIRPass}
|
||||
${MLIRStandardOps}
|
||||
|
|
|
@ -425,7 +425,11 @@ public:
|
|||
struct TensorTypeConverter : public TypeConverter {
|
||||
using TypeConverter::TypeConverter;
|
||||
|
||||
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
|
||||
TensorTypeConverter() {
|
||||
addConversion(convertType);
|
||||
}
|
||||
|
||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
||||
results.push_back(convertTensorToMemRef(tensor_type));
|
||||
return success();
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Parser.h"
|
||||
|
@ -69,6 +70,10 @@ void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
|
|||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
mlir::registerDialect<mlir::AffineOpsDialect>();
|
||||
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||||
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include <llvm/Support/CommandLine.h>
|
||||
#include <llvm/Support/InitLLVM.h>
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
#include <mlir/InitAllDialects.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
|
@ -46,6 +47,10 @@ static llvm::cl::opt<bool> verify_passes(
|
|||
llvm::cl::init(true));
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerDialect<mlir::AffineOpsDialect>();
|
||||
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||||
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||
|
|
|
@ -224,7 +224,10 @@ public:
|
|||
|
||||
// Based on the static entry point type signature, unpack dynamic memory
|
||||
// refs to corresponding static memory refs.
|
||||
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
|
||||
auto wrappedStaticEntryPointFuncName =
|
||||
"_mlir_ciface_" + staticEntryPointFuncName.lower();
|
||||
auto *staticEntryPointFunc =
|
||||
module.lookupSymbol(wrappedStaticEntryPointFuncName);
|
||||
assert(staticEntryPointFunc &&
|
||||
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
|
||||
"entry point func must exist and be an llvm func op");
|
||||
|
@ -268,7 +271,8 @@ public:
|
|||
// Call static entry point with the memref ptrs created, and get output.
|
||||
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
|
||||
loc, staticEntryPointTy.getFunctionResultType(),
|
||||
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
|
||||
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
|
||||
staticInputs);
|
||||
|
||||
// Create wrapped output.
|
||||
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
|
||||
|
@ -563,7 +567,9 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
|||
OwningRewritePatternList patterns;
|
||||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns,
|
||||
/*useAlloca=*/false,
|
||||
/*emitCWrapper=*/true);
|
||||
|
||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
|
||||
|
|
|
@ -5,10 +5,18 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
|
|||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
|
||||
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
||||
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
|
||||
|
|
Loading…
Reference in New Issue