diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index ea599c6..f713beb 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -2237,6 +2237,7 @@ def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp< HLO_DimensionTensor:$strides ); let results = (outs HLO_Tensor:$result); + let hasCanonicalizer = 1; let hasCustomHLOConverter = 1; } diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index b66235d..1101cf2 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1689,6 +1689,75 @@ static LogicalResult Verify(RealDynamicSliceOp op) { return success(); } +namespace { +// Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice +// ops. This canonicalization is applied the case when the `begin` input values +// are compile time constants and thus can be made into a tensor. +struct RealDynamicSliceIsStatic : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RealDynamicSliceOp real_dynamic_slice, + PatternRewriter& rewriter) const override { + Location loc = real_dynamic_slice.getLoc(); + Value input = real_dynamic_slice.operand(); + Value output = real_dynamic_slice.result(); + auto input_ty = input.getType().dyn_cast(); + auto output_ty = output.getType().dyn_cast(); + + if (!input_ty || !output_ty || !input_ty.hasStaticShape() || + !output_ty.hasStaticShape()) { + return failure(); + } + + int64_t input_rank = input_ty.getRank(); + + auto start_val = real_dynamic_slice.start_indices(); + auto limit_val = real_dynamic_slice.limit_indices(); + auto stride_val = real_dynamic_slice.strides(); + auto start_op = start_val.getDefiningOp(); + auto limit_op = limit_val.getDefiningOp(); + auto stride_op = stride_val.getDefiningOp(); + if (!start_op || !limit_op || !stride_op) return failure(); + + auto start_attr = + start_op.getValue().dyn_cast_or_null(); + auto limit_attr = + limit_op.getValue().dyn_cast_or_null(); + auto stride_attr = + stride_op.getValue().dyn_cast_or_null(); + if (!start_attr || !limit_attr || !stride_attr) return failure(); + + SmallVector temp_start_indices; + SmallVector temp_limit_indices; + SmallVector temp_stride; + for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) { + int64_t start = start_attr.getValue(dim_idx).getInt(); + temp_start_indices.push_back(start); + int64_t limit = limit_attr.getValue(dim_idx).getInt(); + temp_limit_indices.push_back(limit); + int64_t end = stride_attr.getValue(dim_idx).getInt(); + temp_stride.push_back(end); + } + + DenseIntElementsAttr slice_start_indices = + GetI64ElementsAttr(temp_start_indices, &rewriter); + DenseIntElementsAttr slice_limit_indices = + GetI64ElementsAttr(temp_limit_indices, &rewriter); + DenseIntElementsAttr slice_strides = + GetI64ElementsAttr(temp_stride, &rewriter); + auto result = rewriter.create(loc, input, slice_start_indices, + slice_limit_indices, slice_strides); + rewriter.replaceOp(real_dynamic_slice, {result}); + return success(); + } +}; +} // namespace + +void RealDynamicSliceOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { diff --git a/lib/Dialect/mhlo/IR/mhlo_canonicalize.td b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td index 429e5a1..b43b41a 100644 --- a/lib/Dialect/mhlo/IR/mhlo_canonicalize.td +++ b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td @@ -55,3 +55,15 @@ def DPadToPad: Pat< (CastIntElementsAttr $edge_padding_low), (CastIntElementsAttr $edge_padding_high), (CastIntElementsAttr $interior_paddin))>; + +// Convert RealDynamicSliceOp to SliceOp if start_indices, limit_indices and +// strides are HLO_ConstOp +def RealDSliceToSlice: Pat< + (HLO_RealDynamicSliceOp HLO_Tensor:$operand, + (HLO_ConstOp I64ElementsAttr:$start_indices), + (HLO_ConstOp I64ElementsAttr:$limit_indices), + (HLO_ConstOp I64ElementsAttr:$strides)), + (HLO_SliceOp $operand, + (CastIntElementsAttr $start_indices), + (CastIntElementsAttr $limit_indices), + (CastIntElementsAttr $strides))>;