Emit constant tensors as global constants (#66)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Enable unique constant variable names.

* Emit alloca for local array. Add tests.

* Comment clean-up.

* Simplify MemRef construction.

* Fix output type.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-04-01 13:51:06 -04:00 committed by GitHub
parent b65e77305c
commit f16e79d744
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 245 additions and 103 deletions

View File

@ -47,7 +47,7 @@ struct FrontendToKrnlLoweringPass
} // end anonymous namespace. } // end anonymous namespace.
void FrontendToKrnlLoweringPass::runOnModule() { void FrontendToKrnlLoweringPass::runOnModule() {
auto module = getModule(); ModuleOp module = getModule();
// The first thing to define is the conversion target. This will define the // The first thing to define is the conversion target. This will define the
// final target for this lowering. // final target for this lowering.

View File

@ -485,3 +485,7 @@ Value emitNegativeInfinityConstantOp(
return rewriter.create<ConstantOp>(loc, constantAttr); return rewriter.create<ConstantOp>(loc, constantAttr);
} }
int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
}

View File

@ -117,6 +117,8 @@ Value emitPositiveInfinityConstantOp(
Value emitNegativeInfinityConstantOp( Value emitNegativeInfinityConstantOp(
ConversionPatternRewriter &rewriter, Location loc, Type type); ConversionPatternRewriter &rewriter, Location loc, Type type);
int64_t ArrayAttrIntVal(ArrayAttr a, int i);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// This is to get a scalar operation of a given type for a specific operation. // This is to get a scalar operation of a given type for a specific operation.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -12,40 +12,13 @@
using namespace mlir; using namespace mlir;
template <typename ElementAttr>
void emitConstantAndStoreOpForDenseElementsAttr(
ConversionPatternRewriter &rewriter, Location loc,
DenseElementsAttr constantValue, ArrayRef<int64_t> valueShape,
ArrayRef<Value> constantIndices, Value alloc) {
// The following functor recursively walks the dimensions of the constant
// shape, generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
auto valueIt = constantValue.getValues<ElementAttr>().begin();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.
if (dimension == valueShape.size()) {
rewriter.create<AffineStoreOp>(loc,
rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
llvm::makeArrayRef(indices));
return;
}
// Otherwise, iterate over the current dimension and add the indices to
// the list.
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
indices.push_back(constantIndices[i]);
storeElements(dimension + 1);
indices.pop_back();
}
};
// Start the element storing recursion from the first dimension.
storeElements(/*dimension=*/0);
}
struct ONNXConstantOpLowering : public ConversionPattern { struct ONNXConstantOpLowering : public ConversionPattern {
static int constantID;
ONNXConstantOpLowering(MLIRContext *ctx) ONNXConstantOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {
constantID = 0;
}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
@ -58,42 +31,32 @@ struct ONNXConstantOpLowering : public ConversionPattern {
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc; // Shape based computations.
bool insertDealloc = checkInsertDealloc(op); auto shape = memRefType.getShape();
int64_t numElements = 1;
for (int i=0; i<shape.size(); ++i)
numElements *= shape[i];
if (hasAllConstantDimensions(memRefType)) // Emit the constant global in Krnl dialect.
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
else memRefType,
emitError(loc, "Unexpected output has non-Constant shape"); rewriter.getI64ArrayAttr(shape),
constantOp.value().getValue(),
rewriter.getStringAttr("constant_" + std::to_string(constantID)));
DenseElementsAttr constantValue = // Increment constant ID:
constantOp.value().getValue().cast<DenseElementsAttr>(); constantID++;
auto valueShape = memRefType.getShape();
SmallVector<Value, 8> constantIndices;
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
// The constant operation represents a multi-dimensional constant, so we
// will need to generate a store for each of the elements.
if (memRefType.getElementType().isa<IntegerType>()) {
emitConstantAndStoreOpForDenseElementsAttr<IntegerAttr>(
rewriter, loc, constantValue, valueShape, constantIndices, alloc);
} else if (memRefType.getElementType().isa<FloatType>()) {
emitConstantAndStoreOpForDenseElementsAttr<FloatAttr>(
rewriter, loc, constantValue, valueShape, constantIndices, alloc);
} else {
emitError(loc, "Unsupported output type");
}
// Replace this operation with the generated alloc. // Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc); // rewriter.replaceOp(op, alloc);
rewriter.replaceOp(op, constantGlobal.getResult());
return success(); return success();
} }
}; };
int ONNXConstantOpLowering::constantID;
void populateLoweringONNXConstantOpPattern( void populateLoweringONNXConstantOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) { OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConstantOpLowering>(ctx); patterns.insert<ONNXConstantOpLowering>(ctx);

View File

@ -192,3 +192,16 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
let parser = ?; let parser = ?;
let printer = ?; let printer = ?;
} }
def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
let summary = "Krnl global operation";
let description = [{
Operation for holding global data values.
}];
let arguments = (ins AnyAttr:$shape, AnyAttr:$value, StrAttr:$name);
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
let parser = ?;
let printer = ?;
}

View File

@ -154,6 +154,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
target.addIllegalDialect<KrnlOpsDialect>(); target.addIllegalDialect<KrnlOpsDialect>();
target.addLegalOp<KrnlMemcpyOp>(); target.addLegalOp<KrnlMemcpyOp>();
target.addLegalOp<KrnlEntryPointOp>(); target.addLegalOp<KrnlEntryPointOp>();
target.addLegalOp<KrnlGlobalOp>();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering, patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,

View File

@ -16,10 +16,12 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h" #include "llvm/ADT/Sequence.h"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
@ -60,6 +62,150 @@ static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
return memRefTy.getStructElementType(3).getArrayNumElements(); return memRefTy.getStructElementType(3).getArrayNumElements();
} }
/// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
// Create a function declaration for memcpy, the signature is:
// * `void (i8*, i8* , i64, i1)`
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>(
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
false);
// Insert the memcpy function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
}
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlGlobalOpLowering
//===----------------------------------------------------------------------===//
class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
public:
explicit KrnlGlobalOpLowering(MLIRContext *context,
LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(KrnlGlobalOp::getOperationName(), context,
lowering_) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
auto krnlGlobalOp = llvm::dyn_cast<KrnlGlobalOp>(op);
// Get module.
ModuleOp module = op->getParentOfType<ModuleOp>();
// Global name.
auto name = krnlGlobalOp.name();
// Compute total number of elements.
auto shape = (krnlGlobalOp.shape()).dyn_cast<ArrayAttr>();
int64_t numElements = 1;
for (int i=0; i<shape.size(); ++i)
numElements *= ArrayAttrIntVal(shape, i);
// Create the global at the entry of the module.
LLVM::GlobalOp global;
auto type = op->getResult(0).getType();
auto memRefTy = type.cast<mlir::MemRefType>();
auto llvmMemRefType =
typeConverter.convertType(type).cast<LLVM::LLVMType>();
// The element type of the array.
auto constantElementType =
typeConverter.convertType(memRefTy.getElementType());
auto globalType = constantElementType;
for (int i=shape.size() - 1; i >= 0; i--)
globalType = LLVM::LLVMType::getArrayTy(
globalType.cast<LLVM::LLVMType>(), ArrayAttrIntVal(shape, i));
// The llvm type of the global (example: [2 x [8 x float]])
auto llvmGlobalType = globalType.cast<LLVM::LLVMType>();
{
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
global = rewriter.create<LLVM::GlobalOp>(loc,
llvmGlobalType, /*isConstant=*/true,
LLVM::Linkage::Internal, name, krnlGlobalOp.value());
}
// Some frequently used types.
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Allocate the memory where the constants will be used from.
// This is a region of local memory and needs to be emitted as an alloca.
auto one = rewriter.create<LLVM::ConstantOp>(loc,
llvmI64Ty, rewriter.getI64IntegerAttr(1));
auto alloc = rewriter.create<LLVM::AllocaOp>(
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
// Copy constant value into the local alloca:
// - Bitcast alloc to i8*
Value int8PtrAlloc = rewriter.create<LLVM::BitcastOp>(
loc, llvmI8PtrTy, alloc);
// - Bitcast global to i8*
Value globalValue = rewriter.create<LLVM::AddressOfOp>(loc, global);
Value i8PtrGlobal = rewriter.create<LLVM::BitcastOp>(
loc, llvmI8PtrTy, globalValue);
// - Set size.
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc,
llvmI64Ty, rewriter.getI64IntegerAttr(
getMemRefEltSizeInBytes(memRefTy)));
Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
Value totalElementsSize = rewriter.create<LLVM::MulOp>(
loc, memRefElementSize, numElementsValue);
Value int64Size = rewriter.create<LLVM::SExtOp>(
loc, llvmI64Ty, totalElementsSize);
// - Set volatile.
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// - Copy constant data into the alloca.
auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect);
rewriter.create<CallOp>(
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
// Prepare data to be inserted into MemRef.
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
loc, llvmConstantElementType.getPointerTo(), alloc);
// Create llvm MemRef from original MemRef and fill the data pointers.
auto llvmMemRef = MemRefDescriptor::fromStaticShape(
rewriter, loc, typeConverter, memRefTy, typedAlloc);
rewriter.replaceOp(op, {llvmMemRef});
return success();
}
private:
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
}
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlMemcpyOpLowering // KRNL to LLVM: KrnlMemcpyOpLowering
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -120,33 +266,6 @@ public:
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
private:
/// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
// Create a function declaration for memcpy, the signature is:
// * `void (i8*, i8* , i64, i1)`
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>(
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
false);
// Insert the memcpy function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
}
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -506,6 +625,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
/*useAlloca=*/false, /*useAlloca=*/false,
/*emitCWrapper=*/true); /*emitCWrapper=*/true);
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
// Lower from the `krnl` dialect i.e. the Reshape operation. // Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>( patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
&getContext()); &getContext());

