[MLIR][HLO] Eliminate duplicate broadcastable constraints
PiperOrigin-RevId: 376715240
This commit is contained in:
parent
511a1db4f3
commit
0f341012c6
|
@ -306,58 +306,25 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Eliminate casted extent tensors. Instead, produce the concrete extent tensor
|
struct EliminateDuplicateCstrBroadcastableOps
|
||||||
// type where possible.
|
: public OpRewritePattern<shape::CstrBroadcastableOp> {
|
||||||
struct CanonicalizeCastedShapeOfOpPattern
|
using OpRewritePattern<shape::CstrBroadcastableOp>::OpRewritePattern;
|
||||||
: public OpRewritePattern<tensor::CastOp> {
|
|
||||||
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(tensor::CastOp op,
|
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// Only merge tensor cast into `shape_of` ops.
|
// Search for previous occurence of the same constraint.
|
||||||
auto shape_of_op = op.source().getDefiningOp<shape::ShapeOfOp>();
|
Operation *it = op->getPrevNode();
|
||||||
if (!shape_of_op) return failure();
|
while (it != nullptr) {
|
||||||
|
if (auto candidate = llvm::dyn_cast<shape::CstrBroadcastableOp>(it)) {
|
||||||
|
if (candidate.shapes() == op.shapes()) {
|
||||||
|
rewriter.replaceOp(op, candidate.result());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
it = it->getPrevNode();
|
||||||
|
}
|
||||||
|
|
||||||
// Desired type must be an extent tensor type.
|
|
||||||
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!result_ty || result_ty.getRank() != 1 ||
|
|
||||||
!result_ty.getElementType().isIndex())
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op, result_ty,
|
|
||||||
shape_of_op.arg());
|
|
||||||
if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO(frgossen): Remove this once it has landed upstream.
|
|
||||||
struct CanonicalizeBroadcastPattern
|
|
||||||
: public OpRewritePattern<shape::BroadcastOp> {
|
|
||||||
using OpRewritePattern<shape::BroadcastOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(shape::BroadcastOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// Only concretize dynamic extent tensor result types.
|
|
||||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!resultTy || !resultTy.isDynamicDim(0)) return failure();
|
|
||||||
|
|
||||||
// Infer resulting shape rank if possible.
|
|
||||||
int64_t maxRank = 0;
|
|
||||||
for (Value shape : op.shapes()) {
|
|
||||||
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
|
|
||||||
// Cannot infer resulting shape rank if any operand is dynamically
|
|
||||||
// ranked.
|
|
||||||
if (extentTensorTy.isDynamicDim(0)) return failure();
|
|
||||||
maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newOp = rewriter.create<shape::BroadcastOp>(
|
|
||||||
op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()),
|
|
||||||
op.shapes());
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -432,8 +399,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
CanonicalizeBroadcastPattern,
|
EliminateDuplicateCstrBroadcastableOps,
|
||||||
CanonicalizeCastedShapeOfOpPattern,
|
|
||||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||||
MergeAssumingOpsPattern,
|
MergeAssumingOpsPattern,
|
||||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||||
|
@ -443,6 +409,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MoveUpBroadcastInDimOpPattern,
|
MoveUpBroadcastInDimOpPattern,
|
||||||
ShapeReificationPattern>(context);
|
ShapeReificationPattern>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -374,3 +374,21 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
}
|
}
|
||||||
return %7 : tensor<?x?x32xf16>
|
return %7 : tensor<?x?x32xf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @redundant_cstr_broadcastable
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>)
|
||||||
|
func @redundant_cstr_broadcastable(%arg0: tensor<?xindex>,
|
||||||
|
%arg1 : tensor<?xindex>) {
|
||||||
|
// CHECK-DAG: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]]
|
||||||
|
// CHECK: shape.assuming %[[WITNESS]]
|
||||||
|
%0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
|
||||||
|
%1 = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
|
||||||
|
%2 = shape.assuming_all %0, %1
|
||||||
|
shape.assuming %2 -> () {
|
||||||
|
"some.op"() : () -> ()
|
||||||
|
shape.assuming_yield
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue