[MLIR] Add example test case for `move-up-dynamic-broadcasts-for-fusion` pass

Add exemplary test case as it appears in the lowering of two subsequent `tf.Sub`
ops.

PiperOrigin-RevId: 366219139
This commit is contained in:
A. Unique TensorFlower 2021-04-01 03:23:47 -07:00 committed by TensorFlow MLIR Team
parent eb4d20ba04
commit c23be1841c
1 changed files with 52 additions and 0 deletions

View File

@ -285,3 +285,55 @@ func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
} }
return %7 : tensor<?x?x32xf16> return %7 : tensor<?x?x32xf16>
} }
// -----
// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
// CHECK-LABEL: @sub_sub
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
%arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
// CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
// CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE1]]
// CHECK: %[[ASSUMING_RESULTS:.*]]:4 = shape.assuming %[[WITNESS]]
// CHECK-SAME: {
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
// CHECK: %[[PARTIALLY_BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
// CHECK: %[[PARTIALLY_BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]>
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
// CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
// CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
// CHECK: shape.assuming_yield %{{.*}}, %{{.*}}, %{{.*}}, %[[RESULT]]
// CHECK: }
// CHECK: return %[[ASSUMING_RESULTS]]#3
%0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
%1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
%2 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex>
%3 = shape.assuming %2 -> (tensor<?x32xf16>) {
%8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
%9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
%10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
%13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
%14 = mhlo.subtract %12, %13 : tensor<?x32xf16>
shape.assuming_yield %14 : tensor<?x32xf16>
}
%4 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
%5 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
%6 = shape.cstr_broadcastable %4, %5 : tensor<3xindex>, tensor<2xindex>
%7 = shape.assuming %6 -> (tensor<?x?x32xf16>) {
%8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
%9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
%10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<3xindex>
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %10) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
%13 = "mhlo.dynamic_broadcast_in_dim"(%3, %10) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
%14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16>
shape.assuming_yield %14 : tensor<?x?x32xf16>
}
return %7 : tensor<?x?x32xf16>
}