922 lines
39 KiB
C++
922 lines
39 KiB
C++
//===------ LowerToLLVM.cpp - Lowering from KRNL+Affine+Std to LLVM -------===//
|
|
//
|
|
// Copyright 2019-2020 The IBM Research Authors.
|
|
//
|
|
// =============================================================================
|
|
//
|
|
//
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
|
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
|
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "onnx/onnx_pb.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
|
|
#include "src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp"
|
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
|
#include "src/Pass/Passes.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
static onnx::TensorProto::DataType llvmTypeToOnnxType(
|
|
mlir::LLVM::LLVMType elemType) {
|
|
if (elemType.isFloatTy())
|
|
return onnx::TensorProto::FLOAT;
|
|
if (elemType.isUnsignedInteger(8))
|
|
return onnx::TensorProto::UINT8;
|
|
if (elemType.isSignedInteger(8))
|
|
return onnx::TensorProto::INT8;
|
|
if (elemType.isUnsignedInteger(16))
|
|
return onnx::TensorProto::UINT16;
|
|
if (elemType.isSignedInteger(16))
|
|
return onnx::TensorProto::INT16;
|
|
if (elemType.isSignedInteger(32))
|
|
return onnx::TensorProto::INT32;
|
|
if (elemType.isSignedInteger(64))
|
|
return onnx::TensorProto::INT64;
|
|
// TODO, wait for Tong's input about how string is represented in MLIR.
|
|
if (elemType.isInteger(1))
|
|
return onnx::TensorProto::BOOL;
|
|
if (elemType.isHalfTy())
|
|
return onnx::TensorProto::FLOAT16;
|
|
if (elemType.isDoubleTy())
|
|
return onnx::TensorProto::DOUBLE;
|
|
if (elemType.isUnsignedInteger(32))
|
|
return onnx::TensorProto::UINT32;
|
|
if (elemType.isUnsignedInteger(64))
|
|
return onnx::TensorProto::INT64;
|
|
// LLVM Dialect does not have signed/unsigned int, only signless int
|
|
if (elemType.isIntegerTy(8))
|
|
return onnx::TensorProto::INT8;
|
|
if (elemType.isIntegerTy(16))
|
|
return onnx::TensorProto::INT16;
|
|
if (elemType.isIntegerTy(32))
|
|
return onnx::TensorProto::INT32;
|
|
if (elemType.isIntegerTy(64))
|
|
return onnx::TensorProto::INT64;
|
|
// Complex types don't seem to exist in LLVM Dialect.
|
|
elemType.dump();
|
|
llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type.");
|
|
}
|
|
|
|
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
|
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
|
|
auto *context = module.getContext();
|
|
if (auto sym = module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
|
|
assert(sym.getType() == funcType && "wrong symbol type");
|
|
return SymbolRefAttr::get(funcName, context);
|
|
}
|
|
|
|
// Insert the function into the body of the parent module.
|
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
|
rewriter.setInsertionPointToStart(module.getBody());
|
|
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType);
|
|
return SymbolRefAttr::get(funcName, context);
|
|
}
|
|
|
|
static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
|
|
// Usually a MemRef is a 5-element struct, where the 4th and 5th elements in
|
|
// this struct are arrays whose size is the rank of the tensor. In the event
|
|
// that the corresponding tensor of this MemRef is a scalar, the 4th and 5th
|
|
// elements will have 0-length, which in turn causes the MemRef struct to
|
|
// degenerate into a 3-element struct. For more information, refer to
|
|
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/ConversionToLLVMDialect.md#memref-types.
|
|
auto numElems = memRefTy.getStructNumElements();
|
|
assert((numElems == 3 || numElems == 5) &&
|
|
"Expect MemRef type to contain either 3 or 5 elements.");
|
|
|
|
if (numElems == 3)
|
|
return 0; // MemRef refers to a scalar.
|
|
else
|
|
return memRefTy.getStructElementType(3).getArrayNumElements();
|
|
}
|
|
|
|
/// Return a symbol reference to the memcpy function, inserting it into the
|
|
/// module if necessary.
|
|
static FlatSymbolRefAttr getOrInsertMemcpy(
|
|
PatternRewriter &rewriter, ModuleOp module) {
|
|
auto *context = module.getContext();
|
|
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
|
// Create a function declaration for memcpy, the signature is:
|
|
// * `void (i8*, i8* , i64, i1)`
|
|
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(context);
|
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
|
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
|
|
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(context);
|
|
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
|
ArrayRef<mlir::LLVM::LLVMType>(
|
|
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
|
false);
|
|
|
|
// Insert the memcpy function into the body of the parent module.
|
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
|
rewriter.setInsertionPointToStart(module.getBody());
|
|
rewriter.create<LLVM::LLVMFuncOp>(
|
|
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// KRNL to LLVM: KrnlGetRefOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class KrnlGetRefOpLowering : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit KrnlGetRefOpLowering(
|
|
MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(
|
|
KrnlGetRefOp::getOperationName(), context, lowering_) {}
|
|
|
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto *context = op->getContext();
|
|
auto loc = op->getLoc();
|
|
|
|
KrnlGetRefOpAdaptor operandAdaptor(operands);
|
|
|
|
// This is the type of the krnl.getref output. This type is used
|
|
// for the type of the internal MemRef.
|
|
auto type = op->getResult(0).getType();
|
|
auto memRefTy = type.cast<mlir::MemRefType>();
|
|
auto llvmMemRefType =
|
|
typeConverter.convertType(type).cast<LLVM::LLVMType>();
|
|
auto outputElementType =
|
|
typeConverter.convertType(memRefTy.getElementType());
|
|
|
|
// This is the start of the memory pool containing the output MemRef.
|
|
Type memPoolType = operandAdaptor.mempool()
|
|
.getType()
|
|
.cast<LLVM::LLVMType>()
|
|
.getStructElementType(1);
|
|
Value alignedMemPoolBase = rewriter.create<LLVM::ExtractValueOp>(loc,
|
|
memPoolType, operandAdaptor.mempool(), rewriter.getI64ArrayAttr(1));
|
|
|
|
// Get pointer using the offset.
|
|
auto offset = operandAdaptor.offset();
|
|
auto llvmMemPoolType =
|
|
typeConverter.convertType(memPoolType).cast<LLVM::LLVMType>();
|
|
auto outputMemPoolTypePtrAlloc = rewriter.create<LLVM::GEPOp>(
|
|
loc, llvmMemPoolType, alignedMemPoolBase, ArrayRef<Value>({offset}));
|
|
|
|
// Bitcast to output MemRef type i.e. from i8* to the element type
|
|
// of the output MemRef.
|
|
auto llvmOutputElementType = outputElementType.cast<LLVM::LLVMType>();
|
|
Value outputTypedPtrAlloc = rewriter.create<LLVM::BitcastOp>(
|
|
loc, llvmOutputElementType.getPointerTo(), outputMemPoolTypePtrAlloc);
|
|
|
|
// Create llvm MemRef from original MemRef and fill the data pointers.
|
|
auto llvmMemRef = MemRefDescriptor::fromStaticShape(
|
|
rewriter, loc, typeConverter, memRefTy, outputTypedPtrAlloc);
|
|
|
|
rewriter.replaceOp(op, {llvmMemRef});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// KRNL to LLVM: KrnlGlobalOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit KrnlGlobalOpLowering(
|
|
MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(
|
|
KrnlGlobalOp::getOperationName(), context, lowering_) {}
|
|
|
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto *context = op->getContext();
|
|
auto loc = op->getLoc();
|
|
|
|
auto krnlGlobalOp = llvm::dyn_cast<KrnlGlobalOp>(op);
|
|
|
|
// Get module.
|
|
ModuleOp module = op->getParentOfType<ModuleOp>();
|
|
|
|
// Global name.
|
|
auto name = krnlGlobalOp.name();
|
|
|
|
// Compute total number of elements.
|
|
auto shape = (krnlGlobalOp.shape()).dyn_cast<ArrayAttr>();
|
|
int64_t numElements = 1;
|
|
for (int i = 0; i < shape.size(); ++i)
|
|
numElements *= ArrayAttrIntVal(shape, i);
|
|
|
|
// Create the global at the entry of the module.
|
|
LLVM::GlobalOp global;
|
|
auto type = op->getResult(0).getType();
|
|
auto memRefTy = type.cast<mlir::MemRefType>();
|
|
auto llvmMemRefType =
|
|
typeConverter.convertType(type).cast<LLVM::LLVMType>();
|
|
|
|
// The element type of the array.
|
|
auto constantElementType =
|
|
typeConverter.convertType(memRefTy.getElementType());
|
|
auto globalType = constantElementType;
|
|
for (int i = shape.size() - 1; i >= 0; i--)
|
|
globalType = LLVM::LLVMType::getArrayTy(
|
|
globalType.cast<LLVM::LLVMType>(), ArrayAttrIntVal(shape, i));
|
|
// The llvm type of the global (example: [2 x [8 x float]])
|
|
auto llvmGlobalType = globalType.cast<LLVM::LLVMType>();
|
|
|
|
mlir::Value alloc;
|
|
if (krnlGlobalOp.value().hasValue()) {
|
|
{
|
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
|
rewriter.setInsertionPointToStart(module.getBody());
|
|
|
|
assert(krnlGlobalOp.value().hasValue() &&
|
|
"Krnl Global must always have a value");
|
|
global = rewriter.create<LLVM::GlobalOp>(loc, llvmGlobalType,
|
|
/*isConstant=*/true, LLVM::Linkage::Internal, name,
|
|
krnlGlobalOp.value().getValue());
|
|
}
|
|
|
|
// Some frequently used types.
|
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
|
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
|
|
|
|
// Allocate the memory where the constants will be used from.
|
|
// This is a region of local memory and needs to be emitted as an alloca.
|
|
auto one = rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
|
|
alloc = rewriter.create<LLVM::AllocaOp>(
|
|
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
|
|
|
|
// Copy constant value into the local alloca:
|
|
// - Bitcast alloc to i8*
|
|
Value int8PtrAlloc =
|
|
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, alloc);
|
|
// - Bitcast global to i8*
|
|
Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
|
|
Value i8PtrGlobal =
|
|
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, globalValue);
|
|
// - Set size.
|
|
Value memRefElementSize =
|
|
rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
|
|
rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy)));
|
|
Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
|
|
Value totalElementsSize = rewriter.create<LLVM::MulOp>(
|
|
loc, memRefElementSize, numElementsValue);
|
|
Value int64Size =
|
|
rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
|
|
// - Set volatile.
|
|
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
|
|
LLVM::LLVMType::getInt1Ty(context),
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
|
// - Copy constant data into the alloca.
|
|
auto memcpyRef = getOrInsertMemcpy(rewriter, module);
|
|
rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
|
|
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
|
} else {
|
|
// Some frequently used types.
|
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
|
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
|
|
|
|
// Allocate the memory where the constants will be used from.
|
|
// This is a region of local memory and needs to be emitted as an alloca.
|
|
auto one = rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
|
|
|
|
auto base = module.lookupSymbol<LLVM::GlobalOp>("packedConst");
|
|
assert(base && "Cannot find symbol packedConst.");
|
|
|
|
Value constPackBasePtrAddr =
|
|
rewriter.create<LLVM::AddressOfOp>(loc, base);
|
|
Value constPackBasePtr = rewriter.create<LLVM::LoadOp>(
|
|
loc, base.getType(), constPackBasePtrAddr);
|
|
auto offset = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
|
|
rewriter.getI64IntegerAttr(
|
|
krnlGlobalOp.offsetAttr().getValue().getSExtValue()));
|
|
alloc = rewriter.create<LLVM::GEPOp>(
|
|
loc, llvmI8PtrTy, constPackBasePtr, ValueRange({offset}));
|
|
}
|
|
// Prepare data to be inserted into MemRef.
|
|
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
|
|
Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
|
|
loc, llvmConstantElementType.getPointerTo(), alloc);
|
|
|
|
// Create llvm MemRef from original MemRef and fill the data pointers.
|
|
auto llvmMemRef = MemRefDescriptor::fromStaticShape(
|
|
rewriter, loc, typeConverter, memRefTy, typedAlloc);
|
|
|
|
rewriter.replaceOp(op, {llvmMemRef});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// KRNL to LLVM: KrnlMemcpyOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class KrnlMemcpyOpLowering : public ConversionPattern {
|
|
public:
|
|
explicit KrnlMemcpyOpLowering(MLIRContext *context)
|
|
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
|
|
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto *context = op->getContext();
|
|
KrnlMemcpyOpAdaptor operandAdaptor(operands);
|
|
auto loc = op->getLoc();
|
|
|
|
// Get a symbol reference to the memcpy function, inserting it if necessary.
|
|
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
|
|
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule);
|
|
|
|
// First operand.
|
|
Type dstType = operandAdaptor.dest()
|
|
.getType()
|
|
.cast<LLVM::LLVMType>()
|
|
.getStructElementType(1);
|
|
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
|
loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1));
|
|
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
|
loc, LLVM::LLVMType::getInt8PtrTy(context), alignedDstMemory);
|
|
|
|
// Second operand.
|
|
Type srcType = operandAdaptor.src()
|
|
.getType()
|
|
.cast<LLVM::LLVMType>()
|
|
.getStructElementType(1);
|
|
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
|
loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1));
|
|
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
|
loc, LLVM::LLVMType::getInt8PtrTy(context), alignedSrcMemory);
|
|
|
|
// Size.
|
|
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
|
loc, LLVM::LLVMType::getInt64Ty(context), operandAdaptor.size());
|
|
|
|
// Is volatile (set to false).
|
|
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
|
|
LLVM::LLVMType::getInt1Ty(context),
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
|
|
|
// Memcpy call
|
|
rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
|
|
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
|
int64Size, isVolatile}));
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// KRNL to LLVM: KrnlEntryPointOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
|
|
public:
|
|
using OpRewritePattern<KrnlEntryPointOp>::OpRewritePattern;
|
|
|
|
enum class API {
|
|
CREATE_ORDERED_DYN_MEM_REF_DICT,
|
|
CREATE_DYN_MEM_REF,
|
|
GET_DYN_MEM_REF,
|
|
SET_DYN_MEM_REF,
|
|
GET_DATA,
|
|
SET_DATA,
|
|
GET_SIZES,
|
|
GET_STRIDES,
|
|
SET_DTYPE,
|
|
GET_DTYPE,
|
|
};
|
|
|
|
struct ApiSpec {
|
|
API id;
|
|
std::string name;
|
|
FlatSymbolRefAttr symbolRef;
|
|
LLVM::LLVMType outputTy;
|
|
SmallVector<LLVM::LLVMType, 4> inputTys;
|
|
|
|
ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy,
|
|
ArrayRef<LLVM::LLVMType> inputTys)
|
|
: id(id), name(name), outputTy(outputTy),
|
|
inputTys(inputTys.begin(), inputTys.end()) {}
|
|
|
|
LLVM::LLVMType funcTy() {
|
|
return LLVM::LLVMType::getFunctionTy(outputTy, inputTys,
|
|
/*isVarArg=*/false);
|
|
}
|
|
};
|
|
|
|
LogicalResult matchAndRewrite(
|
|
KrnlEntryPointOp op, PatternRewriter &rewriter) const override {
|
|
|
|
auto module = op.getParentOfType<ModuleOp>();
|
|
auto *context = module.getContext();
|
|
auto apiRegistry = RegisterAllApis(module, rewriter);
|
|
auto loc = op.getLoc();
|
|
auto numOutputs =
|
|
op.getAttrOfType<IntegerAttr>(KrnlEntryPointOp::getNumOutputsAttrName())
|
|
.getInt();
|
|
|
|
using LLVMType = LLVM::LLVMType;
|
|
auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
|
|
auto int32Ty = LLVMType::getInt32Ty(context);
|
|
|
|
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
|
|
// signature. The signature is dynamic because it remains the same no matter
|
|
// what the model input/output schema look like. Such dynamic signature
|
|
// takes a opaque ptr as input, representing a ptr to a data structure
|
|
// containing a set of dynamic memrefs wrapped in a vector; similarly the
|
|
// output is also a opaque ptr to a data structure with output memrefs
|
|
// wrapped within it.
|
|
auto staticEntryPointFuncName =
|
|
op.getAttrOfType<SymbolRefAttr>(
|
|
KrnlEntryPointOp::getEntryPointFuncAttrName())
|
|
.getLeafReference();
|
|
auto dynEntryPointName = "_dyn_entry_point_" + staticEntryPointFuncName;
|
|
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
|
|
"dynamic entry point name is not unique");
|
|
rewriter.eraseOp(op);
|
|
auto dynEntryPointFuncTy =
|
|
LLVMType::getFunctionTy(opaquePtrTy, {opaquePtrTy}, false);
|
|
auto dynamicEntryPointFunc = rewriter.create<LLVM::LLVMFuncOp>(
|
|
loc, dynEntryPointName.str(), dynEntryPointFuncTy);
|
|
auto &entryPointEntryBlock =
|
|
createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc);
|
|
rewriter.setInsertionPointToStart(&entryPointEntryBlock);
|
|
|
|
// Based on the static entry point type signature, unpack dynamic memory
|
|
// refs to corresponding static memory refs.
|
|
auto wrappedStaticEntryPointFuncName =
|
|
"_mlir_ciface_" + staticEntryPointFuncName.lower();
|
|
auto *staticEntryPointFunc =
|
|
module.lookupSymbol(wrappedStaticEntryPointFuncName);
|
|
assert(staticEntryPointFunc &&
|
|
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
|
|
"entry point func must exist and be an llvm func op");
|
|
auto staticEntryPointTy = dyn_cast<LLVM::LLVMFuncOp>(staticEntryPointFunc)
|
|
.getType()
|
|
.dyn_cast<LLVMType>();
|
|
|
|
// Retrieve dynamic mem refs from wrapped input, and convert every one of
|
|
// them to static mem refs.
|
|
SmallVector<Value, 4> staticInputs;
|
|
auto wrappedInput = entryPointEntryBlock.getArgument(0);
|
|
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
|
|
// Call API function to retrieve the i-th dynamic memref.
|
|
auto idxVal = rewriter.create<LLVM::ConstantOp>(
|
|
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
|
auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF,
|
|
{wrappedInput, idxVal});
|
|
|
|
// Create a (static) memref type corresponding to the i-th memref input to
|
|
// the inference function on stack, and load it to memRef.
|
|
auto memRefPtrTy = staticEntryPointTy.getFunctionParamType(i);
|
|
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
|
auto one = rewriter.create<LLVM::ConstantOp>(
|
|
loc, int32Ty, rewriter.getI32IntegerAttr(1));
|
|
Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
|
|
/*alignment=*/0);
|
|
|
|
// Fill in the memref underlying ptrToMemRef with information extracted
|
|
// from dynMemRef.
|
|
fillPtrToMemRefWithRtMemRef(
|
|
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, module);
|
|
|
|
// ptrToMemRef will be an input to main computation graph function.
|
|
staticInputs.emplace_back(ptrToMemRef);
|
|
}
|
|
|
|
// Call static entry point with the memref ptrs created, and get output.
|
|
auto outMemRefs =
|
|
rewriter
|
|
.create<LLVM::CallOp>(loc,
|
|
staticEntryPointTy.getFunctionResultType(),
|
|
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
|
|
staticInputs)
|
|
.getResult(0);
|
|
auto outMemRefsType = outMemRefs.getType().dyn_cast<LLVMType>();
|
|
|
|
std::vector<mlir::Value> outMemRefList;
|
|
if (numOutputs == 1) {
|
|
// If only one output tensor exists, the tensor's corresponding memref
|
|
// descriptor will be returned as is.
|
|
outMemRefList.emplace_back(outMemRefs);
|
|
} else {
|
|
// Otherwise, if multiple tensors are to be returned, the returned value
|
|
// is a struct. Multiple tensors' memref descriptors are packed into the
|
|
// same struct. So we unpack them iteratively to outMemRefList.
|
|
for (int i = 0; i < numOutputs; i++) {
|
|
auto position = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(i)});
|
|
auto type = outMemRefsType.getStructElementType(i);
|
|
auto extractOp = rewriter.create<LLVM::ExtractValueOp>(loc,
|
|
/*res=*/type,
|
|
/*type=*/outMemRefs,
|
|
/*position=*/position);
|
|
outMemRefList.emplace_back(extractOp.getResult());
|
|
}
|
|
}
|
|
|
|
// Create wrapped output.
|
|
auto wrappedOutput = callApi(
|
|
rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
|
|
|
for (decltype(numOutputs) i = 0; i < outMemRefList.size(); i++) {
|
|
// Get the i-th memref returned, convert to a dynamic memref and store it
|
|
// in the wrappedOutput.
|
|
auto memRef = outMemRefList.at(i);
|
|
auto outMemRefTy = memRef.getType().dyn_cast<LLVMType>();
|
|
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
|
|
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
|
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
|
auto outRtMemRef = callApi(rewriter, loc, apiRegistry,
|
|
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
|
fillRtMemRefWithMemRef(
|
|
memRef, outRtMemRef, rewriter, loc, apiRegistry, module);
|
|
auto idx = rewriter.create<LLVM::ConstantOp>(
|
|
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
|
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
|
{wrappedOutput, idx, outRtMemRef});
|
|
}
|
|
// Return wrapped output.
|
|
rewriter.create<LLVM::ReturnOp>(
|
|
loc, SmallVector<Value, 1>({wrappedOutput}));
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
using ApiRegistry = std::map<API, ApiSpec>;
|
|
|
|
ApiRegistry RegisterAllApis(
|
|
ModuleOp &module, PatternRewriter &rewriter) const {
|
|
auto *context = module.getContext();
|
|
using LLVMType = LLVM::LLVMType;
|
|
auto voidTy = LLVMType::getVoidTy(context);
|
|
auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
|
|
auto int32Ty = LLVMType::getInt32Ty(context);
|
|
auto int64Ty = LLVMType::getInt64Ty(context);
|
|
auto int64PtrTy = int64Ty.getPointerTo();
|
|
|
|
// Declare API type as an enum value, its string name and an LLVM Type
|
|
// specifying its signature.
|
|
// clang-format off
|
|
std::vector<ApiSpec> apiSpecs = {
|
|
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedRtMemRefDict", opaquePtrTy, {}),
|
|
ApiSpec(API::CREATE_DYN_MEM_REF, "createRtMemRef", opaquePtrTy, {int32Ty}),
|
|
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
|
|
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
|
|
ApiSpec(API::GET_DYN_MEM_REF, "getRtMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
|
|
ApiSpec(API::SET_DYN_MEM_REF, "setRtMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
|
|
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
|
|
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy}),
|
|
ApiSpec(API::GET_DTYPE, "getDType", int32Ty, {opaquePtrTy}),
|
|
ApiSpec(API::SET_DTYPE, "setDType", voidTy, {opaquePtrTy, int32Ty}),
|
|
};
|
|
// clang-format on
|
|
|
|
// Declare APIs in the current module and build an API registry mapping api
|
|
// identities to a symbol reference to the API function.
|
|
ApiRegistry registry;
|
|
for (auto &apiSpec : apiSpecs) {
|
|
apiSpec.symbolRef = getOrInsertExternFunc(
|
|
apiSpec.name, module, apiSpec.funcTy(), rewriter);
|
|
registry.emplace(apiSpec.id, apiSpec);
|
|
}
|
|
|
|
return registry;
|
|
}
|
|
|
|
// Call a registered API, return the return SSA values if only one result is
|
|
// returned, otherwise return nullptr.
|
|
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
|
API apiId, ArrayRef<Value> params) const {
|
|
// To be used as parameters in LLVM::CallOp, voidTy must be converted
|
|
// to empty list to avoid emission of an SSA value with voidTy. However,
|
|
// we still keep using LLVM voidTy (as opposed to empty list) when recording
|
|
// API function signatures in API registry because when declaring API
|
|
// functions in LLVM IR, the correct way to indicate an output type for
|
|
// "void" is still LLVM voidTy. Relevant discussion thread:
|
|
// https://github.com/onnx/onnx-mlir/issues/255.
|
|
SmallVector<Type, 1> outputTys;
|
|
auto outputTy = registry.at(apiId).outputTy;
|
|
if (!outputTy.isVoidTy())
|
|
outputTys.emplace_back(outputTy);
|
|
auto returnVals =
|
|
rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>(outputTys),
|
|
registry.at(apiId).symbolRef, ArrayRef<Value>(params));
|
|
if (returnVals.getNumResults() == 1)
|
|
return returnVals.getResult(0);
|
|
return nullptr;
|
|
}
|
|
|
|
// Helper function to insert an entry block to LLVM function.
|
|
// (TODO): upstream this to MLIR.
|
|
Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType,
|
|
LLVM::LLVMFuncOp &dynamicEntryPointFunc) const {
|
|
// Add entry block:
|
|
auto *entryPointEntryBlock = new Block();
|
|
dynamicEntryPointFunc.push_back(entryPointEntryBlock);
|
|
llvm::SmallVector<Type, 4> argTypes;
|
|
for (size_t i = 0; i < dynEntryPointFuncType.getFunctionNumParams(); i++)
|
|
argTypes.emplace_back(dynEntryPointFuncType.getFunctionParamType(i));
|
|
entryPointEntryBlock->addArguments(argTypes);
|
|
return *entryPointEntryBlock;
|
|
}
|
|
|
|
void fillPtrToMemRefWithRtMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
|
PatternRewriter &rewriter, const Location &loc,
|
|
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
|
|
auto *context = module.getContext();
|
|
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
|
auto memRefTy = memRefPtrTy.getPointerElementTy();
|
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
|
|
|
|
Value memRef = rewriter.create<LLVM::UndefOp>(loc, memRefTy);
|
|
|
|
// Set dataPtr and alignedDataPtr;
|
|
auto dataPtr =
|
|
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
|
|
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
|
loc, memRefTy.getStructElementType(0), dataPtr);
|
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
|
dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
|
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
|
dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
|
|
|
|
// Use zero offset now.
|
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
|
loc, int64Ty, rewriter.getI64IntegerAttr(0));
|
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef, zero,
|
|
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
|
|
|
|
// Get rank, sizes array ptr and strides array ptr.
|
|
auto rank = getRankFromMemRefType(memRefTy);
|
|
auto sizesArrayPtr =
|
|
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
|
|
auto stridesArrayPtr =
|
|
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef});
|
|
|
|
for (decltype(rank) i = 0; i < rank; i++) {
|
|
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
|
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
|
|
|
// Insert size of the dimension.
|
|
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
|
|
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
|
|
auto dimSize = rewriter.create<LLVM::LoadOp>(
|
|
loc, int64Ty.getPointerTo(), dimSizePtr);
|
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
|
dimSize,
|
|
rewriter.getArrayAttr(
|
|
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
|
|
|
// Insert stride of the dimension.
|
|
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
|
|
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
|
|
auto dimStride = rewriter.create<LLVM::LoadOp>(
|
|
loc, int64Ty.getPointerTo(), dimStridePtr);
|
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
|
dimStride,
|
|
rewriter.getArrayAttr(
|
|
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
|
}
|
|
|
|
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
|
|
}
|
|
|
|
void fillRtMemRefWithMemRef(Value &outMemRef, Value &outRtMemRef,
|
|
PatternRewriter &rewriter, const Location &loc,
|
|
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
|
|
auto *context = module.getContext();
|
|
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
|
|
auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
|
|
|
|
// Extract the data pointer, and record it in dynamic mem ref created.
|
|
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
|
outMemRefTy.getStructElementType(0), outMemRef,
|
|
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
|
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
|
loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefDataPtr);
|
|
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
|
{outRtMemRef, outMemRefDataPtr});
|
|
auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy();
|
|
auto onnxTy = llvmTypeToOnnxType(elemTy);
|
|
auto onnxTyVal = rewriter.create<LLVM::ConstantOp>(
|
|
loc, int32Ty, rewriter.getI32IntegerAttr(onnxTy));
|
|
callApi(
|
|
rewriter, loc, apiRegistry, API::SET_DTYPE, {outRtMemRef, onnxTyVal});
|
|
|
|
auto rank = getRankFromMemRefType(outMemRefTy);
|
|
auto sizesArrayPtr =
|
|
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outRtMemRef});
|
|
auto stridesArrayPtr =
|
|
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outRtMemRef});
|
|
|
|
for (decltype(rank) i = 0; i < rank; i++) {
|
|
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
|
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
|
|
|
// Transfer size of dimension from memref to dynamic memref.
|
|
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
|
|
outMemRef,
|
|
rewriter.getArrayAttr(
|
|
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
|
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
|
|
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
|
|
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
|
|
|
|
// Transfer stride of dimension from memref to dynamic memref.
|
|
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
|
|
outMemRef,
|
|
rewriter.getArrayAttr(
|
|
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
|
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
|
|
int64Ty.getPointerTo(), stridesArrayPtr, ArrayRef<Value>({dimIdx}));
|
|
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
|
|
}
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// KRNL to LLVM: KrnlPackedConstOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class KrnlPackedConstOpLowering : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit KrnlPackedConstOpLowering(
|
|
MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(
|
|
KrnlPackedConstantOp::getOperationName(), context, lowering_) {}
|
|
|
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto *context = op->getContext();
|
|
ModuleOp module = op->getParentOfType<ModuleOp>();
|
|
auto loc = op->getLoc();
|
|
|
|
auto packedConstOp = llvm::dyn_cast<KrnlPackedConstantOp>(op);
|
|
LLVM::GlobalOp globalBase;
|
|
// Some frequently used types.
|
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context);
|
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(context);
|
|
{
|
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
|
rewriter.setInsertionPointToStart(module.getBody());
|
|
|
|
globalBase = rewriter.create<LLVM::GlobalOp>(loc, llvmI8PtrTy,
|
|
/*isConstant=*/false, LLVM::Linkage::Internal, "packedConst",
|
|
nullptr);
|
|
}
|
|
|
|
auto mainFunc = module.lookupSymbol<FuncOp>("main_graph");
|
|
assert(mainFunc);
|
|
|
|
rewriter.setInsertionPoint(
|
|
&mainFunc.getBody().front(), mainFunc.getBody().front().begin());
|
|
|
|
// - Initialize the global constant base.
|
|
Value basePtrAddr = rewriter.create<LLVM::AddressOfOp>(loc, globalBase);
|
|
auto getEmbeddedConstPoolRef = getOrInsertExternFunc(
|
|
KrnlPackedConstantOp::getEmbeddedDataLoaderMethodName(), module,
|
|
LLVM::LLVMType::getFunctionTy(
|
|
llvmI8PtrTy, {llvmI64Ty}, /*isVarArg=*/false),
|
|
rewriter);
|
|
auto constPackSize = rewriter.create<LLVM::ConstantOp>(loc,
|
|
LLVM::LLVMType::getInt64Ty(context), packedConstOp.size_in_bytesAttr());
|
|
Value alloc = rewriter
|
|
.create<CallOp>(loc, getEmbeddedConstPoolRef, llvmI8PtrTy,
|
|
ArrayRef<Value>({constPackSize}))
|
|
.getResult(0);
|
|
rewriter.create<LLVM::StoreOp>(loc, alloc, basePtrAddr);
|
|
{
|
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
|
rewriter.setInsertionPointToStart(module.getBody());
|
|
|
|
// Record constant pack *file path* as a global variable (by recording the
|
|
// file path string's underlying char array + its length).
|
|
const auto &fileNameAttr = packedConstOp.file_nameAttr();
|
|
auto type = LLVM::LLVMType::getArrayTy(
|
|
LLVM::LLVMType::getInt8Ty(context), fileNameAttr.getValue().size());
|
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
|
LLVM::Linkage::External,
|
|
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName(),
|
|
fileNameAttr);
|
|
type = LLVM::LLVMType::getInt64Ty(context);
|
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
|
LLVM::Linkage::External,
|
|
mlir::KrnlPackedConstantOp::getConstPackFilePathStrLenSymbolName(),
|
|
rewriter.getI64IntegerAttr(fileNameAttr.getValue().size()));
|
|
|
|
// Record constant pack *file name* as a global variable (by recording the
|
|
// file name string's underlying char array + its length).
|
|
auto constPackFileName =
|
|
llvm::sys::path::filename(fileNameAttr.getValue());
|
|
type = LLVM::LLVMType::getArrayTy(
|
|
LLVM::LLVMType::getInt8Ty(context), constPackFileName.size());
|
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
|
LLVM::Linkage::External,
|
|
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName(),
|
|
rewriter.getStringAttr(constPackFileName));
|
|
type = LLVM::LLVMType::getInt64Ty(context);
|
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
|
LLVM::Linkage::External,
|
|
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName(),
|
|
rewriter.getI64IntegerAttr(constPackFileName.size()));
|
|
|
|
type = LLVM::LLVMType::getInt8Ty(context);
|
|
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
|
|
LLVM::Linkage::External,
|
|
mlir::KrnlPackedConstantOp::getConstPackIsLESymbolName(),
|
|
rewriter.getI8IntegerAttr(packedConstOp.is_le()));
|
|
}
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
|
}
|
|
};
|
|
} // end namespace
|
|
|
|
void mlir::populateAffineAndKrnlToLLVMConversion(
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx,
|
|
LLVMTypeConverter &typeConverter) {
|
|
populateAffineToStdConversionPatterns(patterns, ctx);
|
|
populateLoopToStdConversionPatterns(patterns, ctx);
|
|
populateShapeToStandardConversionPatterns(patterns, ctx);
|
|
populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns);
|
|
populateVectorToLLVMConversionPatterns(typeConverter, patterns);
|
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
|
|
|
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
|
|
ctx, typeConverter);
|
|
patterns.insert<KrnlGetRefOpLowering>(ctx, typeConverter);
|
|
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(ctx);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// KRNL + Standard + Affine dialects lowering to LLVM.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct ConvertKrnlToLLVMPass
|
|
: public PassWrapper<ConvertKrnlToLLVMPass, OperationPass<ModuleOp>> {
|
|
void runOnOperation() final;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
void ConvertKrnlToLLVMPass::runOnOperation() {
|
|
// Define the target for this lowering i.e. the LLVM dialect.
|
|
ConversionTarget target(getContext());
|
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
|
target.addIllegalOp<LLVM::DialectCastOp>();
|
|
|
|
// Lower the MemRef types to a representation in LLVM.
|
|
LowerToLLVMOptions options;
|
|
options.emitCWrappers = true;
|
|
LLVMTypeConverter typeConverter(&getContext(), options);
|
|
|
|
// We have a combination of `krnl`, `affine`, and `std` operations. We
|
|
// lower in stages until all the code is in the LLVM dialect.
|
|
OwningRewritePatternList patterns;
|
|
populateAffineAndKrnlToLLVMConversion(patterns, &getContext(), typeConverter);
|
|
|
|
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
|
// ensures that only legal operations will remain after the conversion.
|
|
if (failed(applyFullConversion(getOperation(), target, patterns))) {
|
|
signalPassFailure();
|
|
}
|
|
}
|
|
|
|
/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.
|
|
std::unique_ptr<mlir::Pass> mlir::createConvertKrnlToLLVMPass() {
|
|
return std::make_unique<ConvertKrnlToLLVMPass>();
|
|
}
|