Fix building ONNF with latest LLVM/MLIR (#89)

* Fix build and link errors.

* Fix end to end tests.

* Fix indentation.

* Fix type conversion.

* Use newest LLVM version.

* Use newest LLVM version.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-02-19 18:15:02 -05:00 committed by GitHub
parent b9f2f25b56
commit b28c6906b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 44 additions and 10 deletions

View File

@ -18,7 +18,7 @@ jobs:
git submodule update --init --recursive git submodule update --init --recursive
# Use cached mlir installation if possible. # Use cached mlir installation if possible.
- restore_cache: - restore_cache:
key: V4-LLVM-PROJECT-{{ arch }} key: V6-LLVM-PROJECT-{{ arch }}
- run: - run:
name: Install MLIR name: Install MLIR
command: | command: |
@ -29,7 +29,7 @@ jobs:
source ONNF/utils/install-mlir.sh source ONNF/utils/install-mlir.sh
fi fi
- save_cache: - save_cache:
key: V4-LLVM-PROJECT-{{ arch }} key: V6-LLVM-PROJECT-{{ arch }}
paths: paths:
- llvm-project - llvm-project
- run: - run:

View File

@ -58,9 +58,11 @@ find_mlir_lib(MLIRAffineOps)
find_mlir_lib(MLIRAffineToStandard) find_mlir_lib(MLIRAffineToStandard)
find_mlir_lib(MLIRAnalysis) find_mlir_lib(MLIRAnalysis)
find_mlir_lib(MLIRDialect) find_mlir_lib(MLIRDialect)
find_mlir_lib(MLIREDSC)
find_mlir_lib(MLIRExecutionEngine) find_mlir_lib(MLIRExecutionEngine)
find_mlir_lib(MLIRIR) find_mlir_lib(MLIRIR)
find_mlir_lib(MLIRLLVMIR) find_mlir_lib(MLIRLLVMIR)
find_mlir_lib(MLIRLoopAnalysis)
find_mlir_lib(MLIRLoopToStandard) find_mlir_lib(MLIRLoopToStandard)
find_mlir_lib(MLIRLoopOps) find_mlir_lib(MLIRLoopOps)
find_mlir_lib(MLIRParser) find_mlir_lib(MLIRParser)
@ -71,7 +73,8 @@ find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransforms) find_mlir_lib(MLIRTransforms)
find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRSupport) find_mlir_lib(MLIRSupport)
find_mlir_lib(MLIROptMain) find_mlir_lib(MLIRMlirOptMain)
find_mlir_lib(MLIROptLib)
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation) find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
find_mlir_lib(MLIRTargetLLVMIR) find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRTransformUtils)
@ -117,12 +120,15 @@ set(MLIRLibsOnce
${MLIRAffineToStandard} ${MLIRAffineToStandard}
${MLIRAnalysis} ${MLIRAnalysis}
${MLIRDialect} ${MLIRDialect}
${MLIREDSC}
${MLIRExecutionEngine} ${MLIRExecutionEngine}
${MLIRIR} ${MLIRIR}
${MLIRLLVMIR} ${MLIRLLVMIR}
${MLIRLoopToStandard} ${MLIRLoopToStandard}
${MLIRLoopOps} ${MLIRLoopOps}
${MLIROptMain} ${MLIRLoopAnalysis}
${MLIRMlirOptMain}
${MLIROptLib}
${MLIRParser} ${MLIRParser}
${MLIRPass} ${MLIRPass}
${MLIRStandardOps} ${MLIRStandardOps}

View File

@ -425,7 +425,11 @@ public:
struct TensorTypeConverter : public TypeConverter { struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter; using TypeConverter::TypeConverter;
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override { TensorTypeConverter() {
addConversion(convertType);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto tensor_type = t.dyn_cast<TensorType>()) { if (auto tensor_type = t.dyn_cast<TensorType>()) {
results.push_back(convertTensorToMemRef(tensor_type)); results.push_back(convertTensorToMemRef(tensor_type));
return success(); return success();

View File

@ -24,6 +24,7 @@
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllDialects.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/Parser.h" #include "mlir/Parser.h"
@ -69,6 +70,10 @@ void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
} }
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
mlir::registerDialect<mlir::AffineOpsDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::ONNXOpsDialect>(); mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>(); mlir::registerDialect<mlir::KrnlOpsDialect>();

View File

@ -9,6 +9,7 @@
#include <llvm/Support/CommandLine.h> #include <llvm/Support/CommandLine.h>
#include <llvm/Support/InitLLVM.h> #include <llvm/Support/InitLLVM.h>
#include <llvm/Support/ToolOutputFile.h> #include <llvm/Support/ToolOutputFile.h>
#include <mlir/InitAllDialects.h>
#include <mlir/Pass/Pass.h> #include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h> #include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h> #include <mlir/Support/FileUtilities.h>
@ -46,6 +47,10 @@ static llvm::cl::opt<bool> verify_passes(
llvm::cl::init(true)); llvm::cl::init(true));
int main(int argc, char **argv) { int main(int argc, char **argv) {
mlir::registerDialect<mlir::AffineOpsDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);
mlir::registerDialect<mlir::ONNXOpsDialect>(); mlir::registerDialect<mlir::ONNXOpsDialect>();

View File

@ -224,7 +224,10 @@ public:
// Based on the static entry point type signature, unpack dynamic memory // Based on the static entry point type signature, unpack dynamic memory
// refs to corresponding static memory refs. // refs to corresponding static memory refs.
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName); auto wrappedStaticEntryPointFuncName =
"_mlir_ciface_" + staticEntryPointFuncName.lower();
auto *staticEntryPointFunc =
module.lookupSymbol(wrappedStaticEntryPointFuncName);
assert(staticEntryPointFunc && assert(staticEntryPointFunc &&
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) && isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
"entry point func must exist and be an llvm func op"); "entry point func must exist and be an llvm func op");
@ -268,7 +271,8 @@ public:
// Call static entry point with the memref ptrs created, and get output. // Call static entry point with the memref ptrs created, and get output.
auto outputMemRefs = rewriter.create<LLVM::CallOp>( auto outputMemRefs = rewriter.create<LLVM::CallOp>(
loc, staticEntryPointTy.getFunctionResultType(), loc, staticEntryPointTy.getFunctionResultType(),
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs); rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
staticInputs);
// Create wrapped output. // Create wrapped output.
auto wrappedOutput = callApi(rewriter, loc, apiRegistry, auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
@ -563,7 +567,9 @@ void KrnlToLLVMLoweringPass::runOnModule() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext()); populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns,
/*useAlloca=*/false,
/*emitCWrapper=*/true);
// Lower from the `krnl` dialect i.e. the Reshape operation. // Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering, patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,

View File

@ -5,10 +5,18 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> // CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> // CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*"> // CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*"> // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64 // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1 // CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1