Using onnx-mlir through incremental stages (#257)
* Add lowering of Vector dialect for lower-all-llvm pass * Fix generating CallOp instructions when return type is void * Fix lowering of memref * Reformat using clang-format * Record more context. * Reflow comments. Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
dbc41d2330
commit
03dae57189
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
@ -287,8 +288,7 @@ public:
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||||
// - Copy constant data into the alloca.
|
// - Copy constant data into the alloca.
|
||||||
auto memcpyRef = getOrInsertMemcpy(rewriter, module);
|
auto memcpyRef = getOrInsertMemcpy(rewriter, module);
|
||||||
rewriter.create<CallOp>(loc, memcpyRef,
|
rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
|
||||||
LLVM::LLVMType::getVoidTy(context),
|
|
||||||
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
||||||
} else {
|
} else {
|
||||||
// Some frequently used types.
|
// Some frequently used types.
|
||||||
|
@ -381,7 +381,7 @@ public:
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||||
|
|
||||||
// Memcpy call
|
// Memcpy call
|
||||||
rewriter.create<CallOp>(loc, memcpyRef, LLVM::LLVMType::getVoidTy(context),
|
rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
|
||||||
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
||||||
int64Size, isVolatile}));
|
int64Size, isVolatile}));
|
||||||
|
|
||||||
|
@ -612,8 +612,19 @@ private:
|
||||||
// returned, otherwise return nullptr.
|
// returned, otherwise return nullptr.
|
||||||
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
||||||
API apiId, ArrayRef<Value> params) const {
|
API apiId, ArrayRef<Value> params) const {
|
||||||
|
// To be used as parameters in LLVM::CallOp, voidTy must be converted
|
||||||
|
// to empty list to avoid emission of an SSA value with voidTy. However,
|
||||||
|
// we still keep using LLVM voidTy (as opposed to empty list) when recording
|
||||||
|
// API function signatures in API registry because when declaring API
|
||||||
|
// functions in LLVM IR, the correct way to indicate an output type for
|
||||||
|
// "void" is still LLVM voidTy. Relevant discussion thread:
|
||||||
|
// https://github.com/onnx/onnx-mlir/issues/255.
|
||||||
|
SmallVector<Type, 1> outputTys;
|
||||||
|
auto outputTy = registry.at(apiId).outputTy;
|
||||||
|
if (!outputTy.isVoidTy())
|
||||||
|
outputTys.emplace_back(outputTy);
|
||||||
auto returnVals =
|
auto returnVals =
|
||||||
rewriter.create<LLVM::CallOp>(loc, registry.at(apiId).outputTy,
|
rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>(outputTys),
|
||||||
registry.at(apiId).symbolRef, ArrayRef<Value>(params));
|
registry.at(apiId).symbolRef, ArrayRef<Value>(params));
|
||||||
if (returnVals.getNumResults() == 1)
|
if (returnVals.getNumResults() == 1)
|
||||||
return returnVals.getResult(0);
|
return returnVals.getResult(0);
|
||||||
|
@ -642,7 +653,7 @@ private:
|
||||||
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
||||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
|
||||||
|
|
||||||
Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefTy, ptrToMemRef);
|
Value memRef = rewriter.create<LLVM::UndefOp>(loc, memRefTy);
|
||||||
|
|
||||||
// Set dataPtr and alignedDataPtr;
|
// Set dataPtr and alignedDataPtr;
|
||||||
auto dataPtr =
|
auto dataPtr =
|
||||||
|
@ -859,6 +870,8 @@ void mlir::populateAffineAndKrnlToLLVMConversion(
|
||||||
populateAffineToStdConversionPatterns(patterns, ctx);
|
populateAffineToStdConversionPatterns(patterns, ctx);
|
||||||
populateLoopToStdConversionPatterns(patterns, ctx);
|
populateLoopToStdConversionPatterns(patterns, ctx);
|
||||||
populateShapeToStandardConversionPatterns(patterns, ctx);
|
populateShapeToStandardConversionPatterns(patterns, ctx);
|
||||||
|
populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns);
|
||||||
|
populateVectorToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
|
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
|
||||||
|
@ -883,6 +896,7 @@ void ConvertKrnlToLLVMPass::runOnOperation() {
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
|
target.addIllegalOp<LLVM::DialectCastOp>();
|
||||||
|
|
||||||
// Lower the MemRef types to a representation in LLVM.
|
// Lower the MemRef types to a representation in LLVM.
|
||||||
LowerToLLVMOptions options;
|
LowerToLLVMOptions options;
|
||||||
|
|
|
@ -26,7 +26,7 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
|
||||||
/// Volatile flag
|
/// Volatile flag
|
||||||
// CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1
|
// CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1
|
||||||
|
|
||||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
|
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()
|
||||||
|
|
||||||
/// Prepare data for MemRef insertion.
|
/// Prepare data for MemRef insertion.
|
||||||
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<float>
|
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<float>
|
||||||
|
|
|
@ -22,6 +22,6 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*x
|
||||||
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr<float> to !llvm.ptr<i8>
|
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr<float> to !llvm.ptr<i8>
|
||||||
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
||||||
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1
|
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1
|
||||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
|
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()
|
||||||
// CHECK: llvm.return [[RES]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
|
// CHECK: llvm.return [[RES]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue