PR #49919: [MLIR][DISC] pattern conversion from tf2mhlo: ConvertUnpackOpDynamic, ConvertSignOpDynamic, ConvertSigmoidGradOpDynamic
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49919 We are porting our MLIR-based dynamic shape compiler to tf community (From OP def, Patttern, to Optimization pass, etc). This is the 5th PR about tf2mhlo pattern conversion, which including ConvertUnpackOpDynamic, ConvertSignOpDynamic, ConvertSigmoidGradOpDynamic. The rest pattern conversions we will add: - ConvertSqueezeOpxxx - ConvertStridedSliceOpxxx - ConvertPrintOp Copybara import of the project: -- 21b3c3eb05b12956bcdb8b98cc54d9371dbf034d by azazhu <azazhu@gmail.com>: [MLIR][DISC] pattern conversion from tf2mhlo: ConvertUnpackOpDynamic, ConvertSignOpDynamic, ConvertSigmoidGradOpDynamic -- 634630a4e2e426357290650bd579b35efecab5b3 by azazhu <azazhu@gmail.com>: [MLIR][DISC] refine ConvertUnpackOpDynamic, ConvertSignOpDynamic, ConvertSigmoidGradOpDynamic -- 39a2bedd6dafb369ae960c5197b7a352bfdfbc80 by azazhu <azazhu@gmail.com>: add RealDynamicSliceOp's canonicalize and fix CI -- a1c38dd0963d602ed4812da0d77a096a95920ddb by azazhu <azazhu@gmail.com>: fix CI for ConvertUnpackOpDynamic -- 5a8b4eb389ed6dc554104356c37f2f1550802b8c by azazhu <azazhu@gmail.com>: fix typo in ConvertSigmoidGradOpDynamic PiperOrigin-RevId: 379521079
This commit is contained in:
parent
5fbdac34a9
commit
3afbe312f8
|
@ -2237,6 +2237,7 @@ def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp<
|
||||||
HLO_DimensionTensor:$strides
|
HLO_DimensionTensor:$strides
|
||||||
);
|
);
|
||||||
let results = (outs HLO_Tensor:$result);
|
let results = (outs HLO_Tensor:$result);
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1689,6 +1689,75 @@ static LogicalResult Verify(RealDynamicSliceOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice
|
||||||
|
// ops. This canonicalization is applied the case when the `begin` input values
|
||||||
|
// are compile time constants and thus can be made into a tensor.
|
||||||
|
struct RealDynamicSliceIsStatic : public OpRewritePattern<RealDynamicSliceOp> {
|
||||||
|
using OpRewritePattern<RealDynamicSliceOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(RealDynamicSliceOp real_dynamic_slice,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
Location loc = real_dynamic_slice.getLoc();
|
||||||
|
Value input = real_dynamic_slice.operand();
|
||||||
|
Value output = real_dynamic_slice.result();
|
||||||
|
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto output_ty = output.getType().dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (!input_ty || !output_ty || !input_ty.hasStaticShape() ||
|
||||||
|
!output_ty.hasStaticShape()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t input_rank = input_ty.getRank();
|
||||||
|
|
||||||
|
auto start_val = real_dynamic_slice.start_indices();
|
||||||
|
auto limit_val = real_dynamic_slice.limit_indices();
|
||||||
|
auto stride_val = real_dynamic_slice.strides();
|
||||||
|
auto start_op = start_val.getDefiningOp<mlir::ConstantOp>();
|
||||||
|
auto limit_op = limit_val.getDefiningOp<mlir::ConstantOp>();
|
||||||
|
auto stride_op = stride_val.getDefiningOp<mlir::ConstantOp>();
|
||||||
|
if (!start_op || !limit_op || !stride_op) return failure();
|
||||||
|
|
||||||
|
auto start_attr =
|
||||||
|
start_op.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
|
||||||
|
auto limit_attr =
|
||||||
|
limit_op.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
|
||||||
|
auto stride_attr =
|
||||||
|
stride_op.getValue().dyn_cast_or_null<DenseIntElementsAttr>();
|
||||||
|
if (!start_attr || !limit_attr || !stride_attr) return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> temp_start_indices;
|
||||||
|
SmallVector<int64_t, 4> temp_limit_indices;
|
||||||
|
SmallVector<int64_t, 4> temp_stride;
|
||||||
|
for (int64_t dim_idx = 0; dim_idx < input_rank; dim_idx++) {
|
||||||
|
int64_t start = start_attr.getValue<IntegerAttr>(dim_idx).getInt();
|
||||||
|
temp_start_indices.push_back(start);
|
||||||
|
int64_t limit = limit_attr.getValue<IntegerAttr>(dim_idx).getInt();
|
||||||
|
temp_limit_indices.push_back(limit);
|
||||||
|
int64_t end = stride_attr.getValue<IntegerAttr>(dim_idx).getInt();
|
||||||
|
temp_stride.push_back(end);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseIntElementsAttr slice_start_indices =
|
||||||
|
GetI64ElementsAttr(temp_start_indices, &rewriter);
|
||||||
|
DenseIntElementsAttr slice_limit_indices =
|
||||||
|
GetI64ElementsAttr(temp_limit_indices, &rewriter);
|
||||||
|
DenseIntElementsAttr slice_strides =
|
||||||
|
GetI64ElementsAttr(temp_stride, &rewriter);
|
||||||
|
auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
|
||||||
|
slice_limit_indices, slice_strides);
|
||||||
|
rewriter.replaceOp(real_dynamic_slice, {result});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void RealDynamicSliceOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
|
results.insert<RealDynamicSliceIsStatic, RealDSliceToSlice>(context);
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
|
LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
|
||||||
OpBuilder& builder, ValueRange operands,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
|
|
@ -55,3 +55,15 @@ def DPadToPad: Pat<
|
||||||
(CastIntElementsAttr $edge_padding_low),
|
(CastIntElementsAttr $edge_padding_low),
|
||||||
(CastIntElementsAttr $edge_padding_high),
|
(CastIntElementsAttr $edge_padding_high),
|
||||||
(CastIntElementsAttr $interior_paddin))>;
|
(CastIntElementsAttr $interior_paddin))>;
|
||||||
|
|
||||||
|
// Convert RealDynamicSliceOp to SliceOp if start_indices, limit_indices and
|
||||||
|
// strides are HLO_ConstOp
|
||||||
|
def RealDSliceToSlice: Pat<
|
||||||
|
(HLO_RealDynamicSliceOp HLO_Tensor:$operand,
|
||||||
|
(HLO_ConstOp I64ElementsAttr:$start_indices),
|
||||||
|
(HLO_ConstOp I64ElementsAttr:$limit_indices),
|
||||||
|
(HLO_ConstOp I64ElementsAttr:$strides)),
|
||||||
|
(HLO_SliceOp $operand,
|
||||||
|
(CastIntElementsAttr $start_indices),
|
||||||
|
(CastIntElementsAttr $limit_indices),
|
||||||
|
(CastIntElementsAttr $strides))>;
|
||||||
|
|
Loading…
Reference in New Issue