Remove linalg.indexed_generic from mhlo lowerings to linalg
IndexedGeneric is going away. Transition to using linalg.Index instead. PiperOrigin-RevId: 376002501
This commit is contained in:
parent
ca09dabf1a
commit
26a0053d7d
|
@ -917,7 +917,7 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
? SmallVector<Value, 2>()
|
? SmallVector<Value, 2>()
|
||||||
: ExtractDynamicSizes(
|
: ExtractDynamicSizes(
|
||||||
rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor);
|
rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor);
|
||||||
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||||
loc,
|
loc,
|
||||||
/*resultTensorTypes=*/
|
/*resultTensorTypes=*/
|
||||||
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
|
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
|
||||||
|
@ -928,10 +928,11 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
||||||
dyn_sizes)},
|
dyn_sizes)},
|
||||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
GetNParallelLoopsAttrs(nloops),
|
GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
||||||
ValueRange args) {
|
Value index_op = nested_builder.create<linalg::IndexOp>(
|
||||||
|
nested_loc, iota_op.iota_dimension());
|
||||||
Value cast_op = nested_builder.create<IndexCastOp>(
|
Value cast_op = nested_builder.create<IndexCastOp>(
|
||||||
nested_loc, ivs[iota_op.iota_dimension()],
|
nested_loc, index_op,
|
||||||
nested_builder.getIntegerType(
|
nested_builder.getIntegerType(
|
||||||
result_element_type.getIntOrFloatBitWidth()));
|
result_element_type.getIntOrFloatBitWidth()));
|
||||||
if (result_element_type.template isa<FloatType>()) {
|
if (result_element_type.template isa<FloatType>()) {
|
||||||
|
@ -995,17 +996,23 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
|
||||||
// Generate a generic op to gather the elements of the concatenate. This is
|
// Generate a generic op to gather the elements of the concatenate. This is
|
||||||
// awkward standalone but allows fusion with other generic ops.
|
// awkward standalone but allows fusion with other generic ops.
|
||||||
unsigned nloops = result_type.getRank();
|
unsigned nloops = result_type.getRank();
|
||||||
auto linalg_op = b.create<linalg::IndexedGenericOp>(
|
auto linalg_op = b.create<linalg::GenericOp>(
|
||||||
/*resultTensorTypes=*/result_type,
|
/*resultTensorTypes=*/result_type,
|
||||||
/*inputs=*/ValueRange{}, /*outputBuffers=*/result,
|
/*inputs=*/ValueRange{}, /*outputBuffers=*/result,
|
||||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
GetNParallelLoopsAttrs(nloops),
|
GetNParallelLoopsAttrs(nloops),
|
||||||
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs,
|
[&](OpBuilder& nested_builder, Location loc, ValueRange) {
|
||||||
ValueRange) {
|
|
||||||
OpBuilder b = nested_builder;
|
OpBuilder b = nested_builder;
|
||||||
Value concat_dim_size = zero;
|
Value concat_dim_size = zero;
|
||||||
Value result;
|
Value result;
|
||||||
auto extract_indices = llvm::to_vector<4>(ivs);
|
|
||||||
|
SmallVector<Value, 4> extract_indices;
|
||||||
|
extract_indices.reserve(nloops);
|
||||||
|
for (int i = 0; i < nloops; i++) {
|
||||||
|
extract_indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value index_op = b.create<linalg::IndexOp>(loc, dim);
|
||||||
for (const Value& arg : args) {
|
for (const Value& arg : args) {
|
||||||
Value new_concat_dim_size;
|
Value new_concat_dim_size;
|
||||||
scf::IfOp if_op;
|
scf::IfOp if_op;
|
||||||
|
@ -1015,7 +1022,7 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
|
||||||
new_concat_dim_size = b.create<AddIOp>(
|
new_concat_dim_size = b.create<AddIOp>(
|
||||||
loc, concat_dim_size, b.create<memref::DimOp>(loc, arg, dim));
|
loc, concat_dim_size, b.create<memref::DimOp>(loc, arg, dim));
|
||||||
Value cmp = b.create<CmpIOp>(loc, rewriter.getI1Type(),
|
Value cmp = b.create<CmpIOp>(loc, rewriter.getI1Type(),
|
||||||
CmpIPredicate::ult, ivs[dim],
|
CmpIPredicate::ult, index_op,
|
||||||
new_concat_dim_size);
|
new_concat_dim_size);
|
||||||
if_op = b.create<scf::IfOp>(loc, result_type.getElementType(),
|
if_op = b.create<scf::IfOp>(loc, result_type.getElementType(),
|
||||||
cmp, true);
|
cmp, true);
|
||||||
|
@ -1031,7 +1038,7 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
|
||||||
// Now adjust the index for the concatenated dimension to fit into
|
// Now adjust the index for the concatenated dimension to fit into
|
||||||
// the selected tensor and do an extract at that position.
|
// the selected tensor and do an extract at that position.
|
||||||
extract_indices[dim] =
|
extract_indices[dim] =
|
||||||
b.create<SubIOp>(loc, ivs[dim], concat_dim_size);
|
b.create<SubIOp>(loc, index_op, concat_dim_size);
|
||||||
Value extract =
|
Value extract =
|
||||||
b.create<tensor::ExtractOp>(loc, arg, extract_indices);
|
b.create<tensor::ExtractOp>(loc, arg, extract_indices);
|
||||||
b.create<scf::YieldOp>(loc, extract);
|
b.create<scf::YieldOp>(loc, extract);
|
||||||
|
@ -2047,7 +2054,7 @@ struct TorchIndexSelectOpOnTensorsConversion
|
||||||
}
|
}
|
||||||
Value init_op = rewriter.create<linalg::InitTensorOp>(
|
Value init_op = rewriter.create<linalg::InitTensorOp>(
|
||||||
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
|
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
|
||||||
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||||
loc, /*resultTensors=*/ArrayRef<Type>{result_type},
|
loc, /*resultTensors=*/ArrayRef<Type>{result_type},
|
||||||
/*inputs=*/adaptor.index(),
|
/*inputs=*/adaptor.index(),
|
||||||
/*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank));
|
/*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank));
|
||||||
|
@ -2057,7 +2064,6 @@ struct TorchIndexSelectOpOnTensorsConversion
|
||||||
// Add a block to the region.
|
// Add a block to the region.
|
||||||
auto* region = &linalg_op.region();
|
auto* region = &linalg_op.region();
|
||||||
auto* block = rewriter.createBlock(region, region->end());
|
auto* block = rewriter.createBlock(region, region->end());
|
||||||
body_arg_types.append(rank, rewriter.getIndexType());
|
|
||||||
for (auto block_args : linalg_op_args) {
|
for (auto block_args : linalg_op_args) {
|
||||||
body_arg_types.push_back(
|
body_arg_types.push_back(
|
||||||
block_args.getType().cast<ShapedType>().getElementType());
|
block_args.getType().cast<ShapedType>().getElementType());
|
||||||
|
@ -2067,17 +2073,17 @@ struct TorchIndexSelectOpOnTensorsConversion
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToEnd(block);
|
rewriter.setInsertionPointToEnd(block);
|
||||||
|
|
||||||
SmallVector<Value, 4> indices;
|
|
||||||
Value casted_value = rewriter.create<IndexCastOp>(
|
Value casted_value = rewriter.create<IndexCastOp>(
|
||||||
loc, block->getArgument(rank), rewriter.getIndexType());
|
loc, block->getArgument(0), rewriter.getIndexType());
|
||||||
|
|
||||||
|
SmallVector<Value, 4> indices;
|
||||||
for (int i = 0; i < axis; ++i) {
|
for (int i = 0; i < axis; ++i) {
|
||||||
indices.push_back(block->getArgument(i));
|
indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
|
||||||
}
|
}
|
||||||
indices.push_back(casted_value);
|
indices.push_back(casted_value);
|
||||||
for (int i = axis + num_indices - batch; i < rank; ++i) {
|
for (int i = axis + num_indices - batch; i < rank; ++i) {
|
||||||
indices.push_back(block->getArgument(i));
|
indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
|
||||||
}
|
}
|
||||||
|
|
||||||
Value res =
|
Value res =
|
||||||
rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
|
rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
|
||||||
rewriter.create<linalg::YieldOp>(loc, res);
|
rewriter.create<linalg::YieldOp>(loc, res);
|
||||||
|
|
|
@ -890,10 +890,11 @@ func @iota() -> tensor<7x10xf32> {
|
||||||
return %result : tensor<7x10xf32>
|
return %result : tensor<7x10xf32>
|
||||||
}
|
}
|
||||||
// CHECK: linalg.init_tensor
|
// CHECK: linalg.init_tensor
|
||||||
// CHECK: linalg.indexed_generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||||
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %{{.*}}: f32):
|
// CHECK-NEXT: ^bb0(%{{.*}}: f32):
|
||||||
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
|
||||||
|
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
|
||||||
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||||
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
||||||
|
|
||||||
|
@ -911,10 +912,11 @@ func @iota(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
|
||||||
// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32>
|
// CHECK: %[[E2:.*]] = tensor.extract %[[SHAPE]][%c1] : tensor<?xi32>
|
||||||
// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index
|
// CHECK: %[[I2:.*]] = index_cast %[[E2]] : i32 to index
|
||||||
// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xf32>
|
// CHECK: linalg.init_tensor [%[[I1]], %[[I2]], 8] : tensor<?x?x8xf32>
|
||||||
// CHECK: linalg.indexed_generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||||
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[D2:.*]]: index, %{{.*}}: f32):
|
// CHECK-NEXT: ^bb0(%{{.*}}: f32):
|
||||||
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1
|
||||||
|
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[INDEX]] : index to i32
|
||||||
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||||
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
||||||
|
|
||||||
|
@ -2132,15 +2134,15 @@ func @torch_index_select(%arg0: tensor<5x1x5xi32>,
|
||||||
// CHECK: func @torch_index_select
|
// CHECK: func @torch_index_select
|
||||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: linalg.indexed_generic {
|
// CHECK: linalg.generic {
|
||||||
// CHECK-SAME: indexing_maps
|
// CHECK-SAME: indexing_maps
|
||||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
|
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
|
||||||
// CHECK: ^{{.+}}(
|
// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32):
|
||||||
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index
|
|
||||||
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32):
|
|
||||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||||
|
// CHECK: %[[J:.+]] = linalg.index 1
|
||||||
|
// CHECK: %[[K:.+]] = linalg.index 2
|
||||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
|
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
|
||||||
// CHECK: linalg.yield %[[VAL2]] : i32
|
// CHECK: linalg.yield %[[VAL2]] : i32
|
||||||
|
|
||||||
|
@ -2160,15 +2162,15 @@ func @torch_index_select_unsigned(%arg0: tensor<5x1x5xui32>,
|
||||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: %[[INPUT_SIGNLESS:.*]] = unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32>
|
// CHECK: %[[INPUT_SIGNLESS:.*]] = unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32>
|
||||||
// CHECK: %[[RES:.+]] = linalg.indexed_generic {
|
// CHECK: %[[RES:.+]] = linalg.generic {
|
||||||
// CHECK-SAME: indexing_maps
|
// CHECK-SAME: indexing_maps
|
||||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
|
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
|
||||||
// CHECK: ^{{.+}}(
|
// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32):
|
||||||
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index
|
|
||||||
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32):
|
|
||||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||||
|
// CHECK: %[[J:.+]] = linalg.index 1
|
||||||
|
// CHECK: %[[K:.+]] = linalg.index 2
|
||||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT_SIGNLESS]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
|
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT_SIGNLESS]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
|
||||||
// CHECK: linalg.yield %[[VAL2]] : i32
|
// CHECK: linalg.yield %[[VAL2]] : i32
|
||||||
// CHECK: %[[RES_UNSIGNED:.+]] = unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32>
|
// CHECK: %[[RES_UNSIGNED:.+]] = unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32>
|
||||||
|
@ -2191,14 +2193,14 @@ func @torch_index_select_scalar(%arg0: tensor<4x8xf32>,
|
||||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32>
|
// CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32>
|
||||||
// CHECK: linalg.indexed_generic {
|
// CHECK: linalg.generic {
|
||||||
// CHECK-SAME: indexing_maps
|
// CHECK-SAME: indexing_maps
|
||||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel"]
|
// CHECK-SAME: iterator_types = ["parallel"]
|
||||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<i32>) outs(%[[T0]] : tensor<8xf32>)
|
// CHECK-SAME: ins(%[[INDEX]] : tensor<i32>) outs(%[[T0]] : tensor<8xf32>)
|
||||||
// CHECK: ^{{.+}}(
|
// CHECK: ^{{.+}}(%[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32):
|
||||||
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32):
|
|
||||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||||
|
// CHECK: %[[I:.+]] = linalg.index 0
|
||||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32>
|
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32>
|
||||||
// CHECK: linalg.yield %[[VAL2]] : f32
|
// CHECK: linalg.yield %[[VAL2]] : f32
|
||||||
|
|
||||||
|
@ -2217,16 +2219,16 @@ func @torch_index_select_batch(%arg0: tensor<4x7x8x2xf32>,
|
||||||
// CHECK: func @torch_index_select_batch
|
// CHECK: func @torch_index_select_batch
|
||||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||||
// CHECK: linalg.indexed_generic {
|
// CHECK: linalg.generic {
|
||||||
// CHECK-SAME: indexing_maps
|
// CHECK-SAME: indexing_maps
|
||||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
|
||||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<4x1xi32>)
|
// CHECK-SAME: ins(%[[INDEX]] : tensor<4x1xi32>)
|
||||||
// CHECK-NEXT: ^{{.+}}(
|
// CHECK-NEXT: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: f32):
|
||||||
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[J:[a-zA-Z0-9_]+]]: index,
|
|
||||||
// CHECK-SAME: %[[K:[a-zA-Z0-9_]+]]: index, %[[L:.+]]: index,
|
|
||||||
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: f32):
|
|
||||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||||
|
// CHECK: %[[I:.+]] = linalg.index 0
|
||||||
|
// CHECK: %[[J:.+]] = linalg.index 1
|
||||||
|
// CHECK: %[[L:.+]] = linalg.index 3
|
||||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32>
|
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32>
|
||||||
// CHECK: linalg.yield %[[VAL2]] : f32
|
// CHECK: linalg.yield %[[VAL2]] : f32
|
||||||
|
|
||||||
|
@ -2254,19 +2256,19 @@ func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
|
||||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||||
// CHECK: %[[D3:.+]] = memref.dim %[[INPUT]], %[[C3]]
|
// CHECK: %[[D3:.+]] = memref.dim %[[INPUT]], %[[C3]]
|
||||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]]
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]]
|
||||||
// CHECK: %[[RESULT:.+]] = linalg.indexed_generic
|
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
|
||||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<?x?xi32>)
|
// CHECK-SAME: ins(%[[INDEX]] : tensor<?x?xi32>)
|
||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?x?xf32>)
|
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?x?xf32>)
|
||||||
// CHECK: ^{{.+}}(
|
// CHECK: ^{{.+}}(
|
||||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index,
|
|
||||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index,
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32)
|
||||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
|
// CHECK: %[[POS:.+]] = index_cast %[[ARG0]]
|
||||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index,
|
// CHECK: %[[IDX0:.+]] = linalg.index 0
|
||||||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32)
|
// CHECK: %[[IDX1:.+]] = linalg.index 1
|
||||||
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]]
|
// CHECK: %[[IDX3:.+]] = linalg.index 3
|
||||||
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]]
|
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[IDX0]], %[[IDX1]], %[[POS]], %[[IDX3]]]
|
||||||
// CHECK: linalg.yield %[[YIELD]]
|
// CHECK: linalg.yield %[[YIELD]]
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -2287,27 +2289,30 @@ func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
|
||||||
// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_11]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_11]] : tensor<?x?xi32>
|
||||||
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
|
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
|
||||||
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
|
||||||
// CHECK: %[[VAL_15:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
|
// CHECK: %[[VAL_15:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
|
||||||
// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32):
|
// CHECK: ^bb0(%[[VAL_18:.*]]: i32):
|
||||||
|
// CHECK: %[[VAL_16:.*]] = linalg.index 0
|
||||||
|
// CHECK: %[[VAL_17:.*]] = linalg.index 1
|
||||||
|
// CHECK: %[[DIM:.*]] = linalg.index 1
|
||||||
// CHECK: %[[VAL_19:.*]] = constant 1 : index
|
// CHECK: %[[VAL_19:.*]] = constant 1 : index
|
||||||
// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_20:.*]] = memref.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xi32>
|
||||||
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
|
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
|
||||||
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index
|
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[DIM]], %[[VAL_21]] : index
|
||||||
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
|
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
|
||||||
// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index
|
// CHECK: %[[VAL_24:.*]] = subi %[[DIM]], %[[VAL_3]] : index
|
||||||
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
|
||||||
// CHECK: scf.yield %[[VAL_25]] : i32
|
// CHECK: scf.yield %[[VAL_25]] : i32
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[VAL_26:.*]] = constant 1 : index
|
// CHECK: %[[VAL_26:.*]] = constant 1 : index
|
||||||
// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_1]], %[[VAL_26]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_1]], %[[VAL_26]] : tensor<?x?xi32>
|
||||||
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
|
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
|
||||||
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index
|
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[DIM]], %[[VAL_28]] : index
|
||||||
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
|
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
|
||||||
// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index
|
// CHECK: %[[VAL_31:.*]] = subi %[[DIM]], %[[VAL_21]] : index
|
||||||
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_1]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_1]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
|
||||||
// CHECK: scf.yield %[[VAL_32]] : i32
|
// CHECK: scf.yield %[[VAL_32]] : i32
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index
|
// CHECK: %[[VAL_33:.*]] = subi %[[DIM]], %[[VAL_28]] : index
|
||||||
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_2]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_2]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
|
||||||
// CHECK: scf.yield %[[VAL_34]] : i32
|
// CHECK: scf.yield %[[VAL_34]] : i32
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
@ -2345,27 +2350,30 @@ func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>)
|
||||||
// CHECK: %[[VAL_12:.*]] = memref.dim %[[C_SIGNLESS]], %[[VAL_11]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_12:.*]] = memref.dim %[[C_SIGNLESS]], %[[VAL_11]] : tensor<?x?xi32>
|
||||||
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
|
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
|
||||||
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
|
||||||
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
|
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
|
||||||
// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32):
|
// CHECK: ^bb0(%[[VAL_18:.*]]: i32):
|
||||||
|
// CHECK: %[[VAL_16:.*]] = linalg.index 0
|
||||||
|
// CHECK: %[[VAL_17:.*]] = linalg.index 1
|
||||||
|
// CHECK: %[[DIM:.*]] = linalg.index 1
|
||||||
// CHECK: %[[VAL_19:.*]] = constant 1 : index
|
// CHECK: %[[VAL_19:.*]] = constant 1 : index
|
||||||
// CHECK: %[[VAL_20:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_19]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_20:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_19]] : tensor<?x?xi32>
|
||||||
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
|
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
|
||||||
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index
|
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[DIM]], %[[VAL_21]] : index
|
||||||
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
|
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
|
||||||
// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index
|
// CHECK: %[[VAL_24:.*]] = subi %[[DIM]], %[[VAL_3]] : index
|
||||||
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[A_SIGNLESS]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[A_SIGNLESS]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
|
||||||
// CHECK: scf.yield %[[VAL_25]] : i32
|
// CHECK: scf.yield %[[VAL_25]] : i32
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[VAL_26:.*]] = constant 1 : index
|
// CHECK: %[[VAL_26:.*]] = constant 1 : index
|
||||||
// CHECK: %[[VAL_27:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_26]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_27:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_26]] : tensor<?x?xi32>
|
||||||
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
|
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
|
||||||
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index
|
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[DIM]], %[[VAL_28]] : index
|
||||||
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
|
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
|
||||||
// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index
|
// CHECK: %[[VAL_31:.*]] = subi %[[DIM]], %[[VAL_21]] : index
|
||||||
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[B_SIGNLESS]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[B_SIGNLESS]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
|
||||||
// CHECK: scf.yield %[[VAL_32]] : i32
|
// CHECK: scf.yield %[[VAL_32]] : i32
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index
|
// CHECK: %[[VAL_33:.*]] = subi %[[DIM]], %[[VAL_28]] : index
|
||||||
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[C_SIGNLESS]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
|
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[C_SIGNLESS]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
|
||||||
// CHECK: scf.yield %[[VAL_34]] : i32
|
// CHECK: scf.yield %[[VAL_34]] : i32
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
|
@ -276,9 +276,10 @@ func @iota(%out: memref<7x10xf32>) {
|
||||||
"lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
|
"lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: linalg.indexed_generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||||
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[RESULT:.*]]: f32):
|
// CHECK-NEXT: ^bb0(%[[RESULT:.*]]: f32):
|
||||||
|
// CHECK-NEXT: %[[D1:.+]] = linalg.index 1
|
||||||
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
||||||
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||||
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
||||||
|
|
Loading…
Reference in New Issue