diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 8dff2f5..5244bf3 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1035,16 +1035,16 @@ class ReverseConverter } }; -class SliceConverter : public OpConversionPattern { +template +class SliceConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - lmhlo::SliceOp slice_op, ArrayRef args, + OpTy slice_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = slice_op.getLoc(); - auto arg_type = - slice_op.getOperand(0).getType().template dyn_cast(); + auto arg_type = args[0].getType().template dyn_cast(); if (!arg_type || !arg_type.hasRank()) { emitError(loc, "lhlo to linalg conversion expects known-rank args"); return failure(); @@ -1053,17 +1053,22 @@ class SliceConverter : public OpConversionPattern { SmallVector offsets, sizes, strides; for (int i = 0, e = arg_type.getRank(); i < e; ++i) { offsets.push_back(rewriter.getI64IntegerAttr( - slice_op.start_indices().getValue(i))); + slice_op.start_indices().template getValue(i))); sizes.push_back(rewriter.getI64IntegerAttr( - slice_op.limit_indices().getValue(i) - - slice_op.start_indices().getValue(i))); - strides.push_back( - rewriter.getI64IntegerAttr(slice_op.strides().getValue(i))); + slice_op.limit_indices().template getValue(i) - + slice_op.start_indices().template getValue(i))); + strides.push_back(rewriter.getI64IntegerAttr( + slice_op.strides().template getValue(i))); + } + if (isLHLO) { + auto linalg_op = + rewriter.create(loc, args[0], offsets, sizes, strides); + rewriter.create(loc, linalg_op, args[1]); + rewriter.eraseOp(slice_op); + } else { + rewriter.replaceOpWithNewOp(slice_op, args[0], offsets, + sizes, strides); } - auto linalg_slice = rewriter.create(loc, slice_op.getOperand(0), - offsets, sizes, strides); - rewriter.create(loc, linalg_slice, slice_op.getOperand(1)); - rewriter.eraseOp(slice_op); return success(); } }; @@ -1430,7 +1435,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, ReverseConverter, ScalarPointwiseToStandardConverter, ScalarPointwiseToStandardConverter, - SliceConverter, + SliceConverter, TransposeConverter >(context); // clang-format on @@ -1554,6 +1559,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, + SliceConverter, TransposeConverter, DotOpOnTensorsConversion, DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context); patterns->insert, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 99415e0..7846f2d 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1256,3 +1256,29 @@ func @reduce_dynamic(%arg0: tensor, %arg1: tensor) -> tensor) -> tensor<1x4xi32> { + %0 = "mhlo.slice"(%arg0) { + start_indices = dense<[1, 0]> : tensor<2xi64>, + limit_indices = dense<[2, 4]> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } : (tensor<3x4xi32>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} +// CHECK-LABEL: func @slice_whole_stride +// CHECK: subtensor %{{.*}}[1, 0] [1, 4] [1, 1] : tensor<3x4xi32> to tensor<1x4xi32> + +// ----- + +func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { + %0 = "mhlo.slice"(%arg0) { + start_indices = dense<[1, 1]> : tensor<2xi64>, + limit_indices = dense<[2, 3]> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } : (tensor<3x4xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> +} +// CHECK-LABEL: func @slice_stride_part +// CHECK: subtensor %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32>