[MLIR][MHLO] Allow recursion in the shape_of mover

This allows it to push shape_of over a chain of ops all the way to the top.

PiperOrigin-RevId: 362249009
This commit is contained in:
Benjamin Kramer 2021-03-11 02:51:27 -08:00 committed by TensorFlow MLIR Team
parent 67a770e4e0
commit d77b556822
2 changed files with 7 additions and 3 deletions

View File

@ -45,7 +45,10 @@ bool IsShapeOfOpMovable(Value arg) {
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> { struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
explicit ShapeOfOpConversion(MLIRContext *context) explicit ShapeOfOpConversion(MLIRContext *context)
: OpConversionPattern<shape::ShapeOfOp>(context) {} : OpConversionPattern<shape::ShapeOfOp>(context) {
// Recursively reify until we hit an op that doesn't support it.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
shape::ShapeOfOp op, ArrayRef<Value> operands, shape::ShapeOfOp op, ArrayRef<Value> operands,

View File

@ -21,7 +21,8 @@ func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex> // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
// CHECK: "use"(%[[SHAPE]]) // CHECK: "use"(%[[SHAPE]])
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16> %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex> %1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16>
"use"(%1) : (tensor<?xindex>) -> () %2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex>
"use"(%2) : (tensor<?xindex>) -> ()
return return
} }