Emit constant tensors as global constants (#66)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Enable unique constant variable names. * Emit alloca for local array. Add tests. * Comment clean-up. * Simplify MemRef construction. * Fix output type.
This commit is contained in:
parent
b65e77305c
commit
f16e79d744
|
@ -47,7 +47,7 @@ struct FrontendToKrnlLoweringPass
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
||||||
void FrontendToKrnlLoweringPass::runOnModule() {
|
void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
auto module = getModule();
|
ModuleOp module = getModule();
|
||||||
|
|
||||||
// The first thing to define is the conversion target. This will define the
|
// The first thing to define is the conversion target. This will define the
|
||||||
// final target for this lowering.
|
// final target for this lowering.
|
||||||
|
|
|
@ -485,3 +485,7 @@ Value emitNegativeInfinityConstantOp(
|
||||||
|
|
||||||
return rewriter.create<ConstantOp>(loc, constantAttr);
|
return rewriter.create<ConstantOp>(loc, constantAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
||||||
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
|
}
|
|
@ -117,6 +117,8 @@ Value emitPositiveInfinityConstantOp(
|
||||||
Value emitNegativeInfinityConstantOp(
|
Value emitNegativeInfinityConstantOp(
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Type type);
|
ConversionPatternRewriter &rewriter, Location loc, Type type);
|
||||||
|
|
||||||
|
int64_t ArrayAttrIntVal(ArrayAttr a, int i);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// This is to get a scalar operation of a given type for a specific operation.
|
// This is to get a scalar operation of a given type for a specific operation.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -12,40 +12,13 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
template <typename ElementAttr>
|
|
||||||
void emitConstantAndStoreOpForDenseElementsAttr(
|
|
||||||
ConversionPatternRewriter &rewriter, Location loc,
|
|
||||||
DenseElementsAttr constantValue, ArrayRef<int64_t> valueShape,
|
|
||||||
ArrayRef<Value> constantIndices, Value alloc) {
|
|
||||||
// The following functor recursively walks the dimensions of the constant
|
|
||||||
// shape, generating a store when the recursion hits the base case.
|
|
||||||
SmallVector<Value, 2> indices;
|
|
||||||
auto valueIt = constantValue.getValues<ElementAttr>().begin();
|
|
||||||
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
|
|
||||||
// The last dimension is the base case of the recursion, at this point
|
|
||||||
// we store the element at the given index.
|
|
||||||
if (dimension == valueShape.size()) {
|
|
||||||
rewriter.create<AffineStoreOp>(loc,
|
|
||||||
rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
|
|
||||||
llvm::makeArrayRef(indices));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, iterate over the current dimension and add the indices to
|
|
||||||
// the list.
|
|
||||||
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
|
|
||||||
indices.push_back(constantIndices[i]);
|
|
||||||
storeElements(dimension + 1);
|
|
||||||
indices.pop_back();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// Start the element storing recursion from the first dimension.
|
|
||||||
storeElements(/*dimension=*/0);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ONNXConstantOpLowering : public ConversionPattern {
|
struct ONNXConstantOpLowering : public ConversionPattern {
|
||||||
|
static int constantID;
|
||||||
|
|
||||||
ONNXConstantOpLowering(MLIRContext *ctx)
|
ONNXConstantOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {
|
||||||
|
constantID = 0;
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
@ -58,42 +31,32 @@ struct ONNXConstantOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
Value alloc;
|
// Shape based computations.
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
auto shape = memRefType.getShape();
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int i=0; i<shape.size(); ++i)
|
||||||
|
numElements *= shape[i];
|
||||||
|
|
||||||
if (hasAllConstantDimensions(memRefType))
|
// Emit the constant global in Krnl dialect.
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
|
||||||
else
|
memRefType,
|
||||||
emitError(loc, "Unexpected output has non-Constant shape");
|
rewriter.getI64ArrayAttr(shape),
|
||||||
|
constantOp.value().getValue(),
|
||||||
|
rewriter.getStringAttr("constant_" + std::to_string(constantID)));
|
||||||
|
|
||||||
DenseElementsAttr constantValue =
|
// Increment constant ID:
|
||||||
constantOp.value().getValue().cast<DenseElementsAttr>();
|
constantID++;
|
||||||
|
|
||||||
auto valueShape = memRefType.getShape();
|
|
||||||
SmallVector<Value, 8> constantIndices;
|
|
||||||
for (auto i : llvm::seq<int64_t>(
|
|
||||||
0, *std::max_element(valueShape.begin(), valueShape.end())))
|
|
||||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
|
|
||||||
|
|
||||||
// The constant operation represents a multi-dimensional constant, so we
|
|
||||||
// will need to generate a store for each of the elements.
|
|
||||||
if (memRefType.getElementType().isa<IntegerType>()) {
|
|
||||||
emitConstantAndStoreOpForDenseElementsAttr<IntegerAttr>(
|
|
||||||
rewriter, loc, constantValue, valueShape, constantIndices, alloc);
|
|
||||||
} else if (memRefType.getElementType().isa<FloatType>()) {
|
|
||||||
emitConstantAndStoreOpForDenseElementsAttr<FloatAttr>(
|
|
||||||
rewriter, loc, constantValue, valueShape, constantIndices, alloc);
|
|
||||||
} else {
|
|
||||||
emitError(loc, "Unsupported output type");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace this operation with the generated alloc.
|
// Replace this operation with the generated alloc.
|
||||||
rewriter.replaceOp(op, alloc);
|
// rewriter.replaceOp(op, alloc);
|
||||||
|
rewriter.replaceOp(op, constantGlobal.getResult());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
int ONNXConstantOpLowering::constantID;
|
||||||
|
|
||||||
void populateLoweringONNXConstantOpPattern(
|
void populateLoweringONNXConstantOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
patterns.insert<ONNXConstantOpLowering>(ctx);
|
patterns.insert<ONNXConstantOpLowering>(ctx);
|
||||||
|
|
|
@ -192,3 +192,16 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
||||||
let parser = ?;
|
let parser = ?;
|
||||||
let printer = ?;
|
let printer = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
|
||||||
|
let summary = "Krnl global operation";
|
||||||
|
let description = [{
|
||||||
|
Operation for holding global data values.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyAttr:$shape, AnyAttr:$value, StrAttr:$name);
|
||||||
|
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
|
||||||
|
|
||||||
|
let parser = ?;
|
||||||
|
let printer = ?;
|
||||||
|
}
|
||||||
|
|
|
@ -154,6 +154,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
target.addIllegalDialect<KrnlOpsDialect>();
|
target.addIllegalDialect<KrnlOpsDialect>();
|
||||||
target.addLegalOp<KrnlMemcpyOp>();
|
target.addLegalOp<KrnlMemcpyOp>();
|
||||||
target.addLegalOp<KrnlEntryPointOp>();
|
target.addLegalOp<KrnlEntryPointOp>();
|
||||||
|
target.addLegalOp<KrnlGlobalOp>();
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||||
|
|
|
@ -16,10 +16,12 @@
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "llvm/ADT/Sequence.h"
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
#include "src/Pass/Passes.hpp"
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
@ -60,6 +62,150 @@ static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
|
||||||
return memRefTy.getStructElementType(3).getArrayNumElements();
|
return memRefTy.getStructElementType(3).getArrayNumElements();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return a symbol reference to the memcpy function, inserting it into the
|
||||||
|
/// module if necessary.
|
||||||
|
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
||||||
|
ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
|
||||||
|
auto *context = module.getContext();
|
||||||
|
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
||||||
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||||
|
// Create a function declaration for memcpy, the signature is:
|
||||||
|
// * `void (i8*, i8* , i64, i1)`
|
||||||
|
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||||
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
||||||
|
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
||||||
|
ArrayRef<mlir::LLVM::LLVMType>(
|
||||||
|
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
||||||
|
false);
|
||||||
|
|
||||||
|
// Insert the memcpy function into the body of the parent module.
|
||||||
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
rewriter.create<LLVM::LLVMFuncOp>(
|
||||||
|
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
||||||
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KRNL to LLVM: KrnlGlobalOpLowering
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
|
||||||
|
public:
|
||||||
|
explicit KrnlGlobalOpLowering(MLIRContext *context,
|
||||||
|
LLVMTypeConverter &lowering_)
|
||||||
|
: ConvertToLLVMPattern(KrnlGlobalOp::getOperationName(), context,
|
||||||
|
lowering_) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto *context = op->getContext();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto *llvmDialect =
|
||||||
|
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||||
|
|
||||||
|
auto krnlGlobalOp = llvm::dyn_cast<KrnlGlobalOp>(op);
|
||||||
|
|
||||||
|
// Get module.
|
||||||
|
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||||
|
|
||||||
|
// Global name.
|
||||||
|
auto name = krnlGlobalOp.name();
|
||||||
|
|
||||||
|
// Compute total number of elements.
|
||||||
|
auto shape = (krnlGlobalOp.shape()).dyn_cast<ArrayAttr>();
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int i=0; i<shape.size(); ++i)
|
||||||
|
numElements *= ArrayAttrIntVal(shape, i);
|
||||||
|
|
||||||
|
// Create the global at the entry of the module.
|
||||||
|
LLVM::GlobalOp global;
|
||||||
|
auto type = op->getResult(0).getType();
|
||||||
|
auto memRefTy = type.cast<mlir::MemRefType>();
|
||||||
|
auto llvmMemRefType =
|
||||||
|
typeConverter.convertType(type).cast<LLVM::LLVMType>();
|
||||||
|
|
||||||
|
// The element type of the array.
|
||||||
|
auto constantElementType =
|
||||||
|
typeConverter.convertType(memRefTy.getElementType());
|
||||||
|
auto globalType = constantElementType;
|
||||||
|
for (int i=shape.size() - 1; i >= 0; i--)
|
||||||
|
globalType = LLVM::LLVMType::getArrayTy(
|
||||||
|
globalType.cast<LLVM::LLVMType>(), ArrayAttrIntVal(shape, i));
|
||||||
|
// The llvm type of the global (example: [2 x [8 x float]])
|
||||||
|
auto llvmGlobalType = globalType.cast<LLVM::LLVMType>();
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
|
||||||
|
global = rewriter.create<LLVM::GlobalOp>(loc,
|
||||||
|
llvmGlobalType, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::Internal, name, krnlGlobalOp.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some frequently used types.
|
||||||
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
|
||||||
|
// Allocate the memory where the constants will be used from.
|
||||||
|
// This is a region of local memory and needs to be emitted as an alloca.
|
||||||
|
auto one = rewriter.create<LLVM::ConstantOp>(loc,
|
||||||
|
llvmI64Ty, rewriter.getI64IntegerAttr(1));
|
||||||
|
auto alloc = rewriter.create<LLVM::AllocaOp>(
|
||||||
|
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
|
||||||
|
|
||||||
|
// Copy constant value into the local alloca:
|
||||||
|
// - Bitcast alloc to i8*
|
||||||
|
Value int8PtrAlloc = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, llvmI8PtrTy, alloc);
|
||||||
|
// - Bitcast global to i8*
|
||||||
|
Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
|
||||||
|
Value i8PtrGlobal = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, llvmI8PtrTy, globalValue);
|
||||||
|
// - Set size.
|
||||||
|
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc,
|
||||||
|
llvmI64Ty, rewriter.getI64IntegerAttr(
|
||||||
|
getMemRefEltSizeInBytes(memRefTy)));
|
||||||
|
Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
|
||||||
|
Value totalElementsSize = rewriter.create<LLVM::MulOp>(
|
||||||
|
loc, memRefElementSize, numElementsValue);
|
||||||
|
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
||||||
|
loc, llvmI64Ty, totalElementsSize);
|
||||||
|
// - Set volatile.
|
||||||
|
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||||
|
// - Copy constant data into the alloca.
|
||||||
|
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
|
||||||
|
rewriter.create<CallOp>(
|
||||||
|
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||||
|
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
|
||||||
|
|
||||||
|
// Prepare data to be inserted into MemRef.
|
||||||
|
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
|
||||||
|
Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, llvmConstantElementType.getPointerTo(), alloc);
|
||||||
|
|
||||||
|
// Create llvm MemRef from original MemRef and fill the data pointers.
|
||||||
|
auto llvmMemRef = MemRefDescriptor::fromStaticShape(
|
||||||
|
rewriter, loc, typeConverter, memRefTy, typedAlloc);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {llvmMemRef});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
||||||
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// KRNL to LLVM: KrnlMemcpyOpLowering
|
// KRNL to LLVM: KrnlMemcpyOpLowering
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -120,33 +266,6 @@ public:
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
/// Return a symbol reference to the memcpy function, inserting it into the
|
|
||||||
/// module if necessary.
|
|
||||||
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
|
||||||
ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
|
|
||||||
auto *context = module.getContext();
|
|
||||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
|
||||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
|
||||||
// Create a function declaration for memcpy, the signature is:
|
|
||||||
// * `void (i8*, i8* , i64, i1)`
|
|
||||||
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
|
||||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
|
||||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
|
||||||
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
|
||||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
|
||||||
ArrayRef<mlir::LLVM::LLVMType>(
|
|
||||||
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
|
||||||
false);
|
|
||||||
|
|
||||||
// Insert the memcpy function into the body of the parent module.
|
|
||||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
|
||||||
rewriter.setInsertionPointToStart(module.getBody());
|
|
||||||
rewriter.create<LLVM::LLVMFuncOp>(
|
|
||||||
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
|
||||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -506,6 +625,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
||||||
/*useAlloca=*/false,
|
/*useAlloca=*/false,
|
||||||
/*emitCWrapper=*/true);
|
/*emitCWrapper=*/true);
|
||||||
|
|
||||||
|
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
|
||||||
|
|
||||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
||||||
&getContext());
|
&getContext());
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
// RUN: onnx-mlir-opt --shape-inference --lower-frontend --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
|
||||||
|
// CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm<"[3 x [2 x float]]">
|
||||||
|
// CHECK: llvm.func @test_constant({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
|
||||||
|
|
||||||
|
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64
|
||||||
|
// CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm<"[3 x [2 x float]]"> : (!llvm.i64) -> !llvm<"[3 x [2 x float]]*">
|
||||||
|
// CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
|
||||||
|
|
||||||
|
// CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm<"[3 x [2 x float]]*">
|
||||||
|
// CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
|
||||||
|
|
||||||
|
/// Size of the constant tensor in bytes.
|
||||||
|
// CHECK: [[CONST4:%.+]] = llvm.mlir.constant(4 : i64) : !llvm.i64
|
||||||
|
// CHECK: [[CONST6:%.+]] = llvm.mlir.constant(6 : i64) : !llvm.i64
|
||||||
|
// CHECK: [[CONST_MUL1:%.+]] = llvm.mul [[CONST4]], [[CONST6]] : !llvm.i64
|
||||||
|
// CHECK: [[GLOBAL_SIZE_BYTES:%.+]] = llvm.sext [[CONST_MUL1]] : !llvm.i64 to !llvm.i64
|
||||||
|
|
||||||
|
/// Volatile flag
|
||||||
|
// CHECK: [[CONST0:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
|
||||||
|
|
||||||
|
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
|
||||||
|
|
||||||
|
/// Prepare data for MemRef insertion.
|
||||||
|
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"float*">
|
||||||
|
|
||||||
|
/// Insert the constant value in the local MemRef.
|
||||||
|
// CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
|
||||||
|
/// Insert offset.
|
||||||
|
// CHECK: [[CONST00:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
|
||||||
|
/// Insert sizes and strides.
|
||||||
|
// CHECK: [[CONST3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
|
||||||
|
// CHECK: [[CONST2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
|
||||||
|
// CHECK: llvm.return [[MEMREF5]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
}
|
|
@ -1678,22 +1678,7 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
// CHECK-LABEL: test_constant_dense_2d_value
|
// CHECK-LABEL: test_constant_dense_2d_value
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
|
// CHECK: [[RES:%.+]] = "krnl.global"() {name = "constant_0", shape = [3, 2], value = dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>} : () -> memref<3x2xf32>
|
||||||
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
|
|
||||||
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
|
|
||||||
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
|
|
||||||
// CHECK: [[CONSTANT_0:%.+]] = constant 0.000000e+00 : f32
|
|
||||||
// CHECK: affine.store [[CONSTANT_0]], %0[%[[INDEX_0]], %[[INDEX_0]]] : memref<3x2xf32>
|
|
||||||
// CHECK: [[CONSTANT_1:%.+]] = constant 0.000000e+00 : f32
|
|
||||||
// CHECK: affine.store [[CONSTANT_1]], %0[%[[INDEX_0]], %[[INDEX_1]]] : memref<3x2xf32>
|
|
||||||
// CHECK: [[CONSTANT_2:%.+]] = constant 1.000000e+00 : f32
|
|
||||||
// CHECK: affine.store [[CONSTANT_2]], %0[%[[INDEX_1]], %[[INDEX_0]]] : memref<3x2xf32>
|
|
||||||
// CHECK: [[CONSTANT_3:%.+]] = constant 1.100000e+00 : f32
|
|
||||||
// CHECK: affine.store [[CONSTANT_3]], %0[%[[INDEX_1]], %[[INDEX_1]]] : memref<3x2xf32>
|
|
||||||
// CHECK: [[CONSTANT_4:%.+]] = constant 2.000000e+00 : f32
|
|
||||||
// CHECK: affine.store [[CONSTANT_4]], %0[%[[INDEX_2]], %[[INDEX_0]]] : memref<3x2xf32>
|
|
||||||
// CHECK: [[CONSTANT_5:%.+]] = constant 2.100000e+00 : f32
|
|
||||||
// CHECK: affine.store [[CONSTANT_5]], %0[%[[INDEX_2]], %[[INDEX_1]]] : memref<3x2xf32>
|
|
||||||
// CHECK: return [[RES]] : memref<3x2xf32>
|
// CHECK: return [[RES]] : memref<3x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue