From 3afbe312f86dfb7f88783395aa6f6bd818859fe3 Mon Sep 17 00:00:00 2001 From: Feiwen Date: Tue, 15 Jun 2021 10:32:32 -0700 Subject: [PATCH] PR #49919: [MLIR][DISC] pattern conversion from tf2mhlo: ConvertUnpackOpDynamic, ConvertSignOpDynamic, ConvertSigmoidGradOpDynamic Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49919 We are porting our MLIR-based dynamic shape compiler to tf community (From OP def, Patttern, to Optimization pass, etc). This is the 5th PR about tf2mhlo pattern conversion, which including ConvertUnpackOpDynamic, ConvertSignOpDynamic, ConvertSigmoidGradOpDynamic. The rest pattern conversions we will add: - ConvertSqueezeOpxxx - ConvertStridedSliceOpxxx - ConvertPrintOp Copybara import of the project: -- 21b3c3eb05b12956bcdb8b98cc54d9371dbf034d by azazhu : [MLIR][DISC] pattern conversion from tf2mhlo: ConvertUnpackOpDynamic, ConvertSignOpDynamic, ConvertSigmoidGradOpDynamic -- 634630a4e2e426357290650bd579b35efecab5b3 by azazhu : [MLIR][DISC] refine ConvertUnpackOpDynamic, ConvertSignOpDynamic, ConvertSigmoidGradOpDynamic -- 39a2bedd6dafb369ae960c5197b7a352bfdfbc80 by azazhu : add RealDynamicSliceOp's canonicalize and fix CI -- a1c38dd0963d602ed4812da0d77a096a95920ddb by azazhu : fix CI for ConvertUnpackOpDynamic -- 5a8b4eb389ed6dc554104356c37f2f1550802b8c by azazhu : fix typo in ConvertSigmoidGradOpDynamic PiperOrigin-RevId: 379521079 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 1 + lib/Dialect/mhlo/IR/hlo_ops.cc | 69 +++++++++++++++++++++ lib/Dialect/mhlo/IR/mhlo_canonicalize.td | 12 ++++ 3 files changed, 82 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index ea599c6..f713beb 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -2237,6 +2237,7 @@ def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp< HLO_DimensionTensor:$strides ); let results = (outs HLO_Tensor:$result); + let hasCanonicalizer = 1; let hasCustomHLOConverter = 1; } diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index b66235d..1101cf2 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1689,6 +1689,75 @@ static LogicalResult Verify(RealDynamicSliceOp op) { return success(); } +namespace { +// Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice +// ops. This canonicalization is applied the case when the `begin` input values +// are compile time constants and thus can be made into a tensor. +struct RealDynamicSliceIsStatic : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RealDynamicSliceOp real_dynamic_slice, + PatternRewriter& rewriter) const override { + Location loc = real_dynamic_slice.getLoc(); + Value input = real_dynamic_slice.operand(); + Value output = real_dynamic_slice.result(); + auto input_ty = input.getType().dyn_cast(); + auto output_ty = output.getType().dyn_cast(); + + if (!input_ty || !output_ty || !input_ty.hasStaticShape() || + !output_ty.hasStaticShape()) { + return failure(); + } + + int64_t input_rank = input_ty.getRank(); + + auto start_val = real_dynamic_slice.start_indices(); + auto limit_val = real_dynamic_slice.limit_indices(); + auto stride_val = real_dynamic_slice.strides(); + auto start_op = start_val.getDefiningOp(); + auto limit_op = limit_val.getDefiningOp(); + auto stride_op = stride_val.getDefiningOp(); + if (!start_op || !limit_op || !stride_op) return failure(); + + auto start_attr = + start_op.getValue().dyn_cast_or_null(); + auto limit_attr = + limit_op.getValue().dyn_cast_or_null(); + auto stride_attr = + stride_op.getValue().dyn_cast_or_null(); + if (!start_attr || !limit_attr || !stride_attr) return failure(); + + SmallVector temp_start_indices; + SmallVector temp_limit_indices; + SmallVector temp_stride; + for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) { + int64_t start = start_attr.getValue(dim_idx).getInt(); + temp_start_indices.push_back(start); + int64_t limit = limit_attr.getValue(dim_idx).getInt(); + temp_limit_indices.push_back(limit); + int64_t end = stride_attr.getValue(dim_idx).getInt(); + temp_stride.push_back(end); + } + + DenseIntElementsAttr slice_start_indices = + GetI64ElementsAttr(temp_start_indices, &rewriter); + DenseIntElementsAttr slice_limit_indices = + GetI64ElementsAttr(temp_limit_indices, &rewriter); + DenseIntElementsAttr slice_strides = + GetI64ElementsAttr(temp_stride, &rewriter); + auto result = rewriter.create(loc, input, slice_start_indices, + slice_limit_indices, slice_strides); + rewriter.replaceOp(real_dynamic_slice, {result}); + return success(); + } +}; +} // namespace + +void RealDynamicSliceOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { diff --git a/lib/Dialect/mhlo/IR/mhlo_canonicalize.td b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td index 429e5a1..b43b41a 100644 --- a/lib/Dialect/mhlo/IR/mhlo_canonicalize.td +++ b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td @@ -55,3 +55,15 @@ def DPadToPad: Pat< (CastIntElementsAttr $edge_padding_low), (CastIntElementsAttr $edge_padding_high), (CastIntElementsAttr $interior_paddin))>; + +// Convert RealDynamicSliceOp to SliceOp if start_indices, limit_indices and +// strides are HLO_ConstOp +def RealDSliceToSlice: Pat< + (HLO_RealDynamicSliceOp HLO_Tensor:$operand, + (HLO_ConstOp I64ElementsAttr:$start_indices), + (HLO_ConstOp I64ElementsAttr:$limit_indices), + (HLO_ConstOp I64ElementsAttr:$strides)), + (HLO_SliceOp $operand, + (CastIntElementsAttr $start_indices), + (CastIntElementsAttr $limit_indices), + (CastIntElementsAttr $strides))>;