View File

@ -0,0 +1,53 @@
// RUN: onnx-mlir-opt --shape-inference --lower-frontend --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
// CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm<"[3 x [2 x float]]">
// CHECK: llvm.func @test_constant({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64
// CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm<"[3 x [2 x float]]"> : (!llvm.i64) -> !llvm<"[3 x [2 x float]]*">
// CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
// CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm<"[3 x [2 x float]]*">
// CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"i8*">
/// Size of the constant tensor in bytes.
// CHECK: [[CONST4:%.+]] = llvm.mlir.constant(4 : i64) : !llvm.i64
// CHECK: [[CONST6:%.+]] = llvm.mlir.constant(6 : i64) : !llvm.i64
// CHECK: [[CONST_MUL1:%.+]] = llvm.mul [[CONST4]], [[CONST6]] : !llvm.i64
// CHECK: [[GLOBAL_SIZE_BYTES:%.+]] = llvm.sext [[CONST_MUL1]] : !llvm.i64 to !llvm.i64
/// Volatile flag
// CHECK: [[CONST0:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> !llvm.void
/// Prepare data for MemRef insertion.
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm<"[3 x [2 x float]]*"> to !llvm<"float*">
/// Insert the constant value in the local MemRef.
// CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
/// Insert offset.
// CHECK: [[CONST00:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
/// Insert sizes and strides.
// CHECK: [[CONST3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[CONST2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.return [[MEMREF5]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
}

View File

@ -1678,22 +1678,7 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_constant_dense_2d_value // CHECK-LABEL: test_constant_dense_2d_value
// CHECK: [[RES:%.+]] = alloc() : memref<3x2xf32> // CHECK: [[RES:%.+]] = "krnl.global"() {name = "constant_0", shape = [3, 2], value = dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>} : () -> memref<3x2xf32>
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
// CHECK: [[CONSTANT_0:%.+]] = constant 0.000000e+00 : f32
// CHECK: affine.store [[CONSTANT_0]], %0[%[[INDEX_0]], %[[INDEX_0]]] : memref<3x2xf32>
// CHECK: [[CONSTANT_1:%.+]] = constant 0.000000e+00 : f32
// CHECK: affine.store [[CONSTANT_1]], %0[%[[INDEX_0]], %[[INDEX_1]]] : memref<3x2xf32>
// CHECK: [[CONSTANT_2:%.+]] = constant 1.000000e+00 : f32
// CHECK: affine.store [[CONSTANT_2]], %0[%[[INDEX_1]], %[[INDEX_0]]] : memref<3x2xf32>
// CHECK: [[CONSTANT_3:%.+]] = constant 1.100000e+00 : f32
// CHECK: affine.store [[CONSTANT_3]], %0[%[[INDEX_1]], %[[INDEX_1]]] : memref<3x2xf32>
// CHECK: [[CONSTANT_4:%.+]] = constant 2.000000e+00 : f32
// CHECK: affine.store [[CONSTANT_4]], %0[%[[INDEX_2]], %[[INDEX_0]]] : memref<3x2xf32>
// CHECK: [[CONSTANT_5:%.+]] = constant 2.100000e+00 : f32
// CHECK: affine.store [[CONSTANT_5]], %0[%[[INDEX_2]], %[[INDEX_1]]] : memref<3x2xf32>
// CHECK: return [[RES]] : memref<3x2xf32> // CHECK: return [[RES]] : memref<3x2xf32>
} }