Remove linalg.indexed_generic from mhlo lowerings to linalg

IndexedGeneric is going away. Transition to using linalg.Index instead.

PiperOrigin-RevId: 376002501
This commit is contained in:
Robert Suderman 2021-05-26 12:23:11 -07:00 committed by TensorFlow MLIR Team
parent ca09dabf1a
commit 26a0053d7d
3 changed files with 78 additions and 63 deletions

View File

@ -917,7 +917,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
? SmallVector<Value, 2>() ? SmallVector<Value, 2>()
: ExtractDynamicSizes( : ExtractDynamicSizes(
rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor); rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor);
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>( auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, loc,
/*resultTensorTypes=*/ /*resultTensorTypes=*/
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type}, isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
@ -928,10 +928,11 @@ class IotaConverter : public OpConversionPattern<OpTy> {
dyn_sizes)}, dyn_sizes)},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs, [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
ValueRange args) { Value index_op = nested_builder.create<linalg::IndexOp>(
nested_loc, iota_op.iota_dimension());
Value cast_op = nested_builder.create<IndexCastOp>( Value cast_op = nested_builder.create<IndexCastOp>(
nested_loc, ivs[iota_op.iota_dimension()], nested_loc, index_op,
nested_builder.getIntegerType( nested_builder.getIntegerType(
result_element_type.getIntOrFloatBitWidth())); result_element_type.getIntOrFloatBitWidth()));
if (result_element_type.template isa<FloatType>()) { if (result_element_type.template isa<FloatType>()) {
@ -995,17 +996,23 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
// Generate a generic op to gather the elements of the concatenate. This is // Generate a generic op to gather the elements of the concatenate. This is
// awkward standalone but allows fusion with other generic ops. // awkward standalone but allows fusion with other generic ops.
unsigned nloops = result_type.getRank(); unsigned nloops = result_type.getRank();
auto linalg_op = b.create<linalg::IndexedGenericOp>( auto linalg_op = b.create<linalg::GenericOp>(
/*resultTensorTypes=*/result_type, /*resultTensorTypes=*/result_type,
/*inputs=*/ValueRange{}, /*outputBuffers=*/result, /*inputs=*/ValueRange{}, /*outputBuffers=*/result,
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs, [&](OpBuilder& nested_builder, Location loc, ValueRange) {
ValueRange) {
OpBuilder b = nested_builder; OpBuilder b = nested_builder;
Value concat_dim_size = zero; Value concat_dim_size = zero;
Value result; Value result;
auto extract_indices = llvm::to_vector<4>(ivs);
SmallVector<Value, 4> extract_indices;
extract_indices.reserve(nloops);
for (int i = 0; i < nloops; i++) {
extract_indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
Value index_op = b.create<linalg::IndexOp>(loc, dim);
for (const Value& arg : args) { for (const Value& arg : args) {
Value new_concat_dim_size; Value new_concat_dim_size;
scf::IfOp if_op; scf::IfOp if_op;
@ -1015,7 +1022,7 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
new_concat_dim_size = b.create<AddIOp>( new_concat_dim_size = b.create<AddIOp>(
loc, concat_dim_size, b.create<memref::DimOp>(loc, arg, dim)); loc, concat_dim_size, b.create<memref::DimOp>(loc, arg, dim));
Value cmp = b.create<CmpIOp>(loc, rewriter.getI1Type(), Value cmp = b.create<CmpIOp>(loc, rewriter.getI1Type(),
CmpIPredicate::ult, ivs[dim], CmpIPredicate::ult, index_op,
new_concat_dim_size); new_concat_dim_size);
if_op = b.create<scf::IfOp>(loc, result_type.getElementType(), if_op = b.create<scf::IfOp>(loc, result_type.getElementType(),
cmp, true); cmp, true);
@ -1031,7 +1038,7 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
// Now adjust the index for the concatenated dimension to fit into // Now adjust the index for the concatenated dimension to fit into
// the selected tensor and do an extract at that position. // the selected tensor and do an extract at that position.
extract_indices[dim] = extract_indices[dim] =
b.create<SubIOp>(loc, ivs[dim], concat_dim_size); b.create<SubIOp>(loc, index_op, concat_dim_size);
Value extract = Value extract =
b.create<tensor::ExtractOp>(loc, arg, extract_indices); b.create<tensor::ExtractOp>(loc, arg, extract_indices);
b.create<scf::YieldOp>(loc, extract); b.create<scf::YieldOp>(loc, extract);
@ -2047,7 +2054,7 @@ struct TorchIndexSelectOpOnTensorsConversion
} }
Value init_op = rewriter.create<linalg::InitTensorOp>( Value init_op = rewriter.create<linalg::InitTensorOp>(
loc, dyn_sizes, result_type.getShape(), result_type.getElementType()); loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>( auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensors=*/ArrayRef<Type>{result_type}, loc, /*resultTensors=*/ArrayRef<Type>{result_type},
/*inputs=*/adaptor.index(), /*inputs=*/adaptor.index(),
/*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank)); /*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank));
@ -2057,7 +2064,6 @@ struct TorchIndexSelectOpOnTensorsConversion
// Add a block to the region. // Add a block to the region.
auto* region = &linalg_op.region(); auto* region = &linalg_op.region();
auto* block = rewriter.createBlock(region, region->end()); auto* block = rewriter.createBlock(region, region->end());
body_arg_types.append(rank, rewriter.getIndexType());
for (auto block_args : linalg_op_args) { for (auto block_args : linalg_op_args) {
body_arg_types.push_back( body_arg_types.push_back(
block_args.getType().cast<ShapedType>().getElementType()); block_args.getType().cast<ShapedType>().getElementType());
@ -2067,17 +2073,17 @@ struct TorchIndexSelectOpOnTensorsConversion
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block); rewriter.setInsertionPointToEnd(block);
SmallVector<Value, 4> indices;
Value casted_value = rewriter.create<IndexCastOp>( Value casted_value = rewriter.create<IndexCastOp>(
loc, block->getArgument(rank), rewriter.getIndexType()); loc, block->getArgument(0), rewriter.getIndexType());
SmallVector<Value, 4> indices;
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
indices.push_back(block->getArgument(i)); indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
} }
indices.push_back(casted_value); indices.push_back(casted_value);
for (int i = axis + num_indices - batch; i < rank; ++i) { for (int i = axis + num_indices - batch; i < rank; ++i) {
indices.push_back(block->getArgument(i)); indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
} }
Value res = Value res =
rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices); rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
rewriter.create<linalg::YieldOp>(loc, res); rewriter.create<linalg::YieldOp>(loc, res);

View File

@ -890,10 +890,11 @@ func @iota() -> tensor<7x10xf32> {
return %result : tensor<7x10xf32> return %result : tensor<7x10xf32>
} }
// CHECK: linalg.init_tensor // CHECK: linalg.init_tensor
// CHECK: linalg.indexed_generic // CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] // CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %{{.*}}: f32): // CHECK-NEXT: ^bb0(%{{.*}}: f32):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 // CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 // CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
@ -911,10 +912,11 @@ func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32> // CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32>
// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index // CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index
// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xf32> // CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xf32>
// CHECK: linalg.indexed_generic // CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] // CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[D2:.*]]: index, %{{.*}}: f32): // CHECK-NEXT: ^bb0(%{{.*}}: f32):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 // CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 // CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
@ -2132,15 +2134,15 @@ func @torch_index_select(%arg0: tensor<5x1x5xi32>,
// CHECK: func @torch_index_select // CHECK: func @torch_index_select
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
// CHECK: linalg.indexed_generic { // CHECK: linalg.generic {
// CHECK-SAME: indexing_maps // CHECK-SAME: indexing_maps
// CHECK-SAME: #[[MAP0]], #[[MAP1]] // CHECK-SAME: #[[MAP0]], #[[MAP1]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>) // CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
// CHECK: ^{{.+}}( // CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32):
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32):
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index // CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
// CHECK: %[[J:.+]] = linalg.index 1
// CHECK: %[[K:.+]] = linalg.index 2
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> // CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
// CHECK: linalg.yield %[[VAL2]] : i32 // CHECK: linalg.yield %[[VAL2]] : i32
@ -2160,15 +2162,15 @@ func @torch_index_select_unsigned(%arg0: tensor<5x1x5xui32>,
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
// CHECK: %[[INPUT_SIGNLESS:.*]] = unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32> // CHECK: %[[INPUT_SIGNLESS:.*]] = unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32>
// CHECK: %[[RES:.+]] = linalg.indexed_generic { // CHECK: %[[RES:.+]] = linalg.generic {
// CHECK-SAME: indexing_maps // CHECK-SAME: indexing_maps
// CHECK-SAME: #[[MAP0]], #[[MAP1]] // CHECK-SAME: #[[MAP0]], #[[MAP1]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>) // CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
// CHECK: ^{{.+}}( // CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32):
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32):
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index // CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
// CHECK: %[[J:.+]] = linalg.index 1
// CHECK: %[[K:.+]] = linalg.index 2
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT_SIGNLESS]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> // CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT_SIGNLESS]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
// CHECK: linalg.yield %[[VAL2]] : i32 // CHECK: linalg.yield %[[VAL2]] : i32
// CHECK: %[[RES_UNSIGNED:.+]] = unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32> // CHECK: %[[RES_UNSIGNED:.+]] = unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32>
@ -2191,14 +2193,14 @@ func @torch_index_select_scalar(%arg0: tensor<4x8xf32>,
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
// CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32> // CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32>
// CHECK: linalg.indexed_generic { // CHECK: linalg.generic {
// CHECK-SAME: indexing_maps // CHECK-SAME: indexing_maps
// CHECK-SAME: #[[MAP0]], #[[MAP1]] // CHECK-SAME: #[[MAP0]], #[[MAP1]]
// CHECK-SAME: iterator_types = ["parallel"] // CHECK-SAME: iterator_types = ["parallel"]
// CHECK-SAME: ins(%[[INDEX]] : tensor<i32>) outs(%[[T0]] : tensor<8xf32>) // CHECK-SAME: ins(%[[INDEX]] : tensor<i32>) outs(%[[T0]] : tensor<8xf32>)
// CHECK: ^{{.+}}( // CHECK: ^{{.+}}(%[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32):
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32):
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index // CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
// CHECK: %[[I:.+]] = linalg.index 0
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32> // CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32>
// CHECK: linalg.yield %[[VAL2]] : f32 // CHECK: linalg.yield %[[VAL2]] : f32
@ -2217,16 +2219,16 @@ func @torch_index_select_batch(%arg0: tensor<4x7x8x2xf32>,
// CHECK: func @torch_index_select_batch // CHECK: func @torch_index_select_batch
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
// CHECK: linalg.indexed_generic { // CHECK: linalg.generic {
// CHECK-SAME: indexing_maps // CHECK-SAME: indexing_maps
// CHECK-SAME: #[[MAP0]], #[[MAP1]] // CHECK-SAME: #[[MAP0]], #[[MAP1]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[INDEX]] : tensor<4x1xi32>) // CHECK-SAME: ins(%[[INDEX]] : tensor<4x1xi32>)
// CHECK-NEXT: ^{{.+}}( // CHECK-NEXT: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: f32):
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[J:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[K:[a-zA-Z0-9_]+]]: index, %[[L:.+]]: index,
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: f32):
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index // CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
// CHECK: %[[I:.+]] = linalg.index 0
// CHECK: %[[J:.+]] = linalg.index 1
// CHECK: %[[L:.+]] = linalg.index 3
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32> // CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32>
// CHECK: linalg.yield %[[VAL2]] : f32 // CHECK: linalg.yield %[[VAL2]] : f32
@ -2254,19 +2256,19 @@ func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
// CHECK: %[[C3:.+]] = constant 3 : index // CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[D3:.+]] = memref.dim %[[INPUT]], %[[C3]] // CHECK: %[[D3:.+]] = memref.dim %[[INPUT]], %[[C3]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]] // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]]
// CHECK: %[[RESULT:.+]] = linalg.indexed_generic // CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[INDEX]] : tensor<?x?xi32>) // CHECK-SAME: ins(%[[INDEX]] : tensor<?x?xi32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?x?xf32>) // CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?x?xf32>)
// CHECK: ^{{.+}}( // CHECK: ^{{.+}}(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32)
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index // CHECK: %[[POS:.+]] = index_cast %[[ARG0]]
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index, // CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32) // CHECK: %[[IDX1:.+]] = linalg.index 1
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]] // CHECK: %[[IDX3:.+]] = linalg.index 3
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]] // CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[IDX0]], %[[IDX1]], %[[POS]], %[[IDX3]]]
// CHECK: linalg.yield %[[YIELD]] // CHECK: linalg.yield %[[YIELD]]
// ----- // -----
@ -2287,27 +2289,30 @@ func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_11]] : tensor<?x?xi32> // CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_11]] : tensor<?x?xi32>
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index // CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32> // CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
// CHECK: %[[VAL_15:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) { // CHECK: %[[VAL_15:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32): // CHECK: ^bb0(%[[VAL_18:.*]]: i32):
// CHECK: %[[VAL_16:.*]] = linalg.index 0
// CHECK: %[[VAL_17:.*]] = linalg.index 1
// CHECK: %[[DIM:.*]] = linalg.index 1
// CHECK: %[[VAL_19:.*]] = constant 1 : index // CHECK: %[[VAL_19:.*]] = constant 1 : index
// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xi32> // CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xi32>
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index // CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index // CHECK: %[[VAL_22:.*]] = cmpi ult, %[[DIM]], %[[VAL_21]] : index
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) { // CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index // CHECK: %[[VAL_24:.*]] = subi %[[DIM]], %[[VAL_3]] : index
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32> // CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_25]] : i32 // CHECK: scf.yield %[[VAL_25]] : i32
// CHECK: } else { // CHECK: } else {
// CHECK: %[[VAL_26:.*]] = constant 1 : index // CHECK: %[[VAL_26:.*]] = constant 1 : index
// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_1]], %[[VAL_26]] : tensor<?x?xi32> // CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_1]], %[[VAL_26]] : tensor<?x?xi32>
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index // CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index // CHECK: %[[VAL_29:.*]] = cmpi ult, %[[DIM]], %[[VAL_28]] : index
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) { // CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index // CHECK: %[[VAL_31:.*]] = subi %[[DIM]], %[[VAL_21]] : index
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_1]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32> // CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_1]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_32]] : i32 // CHECK: scf.yield %[[VAL_32]] : i32
// CHECK: } else { // CHECK: } else {
// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index // CHECK: %[[VAL_33:.*]] = subi %[[DIM]], %[[VAL_28]] : index
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_2]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32> // CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_2]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_34]] : i32 // CHECK: scf.yield %[[VAL_34]] : i32
// CHECK: } // CHECK: }
@ -2345,27 +2350,30 @@ func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>)
// CHECK: %[[VAL_12:.*]] = memref.dim %[[C_SIGNLESS]], %[[VAL_11]] : tensor<?x?xi32> // CHECK: %[[VAL_12:.*]] = memref.dim %[[C_SIGNLESS]], %[[VAL_11]] : tensor<?x?xi32>
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index // CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32> // CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) { // CHECK: %[[RET_SIGNLESS:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32): // CHECK: ^bb0(%[[VAL_18:.*]]: i32):
// CHECK: %[[VAL_16:.*]] = linalg.index 0
// CHECK: %[[VAL_17:.*]] = linalg.index 1
// CHECK: %[[DIM:.*]] = linalg.index 1
// CHECK: %[[VAL_19:.*]] = constant 1 : index // CHECK: %[[VAL_19:.*]] = constant 1 : index
// CHECK: %[[VAL_20:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_19]] : tensor<?x?xi32> // CHECK: %[[VAL_20:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_19]] : tensor<?x?xi32>
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index // CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index // CHECK: %[[VAL_22:.*]] = cmpi ult, %[[DIM]], %[[VAL_21]] : index
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) { // CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index // CHECK: %[[VAL_24:.*]] = subi %[[DIM]], %[[VAL_3]] : index
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[A_SIGNLESS]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32> // CHECK: %[[VAL_25:.*]] = tensor.extract %[[A_SIGNLESS]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_25]] : i32 // CHECK: scf.yield %[[VAL_25]] : i32
// CHECK: } else { // CHECK: } else {
// CHECK: %[[VAL_26:.*]] = constant 1 : index // CHECK: %[[VAL_26:.*]] = constant 1 : index
// CHECK: %[[VAL_27:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_26]] : tensor<?x?xi32> // CHECK: %[[VAL_27:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_26]] : tensor<?x?xi32>
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index // CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index // CHECK: %[[VAL_29:.*]] = cmpi ult, %[[DIM]], %[[VAL_28]] : index
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) { // CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index // CHECK: %[[VAL_31:.*]] = subi %[[DIM]], %[[VAL_21]] : index
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[B_SIGNLESS]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32> // CHECK: %[[VAL_32:.*]] = tensor.extract %[[B_SIGNLESS]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_32]] : i32 // CHECK: scf.yield %[[VAL_32]] : i32
// CHECK: } else { // CHECK: } else {
// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index // CHECK: %[[VAL_33:.*]] = subi %[[DIM]], %[[VAL_28]] : index
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[C_SIGNLESS]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32> // CHECK: %[[VAL_34:.*]] = tensor.extract %[[C_SIGNLESS]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_34]] : i32 // CHECK: scf.yield %[[VAL_34]] : i32
// CHECK: } // CHECK: }

View File

@ -276,9 +276,10 @@ func @iota(%out: memref<7x10xf32>) {
"lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () "lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
return return
} }
// CHECK: linalg.indexed_generic // CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] // CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[RESULT:.*]]: f32): // CHECK-NEXT: ^bb0(%[[RESULT:.*]]: f32):
// CHECK-NEXT: %[[D1:.+]] = linalg.index 1
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 // CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 // CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32