[MLIR][MHLO] Declare `shape_of` dynamically legal in move-up-dynamic-broadcasts
This allows shape reification to produce `shape_of` ops while they can still be moved up. PiperOrigin-RevId: 362075609
This commit is contained in:
parent
c217a6ef61
commit
e199df1dbf
|
@ -39,6 +39,10 @@ namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
bool IsShapeOfOpMovable(Value arg) {
|
||||||
|
return arg.getDefiningOp<InferShapedTypeOpInterface>();
|
||||||
|
}
|
||||||
|
|
||||||
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
||||||
explicit ShapeOfOpConversion(MLIRContext *context)
|
explicit ShapeOfOpConversion(MLIRContext *context)
|
||||||
: OpConversionPattern<shape::ShapeOfOp>(context) {}
|
: OpConversionPattern<shape::ShapeOfOp>(context) {}
|
||||||
|
@ -48,10 +52,11 @@ struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
shape::ShapeOfOp::Adaptor transformed(operands);
|
shape::ShapeOfOp::Adaptor transformed(operands);
|
||||||
|
|
||||||
auto shape_origin = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
|
// Only reify shape computation if operand allows for it.
|
||||||
transformed.arg().getDefiningOp());
|
if (!IsShapeOfOpMovable(transformed.arg())) return failure();
|
||||||
if (!shape_origin) return failure();
|
|
||||||
|
|
||||||
|
auto shape_origin =
|
||||||
|
transformed.arg().getDefiningOp<InferShapedTypeOpInterface>();
|
||||||
llvm::SmallVector<Value, 1> reified_shapes;
|
llvm::SmallVector<Value, 1> reified_shapes;
|
||||||
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes)))
|
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes)))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -96,8 +101,10 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
||||||
|
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionLegality(
|
void PopulateMoveUpDynamicBroadcastsForFusionLegality(
|
||||||
ConversionTarget *target) {
|
ConversionTarget *target) {
|
||||||
target->addLegalDialect<MhloDialect, StandardOpsDialect,
|
target->addLegalDialect<MhloDialect, StandardOpsDialect, shape::ShapeDialect,
|
||||||
tensor::TensorDialect>();
|
tensor::TensorDialect>();
|
||||||
|
target->addDynamicallyLegalOp<shape::ShapeOfOp>(
|
||||||
|
[](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
|
|
Loading…
Reference in New Issue