Fix building ONNF with latest LLVM/MLIR (#89)

* Fix build and link errors.

* Fix end to end tests.

* Fix indentation.

* Fix type conversion.

* Use newest LLVM version.

* Use newest LLVM version.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-02-19 18:15:02 -05:00 committed by GitHub
parent b9f2f25b56
commit b28c6906b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 44 additions and 10 deletions

View File

@ -18,7 +18,7 @@ jobs:
git submodule update --init --recursive
# Use cached mlir installation if possible.
- restore_cache:
key: V4-LLVM-PROJECT-{{ arch }}
key: V6-LLVM-PROJECT-{{ arch }}
- run:
name: Install MLIR
command: |
@ -29,7 +29,7 @@ jobs:
source ONNF/utils/install-mlir.sh
fi
- save_cache:
key: V4-LLVM-PROJECT-{{ arch }}
key: V6-LLVM-PROJECT-{{ arch }}
paths:
- llvm-project
- run:

View File

@ -58,9 +58,11 @@ find_mlir_lib(MLIRAffineOps)
find_mlir_lib(MLIRAffineToStandard)
find_mlir_lib(MLIRAnalysis)
find_mlir_lib(MLIRDialect)
find_mlir_lib(MLIREDSC)
find_mlir_lib(MLIRExecutionEngine)
find_mlir_lib(MLIRIR)
find_mlir_lib(MLIRLLVMIR)
find_mlir_lib(MLIRLoopAnalysis)
find_mlir_lib(MLIRLoopToStandard)
find_mlir_lib(MLIRLoopOps)
find_mlir_lib(MLIRParser)
@ -71,7 +73,8 @@ find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransforms)
find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRSupport)
find_mlir_lib(MLIROptMain)
find_mlir_lib(MLIRMlirOptMain)
find_mlir_lib(MLIROptLib)
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransformUtils)
@ -117,12 +120,15 @@ set(MLIRLibsOnce
${MLIRAffineToStandard}
${MLIRAnalysis}
${MLIRDialect}
${MLIREDSC}
${MLIRExecutionEngine}
${MLIRIR}
${MLIRLLVMIR}
${MLIRLoopToStandard}
${MLIRLoopOps}
${MLIROptMain}
${MLIRLoopAnalysis}
${MLIRMlirOptMain}
${MLIROptLib}
${MLIRParser}
${MLIRPass}
${MLIRStandardOps}

View File

@ -425,7 +425,11 @@ public:
struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
TensorTypeConverter() {
addConversion(convertType);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto tensor_type = t.dyn_cast<TensorType>()) {
results.push_back(convertTensorToMemRef(tensor_type));
return success();

View File

@ -24,6 +24,7 @@
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllDialects.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
@ -69,6 +70,10 @@ void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
}
int main(int argc, char *argv[]) {
mlir::registerDialect<mlir::AffineOpsDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>();

View File

@ -9,6 +9,7 @@
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/InitLLVM.h>
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/InitAllDialects.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
@ -46,6 +47,10 @@ static llvm::cl::opt<bool> verify_passes(
llvm::cl::init(true));
int main(int argc, char **argv) {
mlir::registerDialect<mlir::AffineOpsDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
llvm::InitLLVM y(argc, argv);
mlir::registerDialect<mlir::ONNXOpsDialect>();

View File

@ -224,7 +224,10 @@ public:
// Based on the static entry point type signature, unpack dynamic memory
// refs to corresponding static memory refs.
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
auto wrappedStaticEntryPointFuncName =
"_mlir_ciface_" + staticEntryPointFuncName.lower();
auto *staticEntryPointFunc =
module.lookupSymbol(wrappedStaticEntryPointFuncName);
assert(staticEntryPointFunc &&
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
"entry point func must exist and be an llvm func op");
@ -268,7 +271,8 @@ public:
// Call static entry point with the memref ptrs created, and get output.
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
loc, staticEntryPointTy.getFunctionResultType(),
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
staticInputs);
// Create wrapped output.
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
@ -563,7 +567,9 @@ void KrnlToLLVMLoweringPass::runOnModule() {
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns,
/*useAlloca=*/false,
/*emitCWrapper=*/true);
// Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,

View File

@ -5,10 +5,18 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1