From e81a7654f9c7126ee2122be1df3cbff8209e38ac Mon Sep 17 00:00:00 2001 From: GHEORGHE-TEOD BERCEA Date: Fri, 13 Dec 2019 15:28:56 -0500 Subject: [PATCH] [MLIR] Add support for reshape (#390) * Add reshape op handling. * Lower reshape to KRNL dialect. * Add comments. * Propagate reshape to KRNL IR. * Lower KRNL reshape to affine and standard ops level dialects. * Add lowering of reshape operation to Krnl and LLVM Dialects. * Add test for LLVM IR dialect output for reshape. * Fix rebase. * Fix test variable. * Emit errors during reshape shape inference. Address other reviewer comments. --- src/compiler/dialect/krnl/krnl_ops.td | 13 ++ src/compiler/dialect/onnx/gen_doc.py | 2 +- src/compiler/dialect/onnx/onnx_ops.cpp | 42 ++++-- src/compiler/dialect/onnx/onnxop.inc | 2 +- src/compiler/pass/lower_frontend_to_krnl.cpp | 75 +++++++++- src/compiler/pass/passes.hpp | 3 + src/compiler/pass/shape_inference_pass.cpp | 5 +- src/compiler/transform/CMakeLists.txt | 4 +- src/compiler/transform/lower_krnl.cpp | 1 + src/compiler/transform/lower_to_llvm.cpp | 146 +++++++++++++++++++ src/main.cpp | 2 +- test/mlir/krnl/reshape.mlir | 16 ++ test/mlir/onnx/onnx_lowering.mlir | 31 ++++ 13 files changed, 325 insertions(+), 17 deletions(-) create mode 100644 src/compiler/transform/lower_to_llvm.cpp create mode 100644 test/mlir/krnl/reshape.mlir diff --git a/src/compiler/dialect/krnl/krnl_ops.td b/src/compiler/dialect/krnl/krnl_ops.td index 76a38d9..c410c70 100644 --- a/src/compiler/dialect/krnl/krnl_ops.td +++ b/src/compiler/dialect/krnl/krnl_ops.td @@ -181,3 +181,16 @@ def KrnlTerminatorOp : Op { // Fully specified by traits. let verifier = ?; } + +def KrnlMemcpyOp : Op { + let summary = "Krnl memcpy operation"; + let description = [{ + In the KRNL dialect the reshape op doesn't generate a new memory entry and + treats a reshape like a cast. + }]; + + let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size); + + let parser = ?; + let printer = ?; +} diff --git a/src/compiler/dialect/onnx/gen_doc.py b/src/compiler/dialect/onnx/gen_doc.py index f5c0ded..54998a4 100644 --- a/src/compiler/dialect/onnx/gen_doc.py +++ b/src/compiler/dialect/onnx/gen_doc.py @@ -266,7 +266,7 @@ def gen_schema(schema) : ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', - 'Elu', 'Selu', 'HardSigmoid'] + 'Elu', 'Selu', 'HardSigmoid', 'Reshape'] CanonicalList=['Add', 'Identity'] line_indent = ' ' diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index a900e18..00d53da 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -42,9 +42,7 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx) // Exp /// Infer the output shape of the ONNXExpOp. This method is required by the /// shape inference interface. -void ONNXExpOp::inferShapes() { - getResult()->setType(getOperand()->getType()); -} +void ONNXExpOp::inferShapes() { getResult()->setType(getOperand()->getType()); } //===----------------------------------------------------------------------===// // Tanh @@ -90,9 +88,7 @@ void ONNXSigmoidOp::inferShapes() { // Elu /// Infer the output shape of the ONNXEluOp. This method is required by the /// shape inference interface. -void ONNXEluOp::inferShapes() { - getResult()->setType(getOperand()->getType()); -} +void ONNXEluOp::inferShapes() { getResult()->setType(getOperand()->getType()); } //===----------------------------------------------------------------------===// // Relu @@ -162,9 +158,7 @@ void ONNXAndOp::inferShapes() { // Or /// Infer the output shape of the ONNXOrOp. This method is required by the /// shape inference interface. -void ONNXOrOp::inferShapes() { - getResult()->setType(getOperand(0)->getType()); -} +void ONNXOrOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } //===----------------------------------------------------------------------===// // Xor @@ -257,6 +251,36 @@ void ONNXFullGemmOp::inferShapes() { // Verify that matrix sizes are valid for multiplication and addition. // Take into account the dimensionality of the matrix. +//===----------------------------------------------------------------------===// + +// Reshape + +void ONNXReshapeOp::inferShapes() { + // Cannot infer shape if no shape tensor is specified. + if (!getOperand(1)->getType().isa()) + emitError("Shape tensor not ranked."); + + auto inputTensorTy = getOperand(0)->getType().cast(); + auto shapeTensorTy = getOperand(1)->getType().cast(); + + // Only rank 1 shape tensors are supported. + if (shapeTensorTy.getShape().size() != 1) + emitError("Shape tensor must have rank one."); + + int64_t outputRank = shapeTensorTy.getShape()[0]; + + // Shape tensor must have constant shape. + if (outputRank < 0) + emitError("Shape tensor must have constant shape."); + + SmallVector dims; + for (int i = 0; i < outputRank; ++i) + dims.emplace_back(-1); + + getResult()->setType( + RankedTensorType::get(dims, inputTensorTy.getElementType())); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/compiler/dialect/onnx/onnxop.inc b/src/compiler/dialect/onnx/onnxop.inc index 7d18c13..1cdef19 100644 --- a/src/compiler/dialect/onnx/onnxop.inc +++ b/src/compiler/dialect/onnx/onnxop.inc @@ -2197,7 +2197,7 @@ def ONNXReluOp:ONNX_Op<"Relu", } def ONNXReshapeOp:ONNX_Op<"Reshape", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Reshape operation"; let description = [{ "Reshape the input tensor similar to numpy.reshape." diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index 9416934..30f132d 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -86,7 +86,7 @@ static bool checkInsertDealloc(Operation* currentOp) { // If there is at least one result to investigate. if (currentOp->getNumResults() > 0) { auto result = currentOp->getResult(0); - for (auto operand : op.getOperands()) + for (const auto& operand : op.getOperands()) if (operand == result) insertDealloc = false; } @@ -95,6 +95,20 @@ static bool checkInsertDealloc(Operation* currentOp) { return insertDealloc; } +unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + namespace { template @@ -655,6 +669,62 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { } }; +struct ONNXReshapeOpLowering : public ConversionPattern { + ONNXReshapeOpLowering(MLIRContext* ctx) + : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + Value* alloc; + + // Compute size in bytes. + Value* tensorSize = rewriter.create(loc, + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) { + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + } else { + auto memRefShape = memRefType.getShape(); + SmallVector allocOperands; + for (int i = 0; i < memRefShape.size(); ++i) { + // The shape array can always be used to construct shape information of + // the result. + Value* index = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); + Value* loadedVal = rewriter.create(loc, operands[1], index); + Value* int64LoadedVal = rewriter.create( + loc, loadedVal, rewriter.getIntegerType(64)); + tensorSize = rewriter.create(loc, tensorSize, int64LoadedVal); + allocOperands.push_back(rewriter.create( + loc, loadedVal, rewriter.getIndexType())); + } + AllocOp allocateMemref = + rewriter.create(loc, memRefType, allocOperands); + + // Make sure to allocate at the beginning of the block if + // all dimensions are known. + auto* parentBlock = allocateMemref.getOperation()->getBlock(); + if (insertDealloc) { + auto dealloc = rewriter.create(loc, allocateMemref); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + + alloc = allocateMemref; + } + + rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. //===----------------------------------------------------------------------===// @@ -754,7 +824,8 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering>(&getContext()); + ONNXElementwiseVariadicOpLowering, + ONNXReshapeOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/src/compiler/pass/passes.hpp b/src/compiler/pass/passes.hpp index 89d08a1..26fa543 100644 --- a/src/compiler/pass/passes.hpp +++ b/src/compiler/pass/passes.hpp @@ -23,4 +23,7 @@ std::unique_ptr createLowerToKrnlPass(); /// Pass for lowering frontend dialects to Krnl IR dialect. std::unique_ptr createLowerKrnlPass(); +/// Pass for lowering Krnl dialect to LLVM dialect. +std::unique_ptr createKrnlLowerToLLVMPass(); + } // end namespace mlir diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index 678f861..f44bed9 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -75,7 +75,7 @@ class ShapeInferencePass : public mlir::FunctionPass { if (auto terminator_op = f.getBody().back().getTerminator()) { auto results = terminator_op->getOperandTypes(); f.setType(FunctionType::get(f.getType().getInputs(), - std::vector(results.begin(), results.end()), f.getContext())); + std::vector(results.begin(), results.end()), f.getContext())); } } @@ -110,7 +110,8 @@ class ShapeInferencePass : public mlir::FunctionPass { op->getName().getStringRef() != "onnx.Min" && op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.Gemm" && - op->getName().getStringRef() != "onnx.FullGemm") + op->getName().getStringRef() != "onnx.FullGemm" && + op->getName().getStringRef() != "onnx.Reshape") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); }); diff --git a/src/compiler/transform/CMakeLists.txt b/src/compiler/transform/CMakeLists.txt index 65f8130..b811513 100644 --- a/src/compiler/transform/CMakeLists.txt +++ b/src/compiler/transform/CMakeLists.txt @@ -1,4 +1,6 @@ -add_library(onnf_transform lower_krnl.cpp) +add_library(onnf_transform + lower_krnl.cpp + lower_to_llvm.cpp) target_include_directories(onnf_transform PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} diff --git a/src/compiler/transform/lower_krnl.cpp b/src/compiler/transform/lower_krnl.cpp index 921124b..36abb1c 100644 --- a/src/compiler/transform/lower_krnl.cpp +++ b/src/compiler/transform/lower_krnl.cpp @@ -142,6 +142,7 @@ void KrnlToAffineLoweringPass::runOnFunction() { target.addLegalDialect(); // We expect IR to be free of Krnl Dialect Ops. target.addIllegalDialect(); + target.addLegalOp(); OwningRewritePatternList patterns; patterns.insert operands, + ConversionPatternRewriter& rewriter) const override { + auto* context = op->getContext(); + auto loc = op->getLoc(); + auto* llvmDialect = + op->getContext()->getRegisteredDialect(); + assert(llvmDialect && "expected llvm dialect to be registered"); + + // Get a symbol reference to the memcpy function, inserting it if necessary. + ModuleOp parentModule = op->getParentOfType(); + auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect); + + // First operand. + Type dstType = + operands[0]->getType().cast().getStructElementType(1); + Value* alignedDstMemory = rewriter.create( + loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); + Value* alignedInt8PtrDstMemory = rewriter.create( + loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); + + // Second operand. + Type srcType = + operands[1]->getType().cast().getStructElementType(1); + Value* alignedSrcMemory = rewriter.create( + loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); + Value* alignedInt8PtrSrcMemory = rewriter.create( + loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); + + // Size. + Value* int64Size = rewriter.create( + loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); + + // Memcpy call + rewriter.create(loc, memcpyRef, + LLVM::LLVMType::getVoidTy(llvmDialect), + ArrayRef( + {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size})); + + rewriter.eraseOp(op); + return matchSuccess(); + } + + private: + /// Return a symbol reference to the memcpy function, inserting it into the + /// module if necessary. + static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter& rewriter, + ModuleOp module, LLVM::LLVMDialect* llvmDialect) { + auto* context = module.getContext(); + if (module.lookupSymbol("llvm.memcpy.p0i8.p0i8.i64")) + return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); + // Create a function declaration for memcpy, the signature is: + // * `void (i8*, i8* , i64, i1)` + auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy, + ArrayRef({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}), + false); + + // Insert the memcpy function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create( + module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); + return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); + } +}; +} // end namespace + +//===----------------------------------------------------------------------===// +// KRNL + Stadard + Affine dialects lowering to LLVM. +//===----------------------------------------------------------------------===// + +namespace { +struct KrnlToLLVMLoweringPass : public ModulePass { + void runOnModule() final; +}; +} // end anonymous namespace + +void KrnlToLLVMLoweringPass::runOnModule() { + // Define the target for this lowering i.e. the LLVM dialect. + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + + // Lower the MemRef types to a representation in LLVM. + LLVMTypeConverter typeConverter(&getContext()); + + // We have a combination of `krnl`, `affine`, and `std` operations. We + // lower in stages until all the code is in the LLVM dialect. + OwningRewritePatternList patterns; + populateAffineToStdConversionPatterns(patterns, &getContext()); + populateLoopToStdConversionPatterns(patterns, &getContext()); + populateStdToLLVMConversionPatterns(typeConverter, patterns); + + // Lower from the `krnl` dialect i.e. the Reshape operation. + patterns.insert(&getContext()); + + // We want to completely lower to LLVM, so we use a `FullConversion`. This + // ensures that only legal operations will remain after the conversion. + auto module = getModule(); + if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + signalPassFailure(); +} + +/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM. +std::unique_ptr mlir::createKrnlLowerToLLVMPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM."); diff --git a/src/main.cpp b/src/main.cpp index 7306aa1..3767280 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -131,7 +131,7 @@ int main(int ac, char* av[]) { pm.addPass(mlir::createLowerKrnlPass()); pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::createLowerToCFGPass()); - pm.addPass(mlir::createLowerToLLVMPass()); + pm.addPass(mlir::createKrnlLowerToLLVMPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.run(*module); diff --git a/test/mlir/krnl/reshape.mlir b/test/mlir/krnl/reshape.mlir new file mode 100644 index 0000000..b94c74b --- /dev/null +++ b/test/mlir/krnl/reshape.mlir @@ -0,0 +1,16 @@ +// RUN: dlc-opt --shape-inference --lower-frontend --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s + +func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi32>) -> tensor<*xf32> { + %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor, tensor<4xi32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) + // CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> + // CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> + // CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*"> + // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*"> + // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64 + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) -> !llvm.void + // CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> +} diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 4ee9453..17d3609 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -279,6 +279,37 @@ func @test_relu(%arg0 : tensor) -> tensor<*xf32> { // CHECK: return [[RES]] : memref } +func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi32>) -> tensor<*xf32> { + %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor, tensor<4xi32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reshape + // CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64 + // CHECK: %[[INDEX_0:.+]] = constant 0 : index + // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32> + // CHECK: [[EXT_0:%.+]] = zexti [[LOAD_0]] : i32 to i64 + // CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64 + // CHECK: [[CAST_0:%.+]] = index_cast [[LOAD_0]] : i32 to index + // CHECK: %[[INDEX_1:.+]] = constant 1 : index + // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32> + // CHECK: [[EXT_1:%.+]] = zexti [[LOAD_1]] : i32 to i64 + // CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64 + // CHECK: [[CAST_1:%.+]] = index_cast [[LOAD_1]] : i32 to index + // CHECK: %[[INDEX_2:.+]] = constant 2 : index + // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32> + // CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64 + // CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[EXT_2]] : i64 + // CHECK: [[CAST_2:%.+]] = index_cast [[LOAD_2]] : i32 to index + // CHECK: %[[INDEX_3:.+]] = constant 3 : index + // CHECK: [[LOAD_3:%.+]] = load %arg1[%[[INDEX_3]]] : memref<4xi32> + // CHECK: [[EXT_3:%.+]] = zexti [[LOAD_3]] : i32 to i64 + // CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[EXT_3]] : i64 + // CHECK: [[CAST_3:%.+]] = index_cast [[LOAD_3]] : i32 to index + // CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref + // CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[MUL_3]]) : (memref, memref, i64) -> () + // CHECK: return [[ALLOC]] : memref +} + func @test_sum(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { %0 = "onnx.Sum"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> ()