[MLIR][MHLO] Declare `shape_of` dynamically legal in move-up-dynamic-broadcasts

This allows shape reification to produce `shape_of` ops while they can still be
moved up.

PiperOrigin-RevId: 362075609
This commit is contained in:
A. Unique TensorFlower 2021-03-10 09:58:15 -08:00 committed by TensorFlow MLIR Team
parent c217a6ef61
commit e199df1dbf
1 changed files with 11 additions and 4 deletions

View File

@ -39,6 +39,10 @@ namespace mlir {
namespace mhlo {
namespace {
bool IsShapeOfOpMovable(Value arg) {
return arg.getDefiningOp<InferShapedTypeOpInterface>();
}
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
explicit ShapeOfOpConversion(MLIRContext *context)
: OpConversionPattern<shape::ShapeOfOp>(context) {}
@ -48,10 +52,11 @@ struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
ConversionPatternRewriter &rewriter) const override {
shape::ShapeOfOp::Adaptor transformed(operands);
auto shape_origin = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
transformed.arg().getDefiningOp());
if (!shape_origin) return failure();
// Only reify shape computation if operand allows for it.
if (!IsShapeOfOpMovable(transformed.arg())) return failure();
auto shape_origin =
transformed.arg().getDefiningOp<InferShapedTypeOpInterface>();
llvm::SmallVector<Value, 1> reified_shapes;
if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes)))
return failure();
@ -96,8 +101,10 @@ struct MoveUpDynamicBroadcastsForFusionPass
void PopulateMoveUpDynamicBroadcastsForFusionLegality(
ConversionTarget *target) {
target->addLegalDialect<MhloDialect, StandardOpsDialect,
target->addLegalDialect<MhloDialect, StandardOpsDialect, shape::ShapeDialect,
tensor::TensorDialect>();
target->addDynamicallyLegalOp<shape::ShapeOfOp>(
[](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); });
}
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(