[MLIR] Add example test case for `move-up-dynamic-broadcasts-for-fusion` pass
Add exemplary test case as it appears in the lowering of two subsequent `tf.Sub` ops. PiperOrigin-RevId: 366219139
This commit is contained in:
parent
eb4d20ba04
commit
c23be1841c
|
@ -285,3 +285,55 @@ func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
}
|
}
|
||||||
return %7 : tensor<?x?x32xf16>
|
return %7 : tensor<?x?x32xf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
|
||||||
|
// CHECK-LABEL: @sub_sub
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
|
||||||
|
func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
|
%arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
|
||||||
|
// CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
|
||||||
|
// CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
|
||||||
|
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
|
||||||
|
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE1]]
|
||||||
|
// CHECK: %[[ASSUMING_RESULTS:.*]]:4 = shape.assuming %[[WITNESS]]
|
||||||
|
// CHECK-SAME: {
|
||||||
|
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
|
||||||
|
// CHECK: %[[PARTIALLY_BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
|
||||||
|
// CHECK: %[[PARTIALLY_BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
|
||||||
|
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
|
||||||
|
// CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
|
||||||
|
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||||
|
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||||
|
// CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
|
||||||
|
// CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
|
||||||
|
// CHECK: shape.assuming_yield %{{.*}}, %{{.*}}, %{{.*}}, %[[RESULT]]
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return %[[ASSUMING_RESULTS]]#3
|
||||||
|
%0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
|
%1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
|
%2 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex>
|
||||||
|
%3 = shape.assuming %2 -> (tensor<?x32xf16>) {
|
||||||
|
%8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
|
%9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
|
%10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
|
||||||
|
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
|
||||||
|
%13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
|
||||||
|
%14 = mhlo.subtract %12, %13 : tensor<?x32xf16>
|
||||||
|
shape.assuming_yield %14 : tensor<?x32xf16>
|
||||||
|
}
|
||||||
|
%4 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
|
||||||
|
%5 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
|
%6 = shape.cstr_broadcastable %4, %5 : tensor<3xindex>, tensor<2xindex>
|
||||||
|
%7 = shape.assuming %6 -> (tensor<?x?x32xf16>) {
|
||||||
|
%8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
|
||||||
|
%9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
|
%10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<3xindex>
|
||||||
|
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %10) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
|
||||||
|
%13 = "mhlo.dynamic_broadcast_in_dim"(%3, %10) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
|
||||||
|
%14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16>
|
||||||
|
shape.assuming_yield %14 : tensor<?x?x32xf16>
|
||||||
|
}
|
||||||
|
return %7 : tensor<?x?x32xf16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue