diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 3347137..3daadf8 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -917,7 +917,7 @@ class IotaConverter : public OpConversionPattern { ? SmallVector() : ExtractDynamicSizes( rewriter, loc, GetResultValue(iota_op), shape_tensor); - auto linalg_op = rewriter.create( + auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/ isLHLO ? ArrayRef{} : ArrayRef{result_shaped_type}, @@ -928,10 +928,11 @@ class IotaConverter : public OpConversionPattern { dyn_sizes)}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), - [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs, - ValueRange args) { + [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { + Value index_op = nested_builder.create( + nested_loc, iota_op.iota_dimension()); Value cast_op = nested_builder.create( - nested_loc, ivs[iota_op.iota_dimension()], + nested_loc, index_op, nested_builder.getIntegerType( result_element_type.getIntOrFloatBitWidth())); if (result_element_type.template isa()) { @@ -995,17 +996,23 @@ struct ConcatenateConverter : public OpConversionPattern { // Generate a generic op to gather the elements of the concatenate. This is // awkward standalone but allows fusion with other generic ops. unsigned nloops = result_type.getRank(); - auto linalg_op = b.create( + auto linalg_op = b.create( /*resultTensorTypes=*/result_type, /*inputs=*/ValueRange{}, /*outputBuffers=*/result, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), - [&](OpBuilder& nested_builder, Location loc, ValueRange ivs, - ValueRange) { + [&](OpBuilder& nested_builder, Location loc, ValueRange) { OpBuilder b = nested_builder; Value concat_dim_size = zero; Value result; - auto extract_indices = llvm::to_vector<4>(ivs); + + SmallVector extract_indices; + extract_indices.reserve(nloops); + for (int i = 0; i < nloops; i++) { + extract_indices.push_back(b.create(loc, i)); + } + + Value index_op = b.create(loc, dim); for (const Value& arg : args) { Value new_concat_dim_size; scf::IfOp if_op; @@ -1015,7 +1022,7 @@ struct ConcatenateConverter : public OpConversionPattern { new_concat_dim_size = b.create( loc, concat_dim_size, b.create(loc, arg, dim)); Value cmp = b.create(loc, rewriter.getI1Type(), - CmpIPredicate::ult, ivs[dim], + CmpIPredicate::ult, index_op, new_concat_dim_size); if_op = b.create(loc, result_type.getElementType(), cmp, true); @@ -1031,7 +1038,7 @@ struct ConcatenateConverter : public OpConversionPattern { // Now adjust the index for the concatenated dimension to fit into // the selected tensor and do an extract at that position. extract_indices[dim] = - b.create(loc, ivs[dim], concat_dim_size); + b.create(loc, index_op, concat_dim_size); Value extract = b.create(loc, arg, extract_indices); b.create(loc, extract); @@ -2047,7 +2054,7 @@ struct TorchIndexSelectOpOnTensorsConversion } Value init_op = rewriter.create( loc, dyn_sizes, result_type.getShape(), result_type.getElementType()); - auto linalg_op = rewriter.create( + auto linalg_op = rewriter.create( loc, /*resultTensors=*/ArrayRef{result_type}, /*inputs=*/adaptor.index(), /*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank)); @@ -2057,7 +2064,6 @@ struct TorchIndexSelectOpOnTensorsConversion // Add a block to the region. auto* region = &linalg_op.region(); auto* block = rewriter.createBlock(region, region->end()); - body_arg_types.append(rank, rewriter.getIndexType()); for (auto block_args : linalg_op_args) { body_arg_types.push_back( block_args.getType().cast().getElementType()); @@ -2067,17 +2073,17 @@ struct TorchIndexSelectOpOnTensorsConversion OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(block); - SmallVector indices; Value casted_value = rewriter.create( - loc, block->getArgument(rank), rewriter.getIndexType()); + loc, block->getArgument(0), rewriter.getIndexType()); + + SmallVector indices; for (int i = 0; i < axis; ++i) { - indices.push_back(block->getArgument(i)); + indices.push_back(rewriter.create(loc, i)); } indices.push_back(casted_value); for (int i = axis + num_indices - batch; i < rank; ++i) { - indices.push_back(block->getArgument(i)); + indices.push_back(rewriter.create(loc, i)); } - Value res = rewriter.create(loc, adaptor.input(), indices); rewriter.create(loc, res); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 08c158e..f8dbfe3 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -890,10 +890,11 @@ func @iota() -> tensor<7x10xf32> { return %result : tensor<7x10xf32> } // CHECK: linalg.init_tensor -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %{{.*}}: f32): -// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 +// CHECK-NEXT: ^bb0(%{{.*}}: f32): +// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 +// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 @@ -911,10 +912,11 @@ func @iota(%shape: tensor) -> tensor { // CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor // CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index // CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[D2:.*]]: index, %{{.*}}: f32): -// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 +// CHECK-NEXT: ^bb0(%{{.*}}: f32): +// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 +// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 @@ -2132,15 +2134,15 @@ func @torch_index_select(%arg0: tensor<5x1x5xi32>, // CHECK: func @torch_index_select // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -// CHECK: linalg.indexed_generic { +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps // CHECK-SAME: #[[MAP0]], #[[MAP1]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>) -// CHECK: ^{{.+}}( -// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index -// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32): +// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32): // CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index +// CHECK: %[[J:.+]] = linalg.index 1 +// CHECK: %[[K:.+]] = linalg.index 2 // CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> // CHECK: linalg.yield %[[VAL2]] : i32 @@ -2160,15 +2162,15 @@ func @torch_index_select_unsigned(%arg0: tensor<5x1x5xui32>, // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK: %[[INPUT_SIGNLESS:.*]] = unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32> -// CHECK: %[[RES:.+]] = linalg.indexed_generic { +// CHECK: %[[RES:.+]] = linalg.generic { // CHECK-SAME: indexing_maps // CHECK-SAME: #[[MAP0]], #[[MAP1]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>) -// CHECK: ^{{.+}}( -// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index -// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32): +// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32): // CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index +// CHECK: %[[J:.+]] = linalg.index 1 +// CHECK: %[[K:.+]] = linalg.index 2 // CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT_SIGNLESS]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> // CHECK: linalg.yield %[[VAL2]] : i32 // CHECK: %[[RES_UNSIGNED:.+]] = unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32> @@ -2191,14 +2193,14 @@ func @torch_index_select_scalar(%arg0: tensor<4x8xf32>, // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32> -// CHECK: linalg.indexed_generic { +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps // CHECK-SAME: #[[MAP0]], #[[MAP1]] // CHECK-SAME: iterator_types = ["parallel"] // CHECK-SAME: ins(%[[INDEX]] : tensor) outs(%[[T0]] : tensor<8xf32>) -// CHECK: ^{{.+}}( -// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32): +// CHECK: ^{{.+}}(%[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32): // CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index +// CHECK: %[[I:.+]] = linalg.index 0 // CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32> // CHECK: linalg.yield %[[VAL2]] : f32 @@ -2217,16 +2219,16 @@ func @torch_index_select_batch(%arg0: tensor<4x7x8x2xf32>, // CHECK: func @torch_index_select_batch // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -// CHECK: linalg.indexed_generic { +// CHECK: linalg.generic { // CHECK-SAME: indexing_maps // CHECK-SAME: #[[MAP0]], #[[MAP1]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[INDEX]] : tensor<4x1xi32>) -// CHECK-NEXT: ^{{.+}}( -// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[J:[a-zA-Z0-9_]+]]: index, -// CHECK-SAME: %[[K:[a-zA-Z0-9_]+]]: index, %[[L:.+]]: index, -// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: f32): +// CHECK-NEXT: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: f32): // CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index +// CHECK: %[[I:.+]] = linalg.index 0 +// CHECK: %[[J:.+]] = linalg.index 1 +// CHECK: %[[L:.+]] = linalg.index 3 // CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32> // CHECK: linalg.yield %[[VAL2]] : f32 @@ -2254,19 +2256,19 @@ func @torch_index_select_dynamic(%input: tensor, // CHECK: %[[C3:.+]] = constant 3 : index // CHECK: %[[D3:.+]] = memref.dim %[[INPUT]], %[[C3]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]] -// CHECK: %[[RESULT:.+]] = linalg.indexed_generic +// CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[INDEX]] : tensor) // CHECK-SAME: outs(%[[INIT]] : tensor) // CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index, -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index, -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index, -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32) -// CHECK: %[[POS:.+]] = index_cast %[[ARG4]] -// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]] + +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32) +// CHECK: %[[POS:.+]] = index_cast %[[ARG0]] +// CHECK: %[[IDX0:.+]] = linalg.index 0 +// CHECK: %[[IDX1:.+]] = linalg.index 1 +// CHECK: %[[IDX3:.+]] = linalg.index 3 +// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[IDX0]], %[[IDX1]], %[[POS]], %[[IDX3]]] // CHECK: linalg.yield %[[YIELD]] // ----- @@ -2287,27 +2289,30 @@ func @torch_index_select_dynamic(%input: tensor, // CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_11]] : tensor // CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index // CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor -// CHECK: %[[VAL_15:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor) { -// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32): +// CHECK: %[[VAL_15:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor) { +// CHECK: ^bb0(%[[VAL_18:.*]]: i32): +// CHECK: %[[VAL_16:.*]] = linalg.index 0 +// CHECK: %[[VAL_17:.*]] = linalg.index 1 +// CHECK: %[[DIM:.*]] = linalg.index 1 // CHECK: %[[VAL_19:.*]] = constant 1 : index // CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_0]], %[[VAL_19]] : tensor // CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index -// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index +// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[DIM]], %[[VAL_21]] : index // CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) { -// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index +// CHECK: %[[VAL_24:.*]] = subi %[[DIM]], %[[VAL_3]] : index // CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]][%[[VAL_16]], %[[VAL_24]]] : tensor // CHECK: scf.yield %[[VAL_25]] : i32 // CHECK: } else { // CHECK: %[[VAL_26:.*]] = constant 1 : index // CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_1]], %[[VAL_26]] : tensor // CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index -// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index +// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[DIM]], %[[VAL_28]] : index // CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) { -// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index +// CHECK: %[[VAL_31:.*]] = subi %[[DIM]], %[[VAL_21]] : index // CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_1]][%[[VAL_16]], %[[VAL_31]]] : tensor // CHECK: scf.yield %[[VAL_32]] : i32 // CHECK: } else { -// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index +// CHECK: %[[VAL_33:.*]] = subi %[[DIM]], %[[VAL_28]] : index // CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_2]][%[[VAL_16]], %[[VAL_33]]] : tensor // CHECK: scf.yield %[[VAL_34]] : i32 // CHECK: } @@ -2345,27 +2350,30 @@ func @concatenate(%a: tensor, %b: tensor, %c: tensor) // CHECK: %[[VAL_12:.*]] = memref.dim %[[C_SIGNLESS]], %[[VAL_11]] : tensor // CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index // CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor -// CHECK: %[[RET_SIGNLESS:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor) { -// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32): +// CHECK: %[[RET_SIGNLESS:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor) { +// CHECK: ^bb0(%[[VAL_18:.*]]: i32): +// CHECK: %[[VAL_16:.*]] = linalg.index 0 +// CHECK: %[[VAL_17:.*]] = linalg.index 1 +// CHECK: %[[DIM:.*]] = linalg.index 1 // CHECK: %[[VAL_19:.*]] = constant 1 : index // CHECK: %[[VAL_20:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_19]] : tensor // CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index -// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index +// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[DIM]], %[[VAL_21]] : index // CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) { -// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index +// CHECK: %[[VAL_24:.*]] = subi %[[DIM]], %[[VAL_3]] : index // CHECK: %[[VAL_25:.*]] = tensor.extract %[[A_SIGNLESS]][%[[VAL_16]], %[[VAL_24]]] : tensor // CHECK: scf.yield %[[VAL_25]] : i32 // CHECK: } else { // CHECK: %[[VAL_26:.*]] = constant 1 : index // CHECK: %[[VAL_27:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_26]] : tensor // CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index -// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index +// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[DIM]], %[[VAL_28]] : index // CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) { -// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index +// CHECK: %[[VAL_31:.*]] = subi %[[DIM]], %[[VAL_21]] : index // CHECK: %[[VAL_32:.*]] = tensor.extract %[[B_SIGNLESS]][%[[VAL_16]], %[[VAL_31]]] : tensor // CHECK: scf.yield %[[VAL_32]] : i32 // CHECK: } else { -// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index +// CHECK: %[[VAL_33:.*]] = subi %[[DIM]], %[[VAL_28]] : index // CHECK: %[[VAL_34:.*]] = tensor.extract %[[C_SIGNLESS]][%[[VAL_16]], %[[VAL_33]]] : tensor // CHECK: scf.yield %[[VAL_34]] : i32 // CHECK: } diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index d980782..a5cbd47 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -276,9 +276,10 @@ func @iota(%out: memref<7x10xf32>) { "lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () return } -// CHECK: linalg.indexed_generic +// CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[RESULT:.*]]: f32): +// CHECK-NEXT: ^bb0(%[[RESULT:.*]]: f32): +// CHECK-NEXT: %[[D1:.+]] = linalg.index 1 // CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32