diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 4822491..e278f66 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -45,7 +45,8 @@ MemRefType convertToMemRefType(Type type) { /// Insert an allocation and deallocation for the given MemRefType. Value insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter, bool insertDealloc, ArrayRef operands) { + PatternRewriter &rewriter, bool insertDealloc, ArrayRef operands, + int64_t alignment) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; if (!operands.empty()) { @@ -87,9 +88,26 @@ Value insertAllocAndDealloc(MemRefType type, Location loc, for (int i = 0; i < rank; ++i) if (memRefShape[i] < 0) allocOperands.push_back(fromOperands[i]); - alloc = rewriter.create(loc, type, allocOperands); + // Set alignment attribute. Default value is `-1`, which does not set + // alignment. + if (alignment >= 0) { + IntegerAttr constAlignAttr = rewriter.getI64IntegerAttr(alignment); + alloc = + rewriter.create(loc, type, allocOperands, constAlignAttr); + } else { + alloc = rewriter.create(loc, type, allocOperands); + } } else { - alloc = rewriter.create(loc, type); + // Set alignment attribute. Default value is `-1`, which does not set + // alignment. + if (alignment >= 0) { + SmallVector allocOperandsEmpty; + IntegerAttr constAlignAttr = rewriter.getI64IntegerAttr(alignment); + alloc = rewriter.create( + loc, type, allocOperandsEmpty, constAlignAttr); + } else { + alloc = rewriter.create(loc, type); + } } // Make sure to allocate at the beginning of the block if diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 2d12677..f03f8ba 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -45,7 +45,7 @@ MemRefType convertToMemRefType(Type type); /// Insert an allocation and deallocation for the given MemRefType. Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter, bool insertDealloc, - ArrayRef operands = {}); + ArrayRef operands = {}, int64_t alignment = -1); // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be