Fix building ONNF with latest LLVM/MLIR (#89)
* Fix build and link errors. * Fix end to end tests. * Fix indentation. * Fix type conversion. * Use newest LLVM version. * Use newest LLVM version.
This commit is contained in:
parent
b9f2f25b56
commit
b28c6906b4
|
@ -18,7 +18,7 @@ jobs:
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
# Use cached mlir installation if possible.
|
# Use cached mlir installation if possible.
|
||||||
- restore_cache:
|
- restore_cache:
|
||||||
key: V4-LLVM-PROJECT-{{ arch }}
|
key: V6-LLVM-PROJECT-{{ arch }}
|
||||||
- run:
|
- run:
|
||||||
name: Install MLIR
|
name: Install MLIR
|
||||||
command: |
|
command: |
|
||||||
|
@ -29,7 +29,7 @@ jobs:
|
||||||
source ONNF/utils/install-mlir.sh
|
source ONNF/utils/install-mlir.sh
|
||||||
fi
|
fi
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: V4-LLVM-PROJECT-{{ arch }}
|
key: V6-LLVM-PROJECT-{{ arch }}
|
||||||
paths:
|
paths:
|
||||||
- llvm-project
|
- llvm-project
|
||||||
- run:
|
- run:
|
||||||
|
|
10
MLIR.cmake
10
MLIR.cmake
|
@ -58,9 +58,11 @@ find_mlir_lib(MLIRAffineOps)
|
||||||
find_mlir_lib(MLIRAffineToStandard)
|
find_mlir_lib(MLIRAffineToStandard)
|
||||||
find_mlir_lib(MLIRAnalysis)
|
find_mlir_lib(MLIRAnalysis)
|
||||||
find_mlir_lib(MLIRDialect)
|
find_mlir_lib(MLIRDialect)
|
||||||
|
find_mlir_lib(MLIREDSC)
|
||||||
find_mlir_lib(MLIRExecutionEngine)
|
find_mlir_lib(MLIRExecutionEngine)
|
||||||
find_mlir_lib(MLIRIR)
|
find_mlir_lib(MLIRIR)
|
||||||
find_mlir_lib(MLIRLLVMIR)
|
find_mlir_lib(MLIRLLVMIR)
|
||||||
|
find_mlir_lib(MLIRLoopAnalysis)
|
||||||
find_mlir_lib(MLIRLoopToStandard)
|
find_mlir_lib(MLIRLoopToStandard)
|
||||||
find_mlir_lib(MLIRLoopOps)
|
find_mlir_lib(MLIRLoopOps)
|
||||||
find_mlir_lib(MLIRParser)
|
find_mlir_lib(MLIRParser)
|
||||||
|
@ -71,7 +73,8 @@ find_mlir_lib(MLIRTargetLLVMIR)
|
||||||
find_mlir_lib(MLIRTransforms)
|
find_mlir_lib(MLIRTransforms)
|
||||||
find_mlir_lib(MLIRTransformUtils)
|
find_mlir_lib(MLIRTransformUtils)
|
||||||
find_mlir_lib(MLIRSupport)
|
find_mlir_lib(MLIRSupport)
|
||||||
find_mlir_lib(MLIROptMain)
|
find_mlir_lib(MLIRMlirOptMain)
|
||||||
|
find_mlir_lib(MLIROptLib)
|
||||||
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
|
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
|
||||||
find_mlir_lib(MLIRTargetLLVMIR)
|
find_mlir_lib(MLIRTargetLLVMIR)
|
||||||
find_mlir_lib(MLIRTransformUtils)
|
find_mlir_lib(MLIRTransformUtils)
|
||||||
|
@ -117,12 +120,15 @@ set(MLIRLibsOnce
|
||||||
${MLIRAffineToStandard}
|
${MLIRAffineToStandard}
|
||||||
${MLIRAnalysis}
|
${MLIRAnalysis}
|
||||||
${MLIRDialect}
|
${MLIRDialect}
|
||||||
|
${MLIREDSC}
|
||||||
${MLIRExecutionEngine}
|
${MLIRExecutionEngine}
|
||||||
${MLIRIR}
|
${MLIRIR}
|
||||||
${MLIRLLVMIR}
|
${MLIRLLVMIR}
|
||||||
${MLIRLoopToStandard}
|
${MLIRLoopToStandard}
|
||||||
${MLIRLoopOps}
|
${MLIRLoopOps}
|
||||||
${MLIROptMain}
|
${MLIRLoopAnalysis}
|
||||||
|
${MLIRMlirOptMain}
|
||||||
|
${MLIROptLib}
|
||||||
${MLIRParser}
|
${MLIRParser}
|
||||||
${MLIRPass}
|
${MLIRPass}
|
||||||
${MLIRStandardOps}
|
${MLIRStandardOps}
|
||||||
|
|
|
@ -425,7 +425,11 @@ public:
|
||||||
struct TensorTypeConverter : public TypeConverter {
|
struct TensorTypeConverter : public TypeConverter {
|
||||||
using TypeConverter::TypeConverter;
|
using TypeConverter::TypeConverter;
|
||||||
|
|
||||||
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
|
TensorTypeConverter() {
|
||||||
|
addConversion(convertType);
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||||
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
||||||
results.push_back(convertTensorToMemRef(tensor_type));
|
results.push_back(convertTensorToMemRef(tensor_type));
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
|
#include "mlir/InitAllDialects.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
#include "mlir/Parser.h"
|
#include "mlir/Parser.h"
|
||||||
|
@ -69,6 +70,10 @@ void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
mlir::registerDialect<mlir::AffineOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||||||
|
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#include <llvm/Support/CommandLine.h>
|
#include <llvm/Support/CommandLine.h>
|
||||||
#include <llvm/Support/InitLLVM.h>
|
#include <llvm/Support/InitLLVM.h>
|
||||||
#include <llvm/Support/ToolOutputFile.h>
|
#include <llvm/Support/ToolOutputFile.h>
|
||||||
|
#include <mlir/InitAllDialects.h>
|
||||||
#include <mlir/Pass/Pass.h>
|
#include <mlir/Pass/Pass.h>
|
||||||
#include <mlir/Pass/PassManager.h>
|
#include <mlir/Pass/PassManager.h>
|
||||||
#include <mlir/Support/FileUtilities.h>
|
#include <mlir/Support/FileUtilities.h>
|
||||||
|
@ -46,6 +47,10 @@ static llvm::cl::opt<bool> verify_passes(
|
||||||
llvm::cl::init(true));
|
llvm::cl::init(true));
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
|
mlir::registerDialect<mlir::AffineOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||||||
|
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
||||||
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
|
||||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
|
|
|
@ -224,7 +224,10 @@ public:
|
||||||
|
|
||||||
// Based on the static entry point type signature, unpack dynamic memory
|
// Based on the static entry point type signature, unpack dynamic memory
|
||||||
// refs to corresponding static memory refs.
|
// refs to corresponding static memory refs.
|
||||||
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
|
auto wrappedStaticEntryPointFuncName =
|
||||||
|
"_mlir_ciface_" + staticEntryPointFuncName.lower();
|
||||||
|
auto *staticEntryPointFunc =
|
||||||
|
module.lookupSymbol(wrappedStaticEntryPointFuncName);
|
||||||
assert(staticEntryPointFunc &&
|
assert(staticEntryPointFunc &&
|
||||||
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
|
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
|
||||||
"entry point func must exist and be an llvm func op");
|
"entry point func must exist and be an llvm func op");
|
||||||
|
@ -268,7 +271,8 @@ public:
|
||||||
// Call static entry point with the memref ptrs created, and get output.
|
// Call static entry point with the memref ptrs created, and get output.
|
||||||
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
|
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
|
||||||
loc, staticEntryPointTy.getFunctionResultType(),
|
loc, staticEntryPointTy.getFunctionResultType(),
|
||||||
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
|
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
|
||||||
|
staticInputs);
|
||||||
|
|
||||||
// Create wrapped output.
|
// Create wrapped output.
|
||||||
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
|
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
|
||||||
|
@ -563,7 +567,9 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||||
populateLoopToStdConversionPatterns(patterns, &getContext());
|
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
populateStdToLLVMConversionPatterns(typeConverter, patterns,
|
||||||
|
/*useAlloca=*/false,
|
||||||
|
/*emitCWrapper=*/true);
|
||||||
|
|
||||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
|
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
|
||||||
|
|
|
@ -5,10 +5,18 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
|
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
|
||||||
|
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||||
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||||
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
|
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
|
||||||
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
|
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
|
||||||
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
||||||
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
|
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
|
||||||
|
|
Loading…
Reference in New Issue