Add support for lowering mhlo.slice to subtensor.
PiperOrigin-RevId: 359297978
This commit is contained in:
parent
b478bdf00e
commit
475b4a06a5
|
@ -1035,16 +1035,16 @@ class ReverseConverter
|
|||
}
|
||||
};
|
||||
|
||||
class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class SliceConverter : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
lmhlo::SliceOp slice_op, ArrayRef<Value> args,
|
||||
OpTy slice_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = slice_op.getLoc();
|
||||
auto arg_type =
|
||||
slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
||||
auto arg_type = args[0].getType().template dyn_cast<ShapedType>();
|
||||
if (!arg_type || !arg_type.hasRank()) {
|
||||
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
||||
return failure();
|
||||
|
@ -1053,17 +1053,22 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
|||
SmallVector<OpFoldResult, 3> offsets, sizes, strides;
|
||||
for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
|
||||
offsets.push_back(rewriter.getI64IntegerAttr(
|
||||
slice_op.start_indices().getValue<int64_t>(i)));
|
||||
slice_op.start_indices().template getValue<int64_t>(i)));
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(
|
||||
slice_op.limit_indices().getValue<int64_t>(i) -
|
||||
slice_op.start_indices().getValue<int64_t>(i)));
|
||||
strides.push_back(
|
||||
rewriter.getI64IntegerAttr(slice_op.strides().getValue<int64_t>(i)));
|
||||
slice_op.limit_indices().template getValue<int64_t>(i) -
|
||||
slice_op.start_indices().template getValue<int64_t>(i)));
|
||||
strides.push_back(rewriter.getI64IntegerAttr(
|
||||
slice_op.strides().template getValue<int64_t>(i)));
|
||||
}
|
||||
if (isLHLO) {
|
||||
auto linalg_op =
|
||||
rewriter.create<SubViewOp>(loc, args[0], offsets, sizes, strides);
|
||||
rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
|
||||
rewriter.eraseOp(slice_op);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<SubTensorOp>(slice_op, args[0], offsets,
|
||||
sizes, strides);
|
||||
}
|
||||
auto linalg_slice = rewriter.create<SubViewOp>(loc, slice_op.getOperand(0),
|
||||
offsets, sizes, strides);
|
||||
rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
|
||||
rewriter.eraseOp(slice_op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1430,7 +1435,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
ReverseConverter<lmhlo::ReverseOp>,
|
||||
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
||||
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
|
||||
SliceConverter,
|
||||
SliceConverter<lmhlo::SliceOp>,
|
||||
TransposeConverter<lmhlo::TransposeOp>
|
||||
>(context);
|
||||
// clang-format on
|
||||
|
@ -1554,6 +1559,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
SliceConverter<mhlo::SliceOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
|
||||
DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
|
||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||
|
|
|
@ -1256,3 +1256,29 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
|
|||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
||||
%0 = "mhlo.slice"(%arg0) {
|
||||
start_indices = dense<[1, 0]> : tensor<2xi64>,
|
||||
limit_indices = dense<[2, 4]> : tensor<2xi64>,
|
||||
strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @slice_whole_stride
|
||||
// CHECK: subtensor %{{.*}}[1, 0] [1, 4] [1, 1] : tensor<3x4xi32> to tensor<1x4xi32>
|
||||
|
||||
// -----
|
||||
|
||||
func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "mhlo.slice"(%arg0) {
|
||||
start_indices = dense<[1, 1]> : tensor<2xi64>,
|
||||
limit_indices = dense<[2, 3]> : tensor<2xi64>,
|
||||
strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @slice_stride_part
|
||||
// CHECK: subtensor %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32>
|
||||
|
|
Loading…
Reference in New Issue