Using onnx-mlir through incremental stages (#257)
* Add lowering of Vector dialect for lower-all-llvm pass * Fix generating CallOp instructions when return type is void * Fix lowering of memref * Reformat using clang-format * Record more context. * Reflow comments. Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
dbc41d2330
commit
03dae57189
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
@ -287,8 +288,7 @@ public:
|
|||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||
// - Copy constant data into the alloca.
|
||||
auto memcpyRef = getOrInsertMemcpy(rewriter, module);
|
||||
rewriter.create<CallOp>(loc, memcpyRef,
|
||||
LLVM::LLVMType::getVoidTy(context),
|
||||
rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
|
||||
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
||||
} else {
|
||||
// Some frequently used types.
|
||||
|
@ -381,7 +381,7 @@ public:
|
|||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||
|
||||
// Memcpy call
|
||||
rewriter.create<CallOp>(loc, memcpyRef, LLVM::LLVMType::getVoidTy(context),
|
||||
rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
|
||||
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
||||
int64Size, isVolatile}));
|
||||
|
||||
|
@ -612,8 +612,19 @@ private:
|
|||
// returned, otherwise return nullptr.
|
||||
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
||||
API apiId, ArrayRef<Value> params) const {
|
||||
// To be used as parameters in LLVM::CallOp, voidTy must be converted
|
||||
// to empty list to avoid emission of an SSA value with voidTy. However,
|
||||
// we still keep using LLVM voidTy (as opposed to empty list) when recording
|
||||
// API function signatures in API registry because when declaring API
|
||||
// functions in LLVM IR, the correct way to indicate an output type for
|
||||
// "void" is still LLVM voidTy. Relevant discussion thread:
|
||||
// https://github.com/onnx/onnx-mlir/issues/255.
|
||||
SmallVector<Type, 1> outputTys;
|
||||
auto outputTy = registry.at(apiId).outputTy;
|
||||
if (!outputTy.isVoidTy())
|
||||
outputTys.emplace_back(outputTy);
|
||||
auto returnVals =
|
||||
rewriter.create<LLVM::CallOp>(loc, registry.at(apiId).outputTy,
|
||||
rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>(outputTys),
|
||||
registry.at(apiId).symbolRef, ArrayRef<Value>(params));
|
||||
if (returnVals.getNumResults() == 1)
|
||||
return returnVals.getResult(0);
|
||||
|
@ -642,7 +653,7 @@ private:
|
|||
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
|
||||
|
||||
Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefTy, ptrToMemRef);
|
||||
Value memRef = rewriter.create<LLVM::UndefOp>(loc, memRefTy);
|
||||
|
||||
// Set dataPtr and alignedDataPtr;
|
||||
auto dataPtr =
|
||||
|
@ -859,6 +870,8 @@ void mlir::populateAffineAndKrnlToLLVMConversion(
|
|||
populateAffineToStdConversionPatterns(patterns, ctx);
|
||||
populateLoopToStdConversionPatterns(patterns, ctx);
|
||||
populateShapeToStandardConversionPatterns(patterns, ctx);
|
||||
populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns);
|
||||
populateVectorToLLVMConversionPatterns(typeConverter, patterns);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
|
||||
|
@ -883,6 +896,7 @@ void ConvertKrnlToLLVMPass::runOnOperation() {
|
|||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addIllegalOp<LLVM::DialectCastOp>();
|
||||
|
||||
// Lower the MemRef types to a representation in LLVM.
|
||||
LowerToLLVMOptions options;
|
||||
|
|
|
@ -26,7 +26,7 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
|
|||
/// Volatile flag
|
||||
// CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1
|
||||
|
||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
|
||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()
|
||||
|
||||
/// Prepare data for MemRef insertion.
|
||||
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<float>
|
||||
|
|
|
@ -22,6 +22,6 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*x
|
|||
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr<float> to !llvm.ptr<i8>
|
||||
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
||||
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1
|
||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
|
||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()
|
||||
// CHECK: llvm.return [[RES]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue