[MLIR][MHLO] Allow recursion in the shape_of mover
This allows it to push shape_of over a chain of ops all the way to the top. PiperOrigin-RevId: 362249009
This commit is contained in:
parent
67a770e4e0
commit
d77b556822
|
@ -45,7 +45,10 @@ bool IsShapeOfOpMovable(Value arg) {
|
|||
|
||||
struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
||||
explicit ShapeOfOpConversion(MLIRContext *context)
|
||||
: OpConversionPattern<shape::ShapeOfOp>(context) {}
|
||||
: OpConversionPattern<shape::ShapeOfOp>(context) {
|
||||
// Recursively reify until we hit an op that doesn't support it.
|
||||
setHasBoundedRewriteRecursion();
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
shape::ShapeOfOp op, ArrayRef<Value> operands,
|
||||
|
|
|
@ -21,7 +21,8 @@ func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
|
|||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex>
|
||||
// CHECK: "use"(%[[SHAPE]])
|
||||
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16>
|
||||
%1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
|
||||
"use"(%1) : (tensor<?xindex>) -> ()
|
||||
%1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16>
|
||||
%2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex>
|
||||
"use"(%2) : (tensor<?xindex>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue