diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index e64ff71..86e73b9 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -285,3 +285,55 @@ func @merge_assuming_ops(%arg0: tensor, %arg1 : tensor, } return %7 : tensor } + +// ----- + +// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. +// CHECK-LABEL: @sub_sub +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor) +func @sub_sub(%arg0: tensor, %arg1 : tensor, + %arg2: tensor) -> tensor { + // CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]] + // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]] + // CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]] + // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE1]] + // CHECK: %[[ASSUMING_RESULTS:.*]]:4 = shape.assuming %[[WITNESS]] + // CHECK-SAME: { + // CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]] + // CHECK: %[[PARTIALLY_BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]> + // CHECK: %[[PARTIALLY_BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE01]]) {broadcast_dimensions = dense<[0, 1]> + // CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]] + // CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]> + // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} + // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PARTIALLY_BCASTED_ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} + // CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] + // CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]] + // CHECK: shape.assuming_yield %{{.*}}, %{{.*}}, %{{.*}}, %[[RESULT]] + // CHECK: } + // CHECK: return %[[ASSUMING_RESULTS]]#3 + %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> + %2 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex> + %3 = shape.assuming %2 -> (tensor) { + %8 = shape.shape_of %arg0 : tensor -> tensor<2xindex> + %9 = shape.shape_of %arg1 : tensor -> tensor<2xindex> + %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + %14 = mhlo.subtract %12, %13 : tensor + shape.assuming_yield %14 : tensor + } + %4 = shape.shape_of %arg2 : tensor -> tensor<3xindex> + %5 = shape.shape_of %3 : tensor -> tensor<2xindex> + %6 = shape.cstr_broadcastable %4, %5 : tensor<3xindex>, tensor<2xindex> + %7 = shape.assuming %6 -> (tensor) { + %8 = shape.shape_of %arg2 : tensor -> tensor<3xindex> + %9 = shape.shape_of %3 : tensor -> tensor<2xindex> + %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<3xindex> + %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %10) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xindex>) -> tensor + %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %10) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor + %14 = mhlo.subtract %12, %13 : tensor + shape.assuming_yield %14 : tensor + } + return %7 : tensor +}