[MLIR] Add support for reshape (#390)

* Add reshape op handling.

* Lower reshape to KRNL dialect.

* Add comments.

* Propagate reshape to KRNL IR.

* Lower KRNL reshape to affine and standard ops level dialects.

* Add lowering of reshape operation to Krnl and LLVM Dialects.

* Add test for LLVM IR dialect output for reshape.

* Fix rebase.

* Fix test variable.

* Emit errors during reshape shape inference. Address other reviewer comments.
This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-12-13 15:28:56 -05:00 committed by Tian Jin
parent 5ed79083d5
commit e81a7654f9
13 changed files with 325 additions and 17 deletions

View File

@ -181,3 +181,16 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
// Fully specified by traits. // Fully specified by traits.
let verifier = ?; let verifier = ?;
} }
def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
let summary = "Krnl memcpy operation";
let description = [{
In the KRNL dialect the reshape op doesn't generate a new memory entry and
treats a reshape like a cast.
}];
let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size);
let parser = ?;
let printer = ?;
}

View File

@ -266,7 +266,7 @@ def gen_schema(schema) :
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid'] 'Elu', 'Selu', 'HardSigmoid', 'Reshape']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']
line_indent = ' ' line_indent = ' '

View File

@ -42,9 +42,7 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
// Exp // Exp
/// Infer the output shape of the ONNXExpOp. This method is required by the /// Infer the output shape of the ONNXExpOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXExpOp::inferShapes() { void ONNXExpOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Tanh // Tanh
@ -90,9 +88,7 @@ void ONNXSigmoidOp::inferShapes() {
// Elu // Elu
/// Infer the output shape of the ONNXEluOp. This method is required by the /// Infer the output shape of the ONNXEluOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXEluOp::inferShapes() { void ONNXEluOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Relu // Relu
@ -162,9 +158,7 @@ void ONNXAndOp::inferShapes() {
// Or // Or
/// Infer the output shape of the ONNXOrOp. This method is required by the /// Infer the output shape of the ONNXOrOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXOrOp::inferShapes() { void ONNXOrOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Xor // Xor
@ -257,6 +251,36 @@ void ONNXFullGemmOp::inferShapes() {
// Verify that matrix sizes are valid for multiplication and addition. // Verify that matrix sizes are valid for multiplication and addition.
// Take into account the dimensionality of the matrix. // Take into account the dimensionality of the matrix.
//===----------------------------------------------------------------------===//
// Reshape
void ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified.
if (!getOperand(1)->getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked.");
auto inputTensorTy = getOperand(0)->getType().cast<RankedTensorType>();
auto shapeTensorTy = getOperand(1)->getType().cast<RankedTensorType>();
// Only rank 1 shape tensors are supported.
if (shapeTensorTy.getShape().size() != 1)
emitError("Shape tensor must have rank one.");
int64_t outputRank = shapeTensorTy.getShape()[0];
// Shape tensor must have constant shape.
if (outputRank < 0)
emitError("Shape tensor must have constant shape.");
SmallVector<int64_t, 2> dims;
for (int i = 0; i < outputRank; ++i)
dims.emplace_back(-1);
getResult()->setType(
RankedTensorType::get(dims, inputTensorTy.getElementType()));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -2197,7 +2197,7 @@ def ONNXReluOp:ONNX_Op<"Relu",
} }
def ONNXReshapeOp:ONNX_Op<"Reshape", def ONNXReshapeOp:ONNX_Op<"Reshape",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Reshape operation"; let summary = "ONNX Reshape operation";
let description = [{ let description = [{
"Reshape the input tensor similar to numpy.reshape." "Reshape the input tensor similar to numpy.reshape."

View File

@ -86,7 +86,7 @@ static bool checkInsertDealloc(Operation* currentOp) {
// If there is at least one result to investigate. // If there is at least one result to investigate.
if (currentOp->getNumResults() > 0) { if (currentOp->getNumResults() > 0) {
auto result = currentOp->getResult(0); auto result = currentOp->getResult(0);
for (auto operand : op.getOperands()) for (const auto& operand : op.getOperands())
if (operand == result) if (operand == result)
insertDealloc = false; insertDealloc = false;
} }
@ -95,6 +95,20 @@ static bool checkInsertDealloc(Operation* currentOp) {
return insertDealloc; return insertDealloc;
} }
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
auto elementType = memRefType.getElementType();
unsigned sizeInBits;
if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
} else {
auto vectorType = elementType.cast<VectorType>();
sizeInBits =
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
}
return llvm::divideCeil(sizeInBits, 8);
}
namespace { namespace {
template <typename ElementwiseNaryOp> template <typename ElementwiseNaryOp>
@ -655,6 +669,62 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
} }
}; };
struct ONNXReshapeOpLowering : public ConversionPattern {
ONNXReshapeOpLowering(MLIRContext* ctx)
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
Value* alloc;
// Compute size in bytes.
Value* tensorSize = rewriter.create<ConstantOp>(loc,
rewriter.getIntegerAttr(
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)));
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType)) {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
} else {
auto memRefShape = memRefType.getShape();
SmallVector<Value*, 4> allocOperands;
for (int i = 0; i < memRefShape.size(); ++i) {
// The shape array can always be used to construct shape information of
// the result.
Value* index = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
Value* loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
Value* int64LoadedVal = rewriter.create<ZeroExtendIOp>(
loc, loadedVal, rewriter.getIntegerType(64));
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
allocOperands.push_back(rewriter.create<IndexCastOp>(
loc, loadedVal, rewriter.getIndexType()));
}
AllocOp allocateMemref =
rewriter.create<AllocOp>(loc, memRefType, allocOperands);
// Make sure to allocate at the beginning of the block if
// all dimensions are known.
auto* parentBlock = allocateMemref.getOperation()->getBlock();
if (insertDealloc) {
auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
alloc = allocateMemref;
}
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Conversion from Tensor type to the Standard dialect MemRef type. // Conversion from Tensor type to the Standard dialect MemRef type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -754,7 +824,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>>(&getContext()); ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`

View File

@ -23,4 +23,7 @@ std::unique_ptr<Pass> createLowerToKrnlPass();
/// Pass for lowering frontend dialects to Krnl IR dialect. /// Pass for lowering frontend dialects to Krnl IR dialect.
std::unique_ptr<Pass> createLowerKrnlPass(); std::unique_ptr<Pass> createLowerKrnlPass();
/// Pass for lowering Krnl dialect to LLVM dialect.
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
} // end namespace mlir } // end namespace mlir

View File

@ -110,7 +110,8 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
op->getName().getStringRef() != "onnx.Min" && op->getName().getStringRef() != "onnx.Min" &&
op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm") op->getName().getStringRef() != "onnx.FullGemm" &&
op->getName().getStringRef() != "onnx.Reshape")
return false; return false;
return llvm::any_of(op->getResultTypes(), return llvm::any_of(op->getResultTypes(),
[](Type result_type) { return !result_type.isa<RankedTensorType>(); }); [](Type result_type) { return !result_type.isa<RankedTensorType>(); });

View File

@ -1,4 +1,6 @@
add_library(onnf_transform lower_krnl.cpp) add_library(onnf_transform
lower_krnl.cpp
lower_to_llvm.cpp)
target_include_directories(onnf_transform target_include_directories(onnf_transform
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}

View File

@ -142,6 +142,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>(); target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
// We expect IR to be free of Krnl Dialect Ops. // We expect IR to be free of Krnl Dialect Ops.
target.addIllegalDialect<KrnlOpsDialect>(); target.addIllegalDialect<KrnlOpsDialect>();
target.addLegalOp<KrnlMemcpyOp>();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering, patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,

View File

@ -0,0 +1,146 @@
//====- LowerToLLVM.cpp - Lowering from KRNL+Affine+Std to LLVM -----------===//
//
// Copyright 2019 The DLC Authors.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/Sequence.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/pass/passes.hpp"
using namespace mlir;
namespace {
//===----------------------------------------------------------------------===//
// KRNL to LLVM: patterns which need a direct lowering to LLVM.
//===----------------------------------------------------------------------===//
class KrnlMemcpyOpLowering : public ConversionPattern {
public:
explicit KrnlMemcpyOpLowering(MLIRContext* context)
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const override {
auto* context = op->getContext();
auto loc = op->getLoc();
auto* llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
// Get a symbol reference to the memcpy function, inserting it if necessary.
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
// First operand.
Type dstType =
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value* alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
Value* alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
// Second operand.
Type srcType =
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value* alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
Value* alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
// Size.
Value* int64Size = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
// Memcpy call
rewriter.create<CallOp>(loc, memcpyRef,
LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value*>(
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
rewriter.eraseOp(op);
return matchSuccess();
}
private:
/// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter& rewriter,
ModuleOp module, LLVM::LLVMDialect* llvmDialect) {
auto* context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
// Create a function declaration for memcpy, the signature is:
// * `void (i8*, i8* , i64, i1)`
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
false);
// Insert the memcpy function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
}
};
} // end namespace
//===----------------------------------------------------------------------===//
// KRNL + Stadard + Affine dialects lowering to LLVM.
//===----------------------------------------------------------------------===//
namespace {
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
void runOnModule() final;
};
} // end anonymous namespace
void KrnlToLLVMLoweringPass::runOnModule() {
// Define the target for this lowering i.e. the LLVM dialect.
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
// Lower the MemRef types to a representation in LLVM.
LLVMTypeConverter typeConverter(&getContext());
// We have a combination of `krnl`, `affine`, and `std` operations. We
// lower in stages until all the code is in the LLVM dialect.
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
// Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering>(&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
auto module = getModule();
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
signalPassFailure();
}
/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.
std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
return std::make_unique<KrnlToLLVMLoweringPass>();
}
static PassRegistration<KrnlToLLVMLoweringPass> pass(
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");

View File

@ -131,7 +131,7 @@ int main(int ac, char* av[]) {
pm.addPass(mlir::createLowerKrnlPass()); pm.addPass(mlir::createLowerKrnlPass());
pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLowerToCFGPass()); pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(mlir::createLowerToLLVMPass()); pm.addPass(mlir::createKrnlLowerToLLVMPass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.run(*module); pm.run(*module);

View File

@ -0,0 +1,16 @@
// RUN: dlc-opt --shape-inference --lower-frontend --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64)
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) -> !llvm.void
// CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
}

View File

@ -279,6 +279,37 @@ func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
} }
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape
// CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32>
// CHECK: [[EXT_0:%.+]] = zexti [[LOAD_0]] : i32 to i64
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64
// CHECK: [[CAST_0:%.+]] = index_cast [[LOAD_0]] : i32 to index
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32>
// CHECK: [[EXT_1:%.+]] = zexti [[LOAD_1]] : i32 to i64
// CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64
// CHECK: [[CAST_1:%.+]] = index_cast [[LOAD_1]] : i32 to index
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32>
// CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[EXT_2]] : i64
// CHECK: [[CAST_2:%.+]] = index_cast [[LOAD_2]] : i32 to index
// CHECK: %[[INDEX_3:.+]] = constant 3 : index
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[INDEX_3]]] : memref<4xi32>
// CHECK: [[EXT_3:%.+]] = zexti [[LOAD_3]] : i32 to i64
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[EXT_3]] : i64
// CHECK: [[CAST_3:%.+]] = index_cast [[LOAD_3]] : i32 to index
// CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32>
// CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[MUL_3]]) : (memref<?x?x?x?xf32>, memref<?x10xf32>, i64) -> ()
// CHECK: return [[ALLOC]] : memref<?x?x?x?xf32>
}
func @test_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()