[MLIR][MHLO] Allow recursion in the shape_of mover
This allows it to push shape_of over a chain of ops all the way to the top. PiperOrigin-RevId: 362249009
This commit is contained in:
parent
67a770e4e0
commit
d77b556822
|
@ -45,7 +45,10 @@ bool IsShapeOfOpMovable(Value arg) {
|
||||||
|
|
||||||
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
||||||
explicit ShapeOfOpConversion(MLIRContext *context)
|
explicit ShapeOfOpConversion(MLIRContext *context)
|
||||||
: OpConversionPattern<shape::ShapeOfOp>(context) {}
|
: OpConversionPattern<shape::ShapeOfOp>(context) {
|
||||||
|
// Recursively reify until we hit an op that doesn't support it.
|
||||||
|
setHasBoundedRewriteRecursion();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
shape::ShapeOfOp op, ArrayRef<Value> operands,
|
shape::ShapeOfOp op, ArrayRef<Value> operands,
|
||||||
|
|
|
@ -21,7 +21,8 @@ func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
// CHECK: "use"(%[[SHAPE]])
|
// CHECK: "use"(%[[SHAPE]])
|
||||||
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
|
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
|
||||||
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
%1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16>
|
||||||
"use"(%1) : (tensor<?xindex>) -> ()
|
%2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex>
|
||||||
|
"use"(%2) : (tensor<?xindex>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue