From 03dae57189a11d62bc78c52a4a3cd3206b78691d Mon Sep 17 00:00:00 2001 From: Kevin Wu <6334443+kwu91@users.noreply.github.com> Date: Wed, 9 Sep 2020 21:29:55 -0500 Subject: [PATCH] Using onnx-mlir through incremental stages (#257) * Add lowering of Vector dialect for lower-all-llvm pass * Fix generating CallOp instructions when return type is void * Fix lowering of memref * Reformat using clang-format * Record more context. * Reflow comments. Co-authored-by: Tian Jin --- src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp | 24 +++++++++++++++++++----- test/mlir/krnl/constant.mlir | 2 +- test/mlir/krnl/reshape.mlir | 2 +- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp index febbc70..108d7c1 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/SCF.h" @@ -287,8 +288,7 @@ public: rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); // - Copy constant data into the alloca. auto memcpyRef = getOrInsertMemcpy(rewriter, module); - rewriter.create(loc, memcpyRef, - LLVM::LLVMType::getVoidTy(context), + rewriter.create(loc, memcpyRef, ArrayRef({}), ArrayRef({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile})); } else { // Some frequently used types. @@ -381,7 +381,7 @@ public: rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); // Memcpy call - rewriter.create(loc, memcpyRef, LLVM::LLVMType::getVoidTy(context), + rewriter.create(loc, memcpyRef, ArrayRef({}), ArrayRef({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size, isVolatile})); @@ -612,8 +612,19 @@ private: // returned, otherwise return nullptr. Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, API apiId, ArrayRef params) const { + // To be used as parameters in LLVM::CallOp, voidTy must be converted + // to empty list to avoid emission of an SSA value with voidTy. However, + // we still keep using LLVM voidTy (as opposed to empty list) when recording + // API function signatures in API registry because when declaring API + // functions in LLVM IR, the correct way to indicate an output type for + // "void" is still LLVM voidTy. Relevant discussion thread: + // https://github.com/onnx/onnx-mlir/issues/255. + SmallVector outputTys; + auto outputTy = registry.at(apiId).outputTy; + if (!outputTy.isVoidTy()) + outputTys.emplace_back(outputTy); auto returnVals = - rewriter.create(loc, registry.at(apiId).outputTy, + rewriter.create(loc, ArrayRef(outputTys), registry.at(apiId).symbolRef, ArrayRef(params)); if (returnVals.getNumResults() == 1) return returnVals.getResult(0); @@ -642,7 +653,7 @@ private: auto memRefTy = memRefPtrTy.getPointerElementTy(); auto int64Ty = LLVM::LLVMType::getInt64Ty(context); - Value memRef = rewriter.create(loc, memRefTy, ptrToMemRef); + Value memRef = rewriter.create(loc, memRefTy); // Set dataPtr and alignedDataPtr; auto dataPtr = @@ -859,6 +870,8 @@ void mlir::populateAffineAndKrnlToLLVMConversion( populateAffineToStdConversionPatterns(patterns, ctx); populateLoopToStdConversionPatterns(patterns, ctx); populateShapeToStandardConversionPatterns(patterns, ctx); + populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns); + populateVectorToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns); patterns.insert( @@ -883,6 +896,7 @@ void ConvertKrnlToLLVMPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); + target.addIllegalOp(); // Lower the MemRef types to a representation in LLVM. LowerToLLVMOptions options; diff --git a/test/mlir/krnl/constant.mlir b/test/mlir/krnl/constant.mlir index 884c210..5790529 100644 --- a/test/mlir/krnl/constant.mlir +++ b/test/mlir/krnl/constant.mlir @@ -26,7 +26,7 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> { /// Volatile flag // CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1 - // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> !llvm.void + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> () /// Prepare data for MemRef insertion. // CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr>> to !llvm.ptr diff --git a/test/mlir/krnl/reshape.mlir b/test/mlir/krnl/reshape.mlir index 5b349e5..81c0e28 100644 --- a/test/mlir/krnl/reshape.mlir +++ b/test/mlir/krnl/reshape.mlir @@ -22,6 +22,6 @@ func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi64>) -> tensor<*x // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr to !llvm.ptr // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64 // CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1 - // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> !llvm.void + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> () // CHECK: llvm.return [[RES]] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> }