Add support for lowering mhlo.torch_index_select to Linalg on tensors.
The change upstreams the pattern from IREE repo to MHLO repo. PiperOrigin-RevId: 363406294
This commit is contained in:
parent
1336c95920
commit
2e0ee7759b
|
@ -1773,6 +1773,99 @@ struct ReduceWindowOpOnTensorsConversion
|
|||
}
|
||||
};
|
||||
|
||||
/// Converts xla-hlo.torch_index_select op to a linalg.indexed_generic op.
|
||||
struct TorchIndexSelectOpOnTensorsConversion
|
||||
: public OpConversionPattern<mhlo::TorchIndexSelectOp> {
|
||||
using OpConversionPattern<mhlo::TorchIndexSelectOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::TorchIndexSelectOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
mhlo::TorchIndexSelectOp::Adaptor adaptor(args);
|
||||
int axis = static_cast<int>(op.dim());
|
||||
int batch = static_cast<int>(op.batch_dims());
|
||||
auto index_shaped_type = adaptor.index().getType().cast<ShapedType>();
|
||||
int num_indices = static_cast<int>(index_shaped_type.getRank());
|
||||
auto input_shaped_type = adaptor.input().getType().cast<ShapedType>();
|
||||
if (axis < 0) axis += static_cast<int>(input_shaped_type.getRank());
|
||||
if (batch < 0) batch += num_indices;
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto result_type = op.getResult().getType().cast<ShapedType>();
|
||||
int rank = static_cast<int>(result_type.getRank());
|
||||
|
||||
SmallVector<AffineMap, 2> indexing_maps;
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
}
|
||||
for (int i = 0, e = num_indices - batch; i < e; ++i) {
|
||||
exprs.push_back(rewriter.getAffineDimExpr(axis + i));
|
||||
}
|
||||
indexing_maps.emplace_back(
|
||||
AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
|
||||
indexing_maps.emplace_back(rewriter.getMultiDimIdentityMap(rank));
|
||||
|
||||
// The output shape is
|
||||
// `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
|
||||
SmallVector<Value, 4> dyn_sizes;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (!result_type.isDynamicDim(i)) continue;
|
||||
if (i < axis) {
|
||||
dyn_sizes.push_back(
|
||||
rewriter.create<memref::DimOp>(loc, adaptor.input(), i));
|
||||
} else if (i < (axis + num_indices - batch)) {
|
||||
int idx = i - axis + batch;
|
||||
dyn_sizes.push_back(
|
||||
rewriter.create<memref::DimOp>(loc, adaptor.index(), idx));
|
||||
} else {
|
||||
int idx = i - (axis + num_indices - batch) + axis + 1;
|
||||
dyn_sizes.push_back(
|
||||
rewriter.create<memref::DimOp>(loc, adaptor.input(), idx));
|
||||
}
|
||||
}
|
||||
Value init_op = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
|
||||
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
||||
loc, /*resultTensors=*/ArrayRef<Type>{result_type},
|
||||
/*inputs=*/adaptor.index(),
|
||||
/*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank));
|
||||
|
||||
SmallVector<Type, 4> body_arg_types;
|
||||
SmallVector<Value, 2> linalg_op_args = {adaptor.index()};
|
||||
// Add a block to the region.
|
||||
auto* region = &linalg_op.region();
|
||||
auto* block = rewriter.createBlock(region, region->end());
|
||||
body_arg_types.append(rank, rewriter.getIndexType());
|
||||
for (auto block_args : linalg_op_args) {
|
||||
body_arg_types.push_back(
|
||||
block_args.getType().cast<ShapedType>().getElementType());
|
||||
}
|
||||
block->addArguments(body_arg_types);
|
||||
block->addArguments(result_type.getElementType());
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
|
||||
SmallVector<Value, 4> indices;
|
||||
Value casted_value = rewriter.create<IndexCastOp>(
|
||||
loc, block->getArgument(rank), rewriter.getIndexType());
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
indices.push_back(block->getArgument(i));
|
||||
}
|
||||
indices.push_back(casted_value);
|
||||
for (int i = axis + num_indices - batch; i < rank; ++i) {
|
||||
indices.push_back(block->getArgument(i));
|
||||
}
|
||||
|
||||
Value res =
|
||||
rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
|
||||
rewriter.create<linalg::YieldOp>(loc, res);
|
||||
|
||||
rewriter.replaceOp(op, linalg_op.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
|
@ -1968,6 +2061,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
DepthwiseConvOpOnTensorsConversion,
|
||||
ReduceOnTensorsConversion,
|
||||
ReduceWindowOpOnTensorsConversion,
|
||||
TorchIndexSelectOpOnTensorsConversion,
|
||||
PadOpOnTensorsConversion>(context);
|
||||
// clang-format on
|
||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||
|
|
|
@ -1826,7 +1826,6 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1
|
|||
return %1 : tensor<1x8x8x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @reduce_window_max_nhwc
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-DAG: %[[CST:.+]] = constant dense<0xFF800000> : tensor<f32>
|
||||
|
@ -1839,3 +1838,125 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1
|
|||
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @torch_select_index(%arg0: tensor<5x1x5xi32>,
|
||||
%arg1: tensor<2xi32>) -> tensor<2x1x5xi32> {
|
||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
dim = 0 : i64,
|
||||
batch_dims = 0 : i64
|
||||
} : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32>
|
||||
return %0 : tensor<2x1x5xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK: func @torch_select_index
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||
// CHECK: linalg.indexed_generic {
|
||||
// CHECK-SAME: indexing_maps
|
||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
|
||||
// CHECK: ^{{.+}}(
|
||||
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index
|
||||
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32):
|
||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
|
||||
// CHECK: linalg.yield %[[VAL2]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
func @torch_select_index_scalar(%arg0: tensor<4x8xf32>,
|
||||
%arg1: tensor<i32>) -> tensor<8xf32> {
|
||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
batch_dims = 0 : i64,
|
||||
dim = 0 : i64
|
||||
} : (tensor<4x8xf32>, tensor<i32>) -> tensor<8xf32>
|
||||
return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
|
||||
// CHECK: func @torch_select_index_scalar
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32>
|
||||
// CHECK: linalg.indexed_generic {
|
||||
// CHECK-SAME: indexing_maps
|
||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||
// CHECK-SAME: iterator_types = ["parallel"]
|
||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<i32>) outs(%[[T0]] : tensor<8xf32>)
|
||||
// CHECK: ^{{.+}}(
|
||||
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32):
|
||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32>
|
||||
// CHECK: linalg.yield %[[VAL2]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @torch_select_index_batch(%arg0: tensor<4x7x8x2xf32>,
|
||||
%arg1: tensor<4x1xi32>) -> tensor<4x7x1x2xf32> {
|
||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
dim = 2 : i64,
|
||||
batch_dims = 1 : i64
|
||||
} : (tensor<4x7x8x2xf32>, tensor<4x1xi32>) -> tensor<4x7x1x2xf32>
|
||||
return %0 : tensor<4x7x1x2xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK: func @torch_select_index_batch
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||
// CHECK: linalg.indexed_generic {
|
||||
// CHECK-SAME: indexing_maps
|
||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<4x1xi32>)
|
||||
// CHECK-NEXT: ^{{.+}}(
|
||||
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[J:[a-zA-Z0-9_]+]]: index,
|
||||
// CHECK-SAME: %[[K:[a-zA-Z0-9_]+]]: index, %[[L:.+]]: index,
|
||||
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: f32):
|
||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32>
|
||||
// CHECK: linalg.yield %[[VAL2]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
|
||||
%index: tensor<?x?xi32>) -> tensor<?x?x?x?xf32>{
|
||||
%0 = "mhlo.torch_index_select"(%input, %index) {
|
||||
batch_dims = 1 : i64,
|
||||
dim = 2 : i64
|
||||
} : (tensor<?x?x?x?xf32>, tensor<?x?xi32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK: func @torch_index_select_dynamic
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[D0:.+]] = memref.dim %[[INPUT]], %[[C0]]
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[D1:.+]] = memref.dim %[[INPUT]], %[[C1]]
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[D2:.+]] = memref.dim %[[INDEX]], %[[C1]]
|
||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK: %[[D3:.+]] = memref.dim %[[INPUT]], %[[C3]]
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]]
|
||||
// CHECK: %[[RESULT:.+]] = linalg.indexed_generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<?x?xi32>)
|
||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?x?xf32>)
|
||||
// CHECK: ^{{.+}}(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index,
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index,
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index,
|
||||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32)
|
||||
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]]
|
||||
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]]
|
||||
// CHECK: linalg.yield %[[YIELD]]
|
||||
|
|
Loading…
Reference in New Issue