Using onnx-mlir through incremental stages (#257)

* Add lowering of Vector dialect for lower-all-llvm pass

* Fix generating CallOp instructions when return type is void

* Fix lowering of memref

* Reformat using clang-format

* Record more context.

* Reflow comments.

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Kevin Wu 2020-09-09 21:29:55 -05:00 committed by GitHub
parent dbc41d2330
commit 03dae57189
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 7 deletions

View File

@ -12,6 +12,7 @@
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
@ -287,8 +288,7 @@ public:
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// - Copy constant data into the alloca. // - Copy constant data into the alloca.
auto memcpyRef = getOrInsertMemcpy(rewriter, module); auto memcpyRef = getOrInsertMemcpy(rewriter, module);
rewriter.create<CallOp>(loc, memcpyRef, rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
LLVM::LLVMType::getVoidTy(context),
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile})); ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
} else { } else {
// Some frequently used types. // Some frequently used types.
@ -381,7 +381,7 @@ public:
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// Memcpy call // Memcpy call
rewriter.create<CallOp>(loc, memcpyRef, LLVM::LLVMType::getVoidTy(context), rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
int64Size, isVolatile})); int64Size, isVolatile}));
@ -612,8 +612,19 @@ private:
// returned, otherwise return nullptr. // returned, otherwise return nullptr.
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value> params) const { API apiId, ArrayRef<Value> params) const {
// To be used as parameters in LLVM::CallOp, voidTy must be converted
// to empty list to avoid emission of an SSA value with voidTy. However,
// we still keep using LLVM voidTy (as opposed to empty list) when recording
// API function signatures in API registry because when declaring API
// functions in LLVM IR, the correct way to indicate an output type for
// "void" is still LLVM voidTy. Relevant discussion thread:
// https://github.com/onnx/onnx-mlir/issues/255.
SmallVector<Type, 1> outputTys;
auto outputTy = registry.at(apiId).outputTy;
if (!outputTy.isVoidTy())
outputTys.emplace_back(outputTy);
auto returnVals = auto returnVals =
rewriter.create<LLVM::CallOp>(loc, registry.at(apiId).outputTy, rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>(outputTys),
registry.at(apiId).symbolRef, ArrayRef<Value>(params)); registry.at(apiId).symbolRef, ArrayRef<Value>(params));
if (returnVals.getNumResults() == 1) if (returnVals.getNumResults() == 1)
return returnVals.getResult(0); return returnVals.getResult(0);
@ -642,7 +653,7 @@ private:
auto memRefTy = memRefPtrTy.getPointerElementTy(); auto memRefTy = memRefPtrTy.getPointerElementTy();
auto int64Ty = LLVM::LLVMType::getInt64Ty(context); auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefTy, ptrToMemRef); Value memRef = rewriter.create<LLVM::UndefOp>(loc, memRefTy);
// Set dataPtr and alignedDataPtr; // Set dataPtr and alignedDataPtr;
auto dataPtr = auto dataPtr =
@ -859,6 +870,8 @@ void mlir::populateAffineAndKrnlToLLVMConversion(
populateAffineToStdConversionPatterns(patterns, ctx); populateAffineToStdConversionPatterns(patterns, ctx);
populateLoopToStdConversionPatterns(patterns, ctx); populateLoopToStdConversionPatterns(patterns, ctx);
populateShapeToStandardConversionPatterns(patterns, ctx); populateShapeToStandardConversionPatterns(patterns, ctx);
populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns);
populateVectorToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns);
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>( patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
@ -883,6 +896,7 @@ void ConvertKrnlToLLVMPass::runOnOperation() {
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>(); target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalOp<LLVM::DialectCastOp>();
// Lower the MemRef types to a representation in LLVM. // Lower the MemRef types to a representation in LLVM.
LowerToLLVMOptions options; LowerToLLVMOptions options;

View File

@ -26,7 +26,7 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
/// Volatile flag /// Volatile flag
// CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1 // CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()
/// Prepare data for MemRef insertion. /// Prepare data for MemRef insertion.
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<float> // CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<float>

View File

@ -22,6 +22,6 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*x
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr<float> to !llvm.ptr<i8> // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr<float> to !llvm.ptr<i8>
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64 // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1 // CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()
// CHECK: llvm.return [[RES]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)> // CHECK: llvm.return [[RES]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
} }