Add support for lowering mhlo.slice to subtensor.
PiperOrigin-RevId: 359297978
This commit is contained in:
parent
b478bdf00e
commit
475b4a06a5
|
@ -1035,16 +1035,16 @@ class ReverseConverter
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
template <typename OpTy, bool isLHLO = true>
|
||||||
|
class SliceConverter : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
lmhlo::SliceOp slice_op, ArrayRef<Value> args,
|
OpTy slice_op, ArrayRef<Value> args,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto loc = slice_op.getLoc();
|
auto loc = slice_op.getLoc();
|
||||||
auto arg_type =
|
auto arg_type = args[0].getType().template dyn_cast<ShapedType>();
|
||||||
slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
|
||||||
if (!arg_type || !arg_type.hasRank()) {
|
if (!arg_type || !arg_type.hasRank()) {
|
||||||
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -1053,17 +1053,22 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
||||||
SmallVector<OpFoldResult, 3> offsets, sizes, strides;
|
SmallVector<OpFoldResult, 3> offsets, sizes, strides;
|
||||||
for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
|
for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
|
||||||
offsets.push_back(rewriter.getI64IntegerAttr(
|
offsets.push_back(rewriter.getI64IntegerAttr(
|
||||||
slice_op.start_indices().getValue<int64_t>(i)));
|
slice_op.start_indices().template getValue<int64_t>(i)));
|
||||||
sizes.push_back(rewriter.getI64IntegerAttr(
|
sizes.push_back(rewriter.getI64IntegerAttr(
|
||||||
slice_op.limit_indices().getValue<int64_t>(i) -
|
slice_op.limit_indices().template getValue<int64_t>(i) -
|
||||||
slice_op.start_indices().getValue<int64_t>(i)));
|
slice_op.start_indices().template getValue<int64_t>(i)));
|
||||||
strides.push_back(
|
strides.push_back(rewriter.getI64IntegerAttr(
|
||||||
rewriter.getI64IntegerAttr(slice_op.strides().getValue<int64_t>(i)));
|
slice_op.strides().template getValue<int64_t>(i)));
|
||||||
}
|
}
|
||||||
auto linalg_slice = rewriter.create<SubViewOp>(loc, slice_op.getOperand(0),
|
if (isLHLO) {
|
||||||
offsets, sizes, strides);
|
auto linalg_op =
|
||||||
rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
|
rewriter.create<SubViewOp>(loc, args[0], offsets, sizes, strides);
|
||||||
|
rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
|
||||||
rewriter.eraseOp(slice_op);
|
rewriter.eraseOp(slice_op);
|
||||||
|
} else {
|
||||||
|
rewriter.replaceOpWithNewOp<SubTensorOp>(slice_op, args[0], offsets,
|
||||||
|
sizes, strides);
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1430,7 +1435,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
ReverseConverter<lmhlo::ReverseOp>,
|
ReverseConverter<lmhlo::ReverseOp>,
|
||||||
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
||||||
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
|
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
|
||||||
SliceConverter,
|
SliceConverter<lmhlo::SliceOp>,
|
||||||
TransposeConverter<lmhlo::TransposeOp>
|
TransposeConverter<lmhlo::TransposeOp>
|
||||||
>(context);
|
>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
@ -1554,6 +1559,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
||||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||||
ReverseConverter<mhlo::ReverseOp, false>,
|
ReverseConverter<mhlo::ReverseOp, false>,
|
||||||
|
SliceConverter<mhlo::SliceOp, false>,
|
||||||
TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
|
TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
|
||||||
DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
|
DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
|
||||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||||
|
|
|
@ -1256,3 +1256,29 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
|
||||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
||||||
|
%0 = "mhlo.slice"(%arg0) {
|
||||||
|
start_indices = dense<[1, 0]> : tensor<2xi64>,
|
||||||
|
limit_indices = dense<[2, 4]> : tensor<2xi64>,
|
||||||
|
strides = dense<1> : tensor<2xi64>
|
||||||
|
} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||||
|
return %0 : tensor<1x4xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @slice_whole_stride
|
||||||
|
// CHECK: subtensor %{{.*}}[1, 0] [1, 4] [1, 1] : tensor<3x4xi32> to tensor<1x4xi32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||||
|
%0 = "mhlo.slice"(%arg0) {
|
||||||
|
start_indices = dense<[1, 1]> : tensor<2xi64>,
|
||||||
|
limit_indices = dense<[2, 3]> : tensor<2xi64>,
|
||||||
|
strides = dense<1> : tensor<2xi64>
|
||||||
|
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||||
|
return %0 : tensor<1x2xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @slice_stride_part
|
||||||
|
// CHECK: subtensor %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32>
|
||||||
|
|
Loading…
Reference in New Issue