Add support for lowering mhlo.slice to subtensor.

PiperOrigin-RevId: 359297978
This commit is contained in:
Hanhan Wang 2021-02-24 09:04:32 -08:00 committed by TensorFlow MLIR Team
parent b478bdf00e
commit 475b4a06a5
2 changed files with 47 additions and 15 deletions

View File

@ -1035,16 +1035,16 @@ class ReverseConverter
}
};
class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
template <typename OpTy, bool isLHLO = true>
class SliceConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::SliceOp slice_op, ArrayRef<Value> args,
OpTy slice_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = slice_op.getLoc();
auto arg_type =
slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
auto arg_type = args[0].getType().template dyn_cast<ShapedType>();
if (!arg_type || !arg_type.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects known-rank args");
return failure();
@ -1053,17 +1053,22 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
SmallVector<OpFoldResult, 3> offsets, sizes, strides;
for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
offsets.push_back(rewriter.getI64IntegerAttr(
slice_op.start_indices().getValue<int64_t>(i)));
slice_op.start_indices().template getValue<int64_t>(i)));
sizes.push_back(rewriter.getI64IntegerAttr(
slice_op.limit_indices().getValue<int64_t>(i) -
slice_op.start_indices().getValue<int64_t>(i)));
strides.push_back(
rewriter.getI64IntegerAttr(slice_op.strides().getValue<int64_t>(i)));
slice_op.limit_indices().template getValue<int64_t>(i) -
slice_op.start_indices().template getValue<int64_t>(i)));
strides.push_back(rewriter.getI64IntegerAttr(
slice_op.strides().template getValue<int64_t>(i)));
}
auto linalg_slice = rewriter.create<SubViewOp>(loc, slice_op.getOperand(0),
offsets, sizes, strides);
rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
if (isLHLO) {
auto linalg_op =
rewriter.create<SubViewOp>(loc, args[0], offsets, sizes, strides);
rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
rewriter.eraseOp(slice_op);
} else {
rewriter.replaceOpWithNewOp<SubTensorOp>(slice_op, args[0], offsets,
sizes, strides);
}
return success();
}
};
@ -1430,7 +1435,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
ReverseConverter<lmhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
SliceConverter,
SliceConverter<lmhlo::SliceOp>,
TransposeConverter<lmhlo::TransposeOp>
>(context);
// clang-format on
@ -1554,6 +1559,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::XorOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
SliceConverter<mhlo::SliceOp, false>,
TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,

View File

@ -1256,3 +1256,29 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = addi %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
%0 = "mhlo.slice"(%arg0) {
start_indices = dense<[1, 0]> : tensor<2xi64>,
limit_indices = dense<[2, 4]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// CHECK-LABEL: func @slice_whole_stride
// CHECK: subtensor %{{.*}}[1, 0] [1, 4] [1, 1] : tensor<3x4xi32> to tensor<1x4xi32>
// -----
func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
%0 = "mhlo.slice"(%arg0) {
start_indices = dense<[1, 1]> : tensor<2xi64>,
limit_indices = dense<[2, 3]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
// CHECK-LABEL: func @slice_stride_part
// CHECK: subtensor %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32>