Emit constant tensors as global constants (#66)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Enable unique constant variable names. * Emit alloca for local array. Add tests. * Comment clean-up. * Simplify MemRef construction. * Fix output type.
This commit is contained in:
		
							parent
							
								
									b65e77305c
								
							
						
					
					
						commit
						f16e79d744
					
				| 
						 | 
					@ -47,7 +47,7 @@ struct FrontendToKrnlLoweringPass
 | 
				
			||||||
} // end anonymous namespace.
 | 
					} // end anonymous namespace.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void FrontendToKrnlLoweringPass::runOnModule() {
 | 
					void FrontendToKrnlLoweringPass::runOnModule() {
 | 
				
			||||||
  auto module = getModule();
 | 
					  ModuleOp module = getModule();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The first thing to define is the conversion target. This will define the
 | 
					  // The first thing to define is the conversion target. This will define the
 | 
				
			||||||
  // final target for this lowering.
 | 
					  // final target for this lowering.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -485,3 +485,7 @@ Value emitNegativeInfinityConstantOp(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return rewriter.create<ConstantOp>(loc, constantAttr);
 | 
					  return rewriter.create<ConstantOp>(loc, constantAttr);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
 | 
				
			||||||
 | 
					  return (a.getValue()[i]).cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -117,6 +117,8 @@ Value emitPositiveInfinityConstantOp(
 | 
				
			||||||
Value emitNegativeInfinityConstantOp(
 | 
					Value emitNegativeInfinityConstantOp(
 | 
				
			||||||
    ConversionPatternRewriter &rewriter, Location loc, Type type);
 | 
					    ConversionPatternRewriter &rewriter, Location loc, Type type);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int64_t ArrayAttrIntVal(ArrayAttr a, int i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// This is to get a scalar operation of a given type for a specific operation.
 | 
					// This is to get a scalar operation of a given type for a specific operation.
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,40 +12,13 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using namespace mlir;
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename ElementAttr>
 | 
					 | 
				
			||||||
void emitConstantAndStoreOpForDenseElementsAttr(
 | 
					 | 
				
			||||||
    ConversionPatternRewriter &rewriter, Location loc,
 | 
					 | 
				
			||||||
    DenseElementsAttr constantValue, ArrayRef<int64_t> valueShape,
 | 
					 | 
				
			||||||
    ArrayRef<Value> constantIndices, Value alloc) {
 | 
					 | 
				
			||||||
  // The following functor recursively walks the dimensions of the constant
 | 
					 | 
				
			||||||
  // shape, generating a store when the recursion hits the base case.
 | 
					 | 
				
			||||||
  SmallVector<Value, 2> indices;
 | 
					 | 
				
			||||||
  auto valueIt = constantValue.getValues<ElementAttr>().begin();
 | 
					 | 
				
			||||||
  std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
 | 
					 | 
				
			||||||
    // The last dimension is the base case of the recursion, at this point
 | 
					 | 
				
			||||||
    // we store the element at the given index.
 | 
					 | 
				
			||||||
    if (dimension == valueShape.size()) {
 | 
					 | 
				
			||||||
      rewriter.create<AffineStoreOp>(loc,
 | 
					 | 
				
			||||||
          rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
 | 
					 | 
				
			||||||
          llvm::makeArrayRef(indices));
 | 
					 | 
				
			||||||
      return;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Otherwise, iterate over the current dimension and add the indices to
 | 
					 | 
				
			||||||
    // the list.
 | 
					 | 
				
			||||||
    for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
 | 
					 | 
				
			||||||
      indices.push_back(constantIndices[i]);
 | 
					 | 
				
			||||||
      storeElements(dimension + 1);
 | 
					 | 
				
			||||||
      indices.pop_back();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
  // Start the element storing recursion from the first dimension.
 | 
					 | 
				
			||||||
  storeElements(/*dimension=*/0);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
struct ONNXConstantOpLowering : public ConversionPattern {
 | 
					struct ONNXConstantOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  static int constantID;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  ONNXConstantOpLowering(MLIRContext *ctx)
 | 
					  ONNXConstantOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {}
 | 
					      : ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {
 | 
				
			||||||
 | 
					        constantID = 0;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
					  LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
      ConversionPatternRewriter &rewriter) const final {
 | 
					      ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
| 
						 | 
					@ -58,42 +31,32 @@ struct ONNXConstantOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto memRefType = convertToMemRefType(*op->result_type_begin());
 | 
					    auto memRefType = convertToMemRefType(*op->result_type_begin());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Value alloc;
 | 
					    // Shape based computations.
 | 
				
			||||||
    bool insertDealloc = checkInsertDealloc(op);
 | 
					    auto shape = memRefType.getShape();
 | 
				
			||||||
 | 
					    int64_t numElements = 1;
 | 
				
			||||||
 | 
					    for (int i=0; i<shape.size(); ++i)
 | 
				
			||||||
 | 
					      numElements *= shape[i];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (hasAllConstantDimensions(memRefType))
 | 
					    // Emit the constant global in Krnl dialect.
 | 
				
			||||||
      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
					    auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
 | 
				
			||||||
    else
 | 
					        memRefType,
 | 
				
			||||||
      emitError(loc, "Unexpected output has non-Constant shape");
 | 
					        rewriter.getI64ArrayAttr(shape),
 | 
				
			||||||
 | 
					        constantOp.value().getValue(),
 | 
				
			||||||
 | 
					        rewriter.getStringAttr("constant_" + std::to_string(constantID)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DenseElementsAttr constantValue =
 | 
					    // Increment constant ID:
 | 
				
			||||||
        constantOp.value().getValue().cast<DenseElementsAttr>();
 | 
					    constantID++;
 | 
				
			||||||
 | 
					 | 
				
			||||||
    auto valueShape = memRefType.getShape();
 | 
					 | 
				
			||||||
    SmallVector<Value, 8> constantIndices;
 | 
					 | 
				
			||||||
    for (auto i : llvm::seq<int64_t>(
 | 
					 | 
				
			||||||
             0, *std::max_element(valueShape.begin(), valueShape.end())))
 | 
					 | 
				
			||||||
      constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // The constant operation represents a multi-dimensional constant, so we
 | 
					 | 
				
			||||||
    // will need to generate a store for each of the elements.
 | 
					 | 
				
			||||||
    if (memRefType.getElementType().isa<IntegerType>()) {
 | 
					 | 
				
			||||||
      emitConstantAndStoreOpForDenseElementsAttr<IntegerAttr>(
 | 
					 | 
				
			||||||
          rewriter, loc, constantValue, valueShape, constantIndices, alloc);
 | 
					 | 
				
			||||||
    } else if (memRefType.getElementType().isa<FloatType>()) {
 | 
					 | 
				
			||||||
      emitConstantAndStoreOpForDenseElementsAttr<FloatAttr>(
 | 
					 | 
				
			||||||
          rewriter, loc, constantValue, valueShape, constantIndices, alloc);
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      emitError(loc, "Unsupported output type");
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Replace this operation with the generated alloc.
 | 
					    // Replace this operation with the generated alloc.
 | 
				
			||||||
    rewriter.replaceOp(op, alloc);
 | 
					    // rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, constantGlobal.getResult());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return success();
 | 
					    return success();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int ONNXConstantOpLowering::constantID;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void populateLoweringONNXConstantOpPattern(
 | 
					void populateLoweringONNXConstantOpPattern(
 | 
				
			||||||
    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
  patterns.insert<ONNXConstantOpLowering>(ctx);
 | 
					  patterns.insert<ONNXConstantOpLowering>(ctx);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -192,3 +192,16 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
 | 
				
			||||||
  let parser = ?;
 | 
					  let parser = ?;
 | 
				
			||||||
  let printer = ?;
 | 
					  let printer = ?;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
 | 
				
			||||||
 | 
					  let summary = "Krnl global operation";
 | 
				
			||||||
 | 
					  let description = [{
 | 
				
			||||||
 | 
					    Operation for holding global data values.
 | 
				
			||||||
 | 
					  }];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let arguments = (ins AnyAttr:$shape, AnyAttr:$value, StrAttr:$name);
 | 
				
			||||||
 | 
					  let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let parser = ?;
 | 
				
			||||||
 | 
					  let printer = ?;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -154,6 +154,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
 | 
				
			||||||
  target.addIllegalDialect<KrnlOpsDialect>();
 | 
					  target.addIllegalDialect<KrnlOpsDialect>();
 | 
				
			||||||
  target.addLegalOp<KrnlMemcpyOp>();
 | 
					  target.addLegalOp<KrnlMemcpyOp>();
 | 
				
			||||||
  target.addLegalOp<KrnlEntryPointOp>();
 | 
					  target.addLegalOp<KrnlEntryPointOp>();
 | 
				
			||||||
 | 
					  target.addLegalOp<KrnlGlobalOp>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  OwningRewritePatternList patterns;
 | 
					  OwningRewritePatternList patterns;
 | 
				
			||||||
  patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
 | 
					  patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,10 +16,12 @@
 | 
				
			||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 | 
					#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 | 
				
			||||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
 | 
					#include "mlir/Dialect/LoopOps/LoopOps.h"
 | 
				
			||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
 | 
					#include "mlir/Target/LLVMIR/ModuleTranslation.h"
 | 
				
			||||||
#include "mlir/Pass/Pass.h"
 | 
					#include "mlir/Pass/Pass.h"
 | 
				
			||||||
#include "mlir/Transforms/DialectConversion.h"
 | 
					#include "mlir/Transforms/DialectConversion.h"
 | 
				
			||||||
#include "llvm/ADT/Sequence.h"
 | 
					#include "llvm/ADT/Sequence.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
 | 
				
			||||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
 | 
					#include "src/Dialect/Krnl/KrnlOps.hpp"
 | 
				
			||||||
#include "src/Pass/Passes.hpp"
 | 
					#include "src/Pass/Passes.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,6 +62,150 @@ static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
 | 
				
			||||||
    return memRefTy.getStructElementType(3).getArrayNumElements();
 | 
					    return memRefTy.getStructElementType(3).getArrayNumElements();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Return a symbol reference to the memcpy function, inserting it into the
 | 
				
			||||||
 | 
					/// module if necessary.
 | 
				
			||||||
 | 
					static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
 | 
				
			||||||
 | 
					    ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
 | 
				
			||||||
 | 
					  auto *context = module.getContext();
 | 
				
			||||||
 | 
					  if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
 | 
				
			||||||
 | 
					    return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
 | 
				
			||||||
 | 
					  // Create a function declaration for memcpy, the signature is:
 | 
				
			||||||
 | 
					  //   * `void (i8*, i8* , i64, i1)`
 | 
				
			||||||
 | 
					  auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
 | 
				
			||||||
 | 
					  auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
 | 
				
			||||||
 | 
					  auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
 | 
				
			||||||
 | 
					  auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
 | 
				
			||||||
 | 
					  auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
 | 
				
			||||||
 | 
					      ArrayRef<mlir::LLVM::LLVMType>(
 | 
				
			||||||
 | 
					          {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
 | 
				
			||||||
 | 
					      false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Insert the memcpy function into the body of the parent module.
 | 
				
			||||||
 | 
					  PatternRewriter::InsertionGuard insertGuard(rewriter);
 | 
				
			||||||
 | 
					  rewriter.setInsertionPointToStart(module.getBody());
 | 
				
			||||||
 | 
					  rewriter.create<LLVM::LLVMFuncOp>(
 | 
				
			||||||
 | 
					      module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
 | 
				
			||||||
 | 
					  return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// KRNL to LLVM: KrnlGlobalOpLowering
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
 | 
				
			||||||
 | 
					public:
 | 
				
			||||||
 | 
					  explicit KrnlGlobalOpLowering(MLIRContext *context,
 | 
				
			||||||
 | 
					                                LLVMTypeConverter &lowering_)
 | 
				
			||||||
 | 
					      : ConvertToLLVMPattern(KrnlGlobalOp::getOperationName(), context,
 | 
				
			||||||
 | 
					                             lowering_) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const override {
 | 
				
			||||||
 | 
					    auto *context = op->getContext();
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					    auto *llvmDialect =
 | 
				
			||||||
 | 
					        op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
 | 
				
			||||||
 | 
					    assert(llvmDialect && "expected llvm dialect to be registered");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto krnlGlobalOp = llvm::dyn_cast<KrnlGlobalOp>(op);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get module.
 | 
				
			||||||
 | 
					    ModuleOp module = op->getParentOfType<ModuleOp>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Global name.
 | 
				
			||||||
 | 
					    auto name = krnlGlobalOp.name();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute total number of elements.
 | 
				
			||||||
 | 
					    auto shape = (krnlGlobalOp.shape()).dyn_cast<ArrayAttr>();
 | 
				
			||||||
 | 
					    int64_t numElements = 1;
 | 
				
			||||||
 | 
					    for (int i=0; i<shape.size(); ++i)
 | 
				
			||||||
 | 
					      numElements *= ArrayAttrIntVal(shape, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Create the global at the entry of the module.
 | 
				
			||||||
 | 
					    LLVM::GlobalOp global;
 | 
				
			||||||
 | 
					    auto type = op->getResult(0).getType();
 | 
				
			||||||
 | 
					    auto memRefTy = type.cast<mlir::MemRefType>();
 | 
				
			||||||
 | 
					    auto llvmMemRefType =
 | 
				
			||||||
 | 
					        typeConverter.convertType(type).cast<LLVM::LLVMType>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // The element type of the array.
 | 
				
			||||||
 | 
					    auto constantElementType =
 | 
				
			||||||
 | 
					        typeConverter.convertType(memRefTy.getElementType());
 | 
				
			||||||
 | 
					    auto globalType = constantElementType;
 | 
				
			||||||
 | 
					    for (int i=shape.size() - 1; i >= 0; i--)
 | 
				
			||||||
 | 
					      globalType = LLVM::LLVMType::getArrayTy(
 | 
				
			||||||
 | 
					          globalType.cast<LLVM::LLVMType>(), ArrayAttrIntVal(shape, i));
 | 
				
			||||||
 | 
					    // The llvm type of the global (example: [2 x [8 x float]])
 | 
				
			||||||
 | 
					    auto llvmGlobalType = globalType.cast<LLVM::LLVMType>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      OpBuilder::InsertionGuard insertGuard(rewriter);
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToStart(module.getBody());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      global = rewriter.create<LLVM::GlobalOp>(loc,
 | 
				
			||||||
 | 
					          llvmGlobalType, /*isConstant=*/true,
 | 
				
			||||||
 | 
					          LLVM::Linkage::Internal, name, krnlGlobalOp.value());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Some frequently used types.
 | 
				
			||||||
 | 
					    auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
 | 
				
			||||||
 | 
					    auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Allocate the memory where the constants will be used from.
 | 
				
			||||||
 | 
					    // This is a region of local memory and needs to be emitted as an alloca.
 | 
				
			||||||
 | 
					    auto one = rewriter.create<LLVM::ConstantOp>(loc,
 | 
				
			||||||
 | 
					        llvmI64Ty, rewriter.getI64IntegerAttr(1));
 | 
				
			||||||
 | 
					    auto alloc = rewriter.create<LLVM::AllocaOp>(
 | 
				
			||||||
 | 
					        loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Copy constant value into the local alloca:
 | 
				
			||||||
 | 
					    //  - Bitcast alloc to i8*
 | 
				
			||||||
 | 
					    Value int8PtrAlloc = rewriter.create<LLVM::BitcastOp>(
 | 
				
			||||||
 | 
					        loc, llvmI8PtrTy, alloc);
 | 
				
			||||||
 | 
					    //  - Bitcast global to i8*
 | 
				
			||||||
 | 
					    Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
 | 
				
			||||||
 | 
					    Value i8PtrGlobal = rewriter.create<LLVM::BitcastOp>(
 | 
				
			||||||
 | 
					        loc, llvmI8PtrTy, globalValue);
 | 
				
			||||||
 | 
					    //  - Set size.
 | 
				
			||||||
 | 
					    Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc,
 | 
				
			||||||
 | 
					        llvmI64Ty, rewriter.getI64IntegerAttr(
 | 
				
			||||||
 | 
					            getMemRefEltSizeInBytes(memRefTy)));
 | 
				
			||||||
 | 
					    Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
 | 
				
			||||||
 | 
					        loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
 | 
				
			||||||
 | 
					    Value totalElementsSize = rewriter.create<LLVM::MulOp>(
 | 
				
			||||||
 | 
					        loc, memRefElementSize, numElementsValue);
 | 
				
			||||||
 | 
					    Value int64Size = rewriter.create<LLVM::SExtOp>(
 | 
				
			||||||
 | 
					        loc, llvmI64Ty, totalElementsSize);
 | 
				
			||||||
 | 
					    //  - Set volatile.
 | 
				
			||||||
 | 
					    Value isVolatile = rewriter.create<LLVM::ConstantOp>(
 | 
				
			||||||
 | 
					        loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
 | 
				
			||||||
 | 
					        rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
 | 
				
			||||||
 | 
					    //  - Copy constant data into the alloca.
 | 
				
			||||||
 | 
					    auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
 | 
				
			||||||
 | 
					    rewriter.create<CallOp>(
 | 
				
			||||||
 | 
					        loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
 | 
				
			||||||
 | 
					        ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Prepare data to be inserted into MemRef.
 | 
				
			||||||
 | 
					    auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
 | 
				
			||||||
 | 
					    Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
 | 
				
			||||||
 | 
					      loc, llvmConstantElementType.getPointerTo(), alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Create llvm MemRef from original MemRef and fill the data pointers.
 | 
				
			||||||
 | 
					    auto llvmMemRef = MemRefDescriptor::fromStaticShape(
 | 
				
			||||||
 | 
					      rewriter, loc, typeConverter, memRefTy, typedAlloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, {llvmMemRef});
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					private:
 | 
				
			||||||
 | 
					  static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
 | 
				
			||||||
 | 
					    return (a.getValue()[i]).cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// KRNL to LLVM: KrnlMemcpyOpLowering
 | 
					// KRNL to LLVM: KrnlMemcpyOpLowering
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					@ -120,33 +266,6 @@ public:
 | 
				
			||||||
    rewriter.eraseOp(op);
 | 
					    rewriter.eraseOp(op);
 | 
				
			||||||
    return success();
 | 
					    return success();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					 | 
				
			||||||
private:
 | 
					 | 
				
			||||||
  /// Return a symbol reference to the memcpy function, inserting it into the
 | 
					 | 
				
			||||||
  /// module if necessary.
 | 
					 | 
				
			||||||
  static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
 | 
					 | 
				
			||||||
      ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
 | 
					 | 
				
			||||||
    auto *context = module.getContext();
 | 
					 | 
				
			||||||
    if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
 | 
					 | 
				
			||||||
      return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
 | 
					 | 
				
			||||||
    // Create a function declaration for memcpy, the signature is:
 | 
					 | 
				
			||||||
    //   * `void (i8*, i8* , i64, i1)`
 | 
					 | 
				
			||||||
    auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
 | 
					 | 
				
			||||||
    auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
 | 
					 | 
				
			||||||
    auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
 | 
					 | 
				
			||||||
    auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
 | 
					 | 
				
			||||||
    auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
 | 
					 | 
				
			||||||
        ArrayRef<mlir::LLVM::LLVMType>(
 | 
					 | 
				
			||||||
            {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
 | 
					 | 
				
			||||||
        false);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Insert the memcpy function into the body of the parent module.
 | 
					 | 
				
			||||||
    PatternRewriter::InsertionGuard insertGuard(rewriter);
 | 
					 | 
				
			||||||
    rewriter.setInsertionPointToStart(module.getBody());
 | 
					 | 
				
			||||||
    rewriter.create<LLVM::LLVMFuncOp>(
 | 
					 | 
				
			||||||
        module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
 | 
					 | 
				
			||||||
    return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					@ -506,6 +625,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
 | 
				
			||||||
      /*useAlloca=*/false,
 | 
					      /*useAlloca=*/false,
 | 
				
			||||||
      /*emitCWrapper=*/true);
 | 
					      /*emitCWrapper=*/true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Lower from the `krnl` dialect i.e. the Reshape operation.
 | 
					  // Lower from the `krnl` dialect i.e. the Reshape operation.
 | 
				
			||||||
  patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
 | 
					  patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
 | 
				
			||||||
      &getContext());
 | 
					      &getContext());
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,53 @@
 | 
				
			||||||
 | 
					// RUN: onnx-mlir-opt --shape-inference --lower-frontend --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
 | 
				
			||||||
 | 
					  // CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm<"[3 x [2 x float]]">
 | 
				
			||||||
 | 
					  // CHECK: llvm.func @test_constant({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64
 | 
				
			||||||
 | 
					  // CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm<"[3 x [2 x float]]"> : (!llvm.i64) -> !llvm<"[3 x [2 x float]]*">
 | 
				
			||||||
 | 
					  // CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					  // CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm<"[3 x [2 x float]]*">
 | 
				
			||||||
 | 
					  // CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Size of the constant tensor in bytes.
 | 
				
			||||||
 | 
					  // CHECK: [[CONST4:%.+]] = llvm.mlir.constant(4 : i64) : !llvm.i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST6:%.+]] = llvm.mlir.constant(6 : i64) : !llvm.i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST_MUL1:%.+]] = llvm.mul [[CONST4]], [[CONST6]] : !llvm.i64
 | 
				
			||||||
 | 
					  // CHECK: [[GLOBAL_SIZE_BYTES:%.+]] = llvm.sext [[CONST_MUL1]] : !llvm.i64 to !llvm.i64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Volatile flag
 | 
				
			||||||
 | 
					  // CHECK: [[CONST0:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Prepare data for MemRef insertion.
 | 
				
			||||||
 | 
					  // CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"float*">
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Insert the constant value in the local MemRef.
 | 
				
			||||||
 | 
					  // CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
 | 
					  // CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
 | 
					  // CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Insert offset.
 | 
				
			||||||
 | 
					  // CHECK: [[CONST00:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Insert sizes and strides.
 | 
				
			||||||
 | 
					  // CHECK: [[CONST3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
 | 
					  // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: [[CONST2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
 | 
					  // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK: llvm.return [[MEMREF5]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -1678,22 +1678,7 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
 | 
					  %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
  // CHECK-LABEL: test_constant_dense_2d_value
 | 
					  // CHECK-LABEL: test_constant_dense_2d_value
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32>
 | 
					  // CHECK: [[RES:%.+]] = "krnl.global"() {name = "constant_0", shape = [3, 2], value = dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>} : () -> memref<3x2xf32>
 | 
				
			||||||
  // CHECK: %[[INDEX_0:.+]] = constant 0 : index
 | 
					 | 
				
			||||||
  // CHECK: %[[INDEX_1:.+]] = constant 1 : index
 | 
					 | 
				
			||||||
  // CHECK: %[[INDEX_2:.+]] = constant 2 : index
 | 
					 | 
				
			||||||
  // CHECK: [[CONSTANT_0:%.+]] = constant 0.000000e+00 : f32
 | 
					 | 
				
			||||||
  // CHECK: affine.store [[CONSTANT_0]], %0[%[[INDEX_0]], %[[INDEX_0]]] : memref<3x2xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[CONSTANT_1:%.+]] = constant 0.000000e+00 : f32
 | 
					 | 
				
			||||||
  // CHECK: affine.store [[CONSTANT_1]], %0[%[[INDEX_0]], %[[INDEX_1]]] : memref<3x2xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[CONSTANT_2:%.+]] = constant 1.000000e+00 : f32
 | 
					 | 
				
			||||||
  // CHECK: affine.store [[CONSTANT_2]], %0[%[[INDEX_1]], %[[INDEX_0]]] : memref<3x2xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[CONSTANT_3:%.+]] = constant 1.100000e+00 : f32
 | 
					 | 
				
			||||||
  // CHECK: affine.store [[CONSTANT_3]], %0[%[[INDEX_1]], %[[INDEX_1]]] : memref<3x2xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[CONSTANT_4:%.+]] = constant 2.000000e+00 : f32
 | 
					 | 
				
			||||||
  // CHECK: affine.store [[CONSTANT_4]], %0[%[[INDEX_2]], %[[INDEX_0]]] : memref<3x2xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[CONSTANT_5:%.+]] = constant 2.100000e+00 : f32
 | 
					 | 
				
			||||||
  // CHECK: affine.store [[CONSTANT_5]], %0[%[[INDEX_2]], %[[INDEX_1]]] : memref<3x2xf32>
 | 
					 | 
				
			||||||
  // CHECK: return [[RES]] : memref<3x2xf32>
 | 
					  // CHECK: return [[RES]] : memref<3x2xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue