diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 44ab6d2..8b338e4 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1257,6 +1257,7 @@ def HLO_TriangularSolveOp: HLO_Op<"triangular_solve", def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ RecursiveSideEffects, + SameVariadicOperandSize, SingleBlockImplicitTerminator<"ReturnOp"> ]>, BASE_HLO_ReduceWindowOp { @@ -1264,8 +1265,8 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ // attributes are 1-d. Attributes' leading dimension should match rank of the // inputs. let arguments = (ins - HLO_Tensor:$operand, - HLO_Tensor:$init_value, + Variadic:$inputs, + Variadic:$init_values, I64ElementsAttr:$window_dimensions, // If strides or dilations attributes are missing then the default value is // one for each of the input dimensions. Similarly, padding values are zero @@ -1276,15 +1277,36 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ OptionalAttr:$padding ); - let results = (outs HLO_Tensor); + let results = (outs Variadic); // TODO(hinsu): Verify that the attached body arguments and results are // compatible with reduce op's operands. let regions = (region SizedRegion<1>:$body); - let hasCustomHLOConverter = 1; + // Builder for non-variadic version of the operation. + let builders = [ + OpBuilder<(ins "Type":$result_type, "Value":$operand, + "Value":$init_value, + "DenseIntElementsAttr":$window_dimensions, + "DenseIntElementsAttr":$window_strides, + "DenseIntElementsAttr":$base_dilations, + "DenseIntElementsAttr":$window_dilations, + "DenseIntElementsAttr":$padding), + [{ + build($_builder, $_state, TypeRange(result_type), ValueRange(operand), + ValueRange(init_value), window_dimensions, window_strides, + base_dilations, window_dilations, padding); + }]> + ]; + let hasCustomHLOConverter = 1; + let verifier = [{ return Verify(*this); }]; // TODO(hinsu): Implement custom printer and parser. + + let extraClassDeclaration = [{ + // Get the operation used for reduction applied to `result_index`th result. + Operation *getReductionOp(int result_index); + }]; } def HLO_ReturnOp : HLO_Op<"return", [NoSideEffect, Terminator]> { diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index baa1e0a..bc11a02 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -216,12 +216,12 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_Reduce let hasCanonicalizer = 1; } -def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp { - +def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]>, + BASE_HLO_ReduceWindowOp { let arguments = (ins - Arg:$operand, - Arg:$init_value, - Arg:$out, + Arg, "", [MemRead]>:$inputs, + Arg, "", [MemRead]>:$init_values, + Arg, "", [MemWrite]>:$out, I64ElementsAttr:$window_dimensions, // If strides or dilations attributes are missing then the default value is // one for each of the input dimensions. Similarly, padding values are zero @@ -233,6 +233,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp { ); let regions = (region SizedRegion<1>:$body); + let verifier = [{ return Verify(*this); }]; } // TODO(timshen): Add a custom syntax for this. diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 1ccfed8..20a255f 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1728,6 +1728,38 @@ static LogicalResult Verify(RecvOp op) { OpFoldResult CopyOp::fold(ArrayRef operands) { return getOperand(); } +//===----------------------------------------------------------------------===// +// ReduceWindowOp +//===----------------------------------------------------------------------===// + +// For reduce-window, all `inputs` need to have compatible shapes. +static LogicalResult Verify(ReduceWindowOp op) { + if (failed(verifyCompatibleShapes(op.inputs().getTypes()))) + return op.emitOpError() << "requires same shape for all inputs"; + return success(); +} + +// Get the operation used for reduction applied to `result_index`th result. Its +// expected to be a binary operation that consumes `result_index`th and +// `result_index + operands().size`th arguments of the body. +Operation* ReduceWindowOp::getReductionOp(int result_index) { + auto return_op = cast(body().front().getTerminator()); + Operation* compute_op = return_op.results()[result_index].getDefiningOp(); + if (compute_op->getNumOperands() != 2) return nullptr; + auto arg0 = compute_op->getOperand(0).dyn_cast(); + auto arg1 = compute_op->getOperand(1).dyn_cast(); + if (!arg0 || !arg1) return nullptr; + int arg0_num = arg0.getArgNumber(); + int arg1_num = arg1.getArgNumber(); + int other_arg_index = result_index + inputs().size(); + if (arg0_num == result_index && arg1_num == other_arg_index) + return compute_op; + if (arg0_num == other_arg_index && arg1_num == result_index && + compute_op->hasTrait()) + return compute_op; + return nullptr; +} + //===----------------------------------------------------------------------===// // ReverseOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 2c14eb8..af3d85a 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -281,6 +281,17 @@ void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results, results.insert(context); } +//===----------------------------------------------------------------------===// +// ReduceWindowOp. +//===----------------------------------------------------------------------===// + +// For reduce-window, all `inputs` need to have compatible shapes. +static LogicalResult Verify(ReduceWindowOp op) { + if (failed(verifyCompatibleShapes(op.inputs().getTypes()))) + return op.emitOpError() << "requires same shape for all operands"; + return success(); +} + } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index df77bad..9a403e1 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1687,40 +1687,34 @@ struct ReduceWindowOpOnTensorsConversion /// the pooling is determined based on the body of the reduce window /// operation. This class enumerates the different variants. enum class PoolingType { + kInvalid, kMin, kMax, kAdd, }; - static PoolingType getPoolingType(Region& region) { - assert(region.getBlocks().size() == 1 && - "expected the region has exactlly one block"); - Block& block = region.front(); - assert(block.getOperations().size() == 2 && - "expected the block has exactlly two operations"); - auto op = block.begin(); - if (isa(op)) return PoolingType::kMin; - if (isa(op)) return PoolingType::kMax; - if (isa(op)) return PoolingType::kAdd; - - llvm_unreachable("unknown pooling type"); + static PoolingType getPoolingType(mhlo::ReduceWindowOp reduce_op, + int result_index) { + if (Operation* op = reduce_op.getReductionOp(result_index)) { + if (isa(*op)) return PoolingType::kMin; + if (isa(*op)) return PoolingType::kMax; + if (isa(*op)) return PoolingType::kAdd; + } + return PoolingType::kInvalid; } LogicalResult matchAndRewrite( mhlo::ReduceWindowOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const override { auto loc = op.getLoc(); - auto result_type = op.getResult().getType().cast(); - if (result_type.getRank() != 4) { + int rank = op.getResultTypes()[0].cast().getRank(); + if (rank != 4) { return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op"); } - // Create a fake window dimension. - SmallVector shapes; + SmallVector shapes; shapes.push_back(op.window_dimensions().getValue(1)); shapes.push_back(op.window_dimensions().getValue(2)); - auto fake_window_dims = rewriter.create( - loc, shapes, result_type.getElementType()); if (op.window_strides() && (op.window_strides().getValue().getValue(0) != 1 || @@ -1735,10 +1729,6 @@ struct ReduceWindowOpOnTensorsConversion op, "expected window_dimensions to be [1,x,y,1]"); } - if (!args[0].getType().cast().getElementType().isF32()) { - return rewriter.notifyMatchFailure(op, "expected element type to be f32"); - } - Attribute strides; if (op.window_stridesAttr()) { strides = rewriter.getI64VectorAttr( @@ -1756,39 +1746,62 @@ struct ReduceWindowOpOnTensorsConversion dilations = rewriter.getI64VectorAttr({1, 1}); } - Value init_tensor = rewriter.create( - loc, result_type.getShape(), result_type.getElementType()); - Value init_value = args[1]; - init_value = rewriter.create(loc, init_value); - Value filled_init_tensor = - rewriter.create(loc, init_tensor, init_value) - .getResult(0); - auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp { - return cast( - rewriter - .create>( - loc, ArrayRef{result_type}, - ValueRange{args[0], fake_window_dims.getResult()}, - filled_init_tensor, dilations, strides) - .getOperation()); - }; - linalg::LinalgOp pooling_op; - PoolingType pooling_type = getPoolingType(op.body()); - switch (pooling_type) { - case PoolingType::kMin: { - pooling_op = create_op(static_cast(nullptr)); - break; + SmallVector pooling_ops; + + ArrayRef inputs = args.take_front(op.inputs().size()); + ArrayRef init_values = args.drop_front(op.inputs().size()); + for (auto it : llvm::zip(op.getResults(), inputs, init_values)) { + OpResult result = std::get<0>(it); + Value input = std::get<1>(it); + Value init_value = std::get<2>(it); + auto result_type = result.getType().cast(); + if (!input.getType().cast().getElementType().isF32()) { + return rewriter.notifyMatchFailure(op, + "expected element type to be f32"); } - case PoolingType::kMax: { - pooling_op = create_op(static_cast(nullptr)); - break; - } - case PoolingType::kAdd: { - pooling_op = create_op(static_cast(nullptr)); - break; + + // Create a fake window dimension. + auto fake_window_dims = rewriter.create( + loc, shapes, result_type.getElementType()); + Value init_tensor = rewriter.create( + loc, result_type.getShape(), result_type.getElementType()); + init_value = rewriter.create(loc, init_value); + Value filled_init_tensor = + rewriter.create(loc, init_tensor, init_value) + .getResult(0); + auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp { + return cast( + rewriter + .create>( + loc, ArrayRef{result_type}, + ValueRange{args[0], fake_window_dims.getResult()}, + filled_init_tensor, dilations, strides) + .getOperation()); + }; + linalg::LinalgOp pooling_op; + PoolingType pooling_type = getPoolingType(op, result.getResultNumber()); + switch (pooling_type) { + case PoolingType::kMin: { + pooling_op = + create_op(static_cast(nullptr)); + break; + } + case PoolingType::kMax: { + pooling_op = + create_op(static_cast(nullptr)); + break; + } + case PoolingType::kAdd: { + pooling_op = + create_op(static_cast(nullptr)); + break; + } + case PoolingType::kInvalid: + return rewriter.notifyMatchFailure(op, "unknown reduction operation"); } + pooling_ops.push_back(pooling_op->getResult(0)); } - rewriter.replaceOp(op, pooling_op->getResult(0)); + rewriter.replaceOp(op, pooling_ops); return success(); } }; diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 01b8c00..777f3a1 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -93,8 +93,8 @@ struct MappedIvs { }; template -MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs, - OpBuilder* b) { +MappedIvs MapWindowIvsToInput(OpTy op, Value operand, ValueRange ivs, + ValueRange window_ivs, OpBuilder* b) { MappedIvs mapped_ivs; if (!op.window_strides().hasValue()) { @@ -108,7 +108,6 @@ MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs, auto padding = op.padding().getValue(); auto loc = op.getLoc(); - auto operand = op.operand(); auto operand_shape = operand.getType().template cast().getShape(); // `in_bounds` is false when the mapped indices are in the padding area. @@ -196,7 +195,7 @@ class ReduceOpConverter : public OpConversionPattern { LogicalResult matchAndRewrite( lmhlo::ReduceOp reduce_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { - // TODO(b/137624192) Implement variadic reduce. + // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp if (reduce_op.out().size() != 1) return failure(); scf::ReduceOp scf_reduce_op = @@ -312,7 +311,7 @@ class ReduceOpConverter : public OpConversionPattern { // value = input[I] // else // value = neutral_value -// accumulator = reduction_operator(output[O], value) +// accumulator = reduction_operator(accumulator, value) // output[O] = accumulator // // Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a @@ -367,6 +366,9 @@ class ReduceWindowOpConverter LogicalResult matchAndRewrite( lmhlo::ReduceWindowOp reduce_window_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { + // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp + if (reduce_window_op.out().size() != 1) return failure(); + scf::ParallelOp output_loop, window_loop; std::tie(output_loop, window_loop) = CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op, @@ -387,14 +389,14 @@ class ReduceWindowOpConverter lmhlo::ReduceWindowOp reduce_window_op, ConversionPatternRewriter* rewriter) const { auto loc = reduce_window_op.getLoc(); - Value init_value = - rewriter->create(loc, reduce_window_op.init_value()); + Value init_value = rewriter->create( + loc, reduce_window_op.init_values()[0]); Value zero = rewriter->create(loc, 0); Value one = rewriter->create(loc, 1); // Create an outer parallel loop that spans the output of ReduceWindowOp. - Value output = reduce_window_op.out(); + Value output = reduce_window_op.out()[0]; auto output_loop = MakeLoopOverShape(loc, output, rewriter); // Create a nested loop that traverses the window. @@ -429,22 +431,22 @@ class ReduceWindowOpConverter "`window_dilations` attributes yet. The attributes will be ignored."); } - Value operand = reduce_window_op.operand(); - auto operand_type = operand.getType().cast(); + Value input = reduce_window_op.inputs()[0]; + auto input_type = input.getType().cast(); // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not. - MappedIvs mapped_ivs = - MapWindowIvsToInput(reduce_window_op, output_loop.getInductionVars(), - window_loop.getInductionVars(), rewriter); + MappedIvs mapped_ivs = MapWindowIvsToInput( + reduce_window_op, input, output_loop.getInductionVars(), + window_loop.getInductionVars(), rewriter); auto elem_or_init = rewriter->create( - loc, operand_type.getElementType(), mapped_ivs.in_bounds, + loc, input_type.getElementType(), mapped_ivs.in_bounds, /*withElseRegion=*/true); OpBuilder then_builder = elem_or_init.getThenBodyBuilder(rewriter->getListener()); - Value elem = then_builder.create( - loc, reduce_window_op.operand(), mapped_ivs.ivs); + Value elem = + then_builder.create(loc, input, mapped_ivs.ivs); then_builder.create(loc, elem); OpBuilder else_builder = @@ -611,9 +613,9 @@ class SelectAndScatterOpConverter OpBuilder::atBlockEnd(window_loops.inner_loop.getBody()); // Compute ivs in 'arg' buffer and whether these ivs are in the pad area. - MappedIvs mapped_ivs = - MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(), - window_loops.window_ivs, &inner_loop_b); + MappedIvs mapped_ivs = MapWindowIvsToInput( + s_and_s_op, s_and_s_op.operand(), loop_over_src.getInductionVars(), + window_loops.window_ivs, &inner_loop_b); IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs()); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index bb097eb..457147d 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1863,6 +1863,43 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1 // ----- +func @reduce_window_sum_max_nhwc(%arg0: tensor<1x18x18x64xf32>, + %arg1: tensor) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) { + %0:2 = "mhlo.reduce_window"(%arg0, %arg0, %arg1, %arg1) ( { + ^bb0(%arg2: tensor, %arg3 : tensor, %arg4: tensor, %arg5 : tensor): + %1 = mhlo.add %arg2, %arg4 : tensor + %2 = mhlo.maximum %arg3, %arg5 : tensor + "mhlo.return"(%1, %2) : (tensor, tensor) -> () + }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x18x18x64xf32>, tensor<1x18x18x64xf32>, tensor, tensor) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) + return %0#0, %0#1 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32> +} + +// CHECK-LABEL: func @reduce_window_sum_max_nhwc +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK: %[[WINDOW0:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32> +// CHECK: %[[INIT0:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32> +// CHECK: %[[INIT_VAL0:.+]] = tensor.extract %[[ARG1]][] : tensor +// CHECK: %[[FILL0:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32> +// CHECK: %[[RES0:.+]] = linalg.pooling_nhwc_sum +// CHECK-SAME: {dilations = dense<1> : vector<2xi64> +// CHECK-SAME: strides = dense<2> : vector<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW0]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%[[FILL0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> +// CHECK: %[[WINDOW1:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32> +// CHECK: %[[INIT1:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32> +// CHECK: %[[INIT_VAL1:.+]] = tensor.extract %[[ARG1]][] : tensor +// CHECK: %[[FILL1:.+]] = linalg.fill(%[[INIT1]], %[[INIT_VAL1]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32> +// CHECK: %[[RES1:.+]] = linalg.pooling_nhwc_max +// CHECK-SAME: {dilations = dense<1> : vector<2xi64> +// CHECK-SAME: strides = dense<2> : vector<2xi64>} +// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW1]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%[[FILL1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> +// CHECK: return %[[RES0]], %[[RES1]] + +// ----- + func @torch_select_index(%arg0: tensor<5x1x5xi32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xi32> { %0 = "mhlo.torch_index_select"(%arg0, %arg1) { diff --git a/tests/ops.mlir b/tests/ops.mlir index 5c83761..604bd5a 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -1389,3 +1389,37 @@ func @custom_call_multiple_outputs(%x: tensor<2xf32>) -> tensor<2xf32> { %1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> } + +// ----- + +// CHECK: func @reduce_window +func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { + %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): // no predecessors + %2 = mhlo.add %a0, %b0 : tensor + %3 = mhlo.add %a1, %b1 : tensor + %4 = "mhlo.tuple"(%2, %3) : (tensor, tensor) -> tuple, tensor> + "mhlo.return"(%4) : (tuple, tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x2xi32>, tensor, tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) + return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32> +} + +// ----- + +func @reduce_window_invalid(%arg0: tensor<4x2xf32>, %arg1: tensor<4x3xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { + // expected-error @+1 {{requires same shape for all inputs}} + %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): // no predecessors + %2 = mhlo.add %a0, %b0 : tensor + %3 = mhlo.add %a1, %b1 : tensor + %4 = "mhlo.tuple"(%2, %3) : (tensor, tensor) -> tuple, tensor> + "mhlo.return"(%4) : (tuple, tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x3xi32>, tensor, tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) + return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32> +}