[MLIR][KernelGen] Simplify rank specialization tests with smaller target rank

For the tests rank specialize only up to rank 3. The remaining cases for higher
ranks are analogous.

PiperOrigin-RevId: 377024370
This commit is contained in:
A. Unique TensorFlower 2021-06-02 03:47:08 -07:00 committed by TensorFlow MLIR Team
parent 968d4b8709
commit 557e56362e
1 changed files with 25 additions and 205 deletions

View File

@ -1,5 +1,5 @@
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s // RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=8 | FileCheck %s --check-prefix CHECK-SCF // RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=3 | FileCheck %s --check-prefix CHECK-SCF
// CHECK-LABEL: @add_mul // CHECK-LABEL: @add_mul
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
@ -24,19 +24,9 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1 // CHECK-SCF-DAG: %[[C1:.*]] = constant 1
// CHECK-SCF-DAG: %[[C2:.*]] = constant 2 // CHECK-SCF-DAG: %[[C2:.*]] = constant 2
// CHECK-SCF-DAG: %[[C3:.*]] = constant 3 // CHECK-SCF-DAG: %[[C3:.*]] = constant 3
// CHECK-SCF-DAG: %[[C4:.*]] = constant 4
// CHECK-SCF-DAG: %[[C5:.*]] = constant 5
// CHECK-SCF-DAG: %[[C6:.*]] = constant 6
// CHECK-SCF-DAG: %[[C7:.*]] = constant 7
// CHECK-SCF-DAG: %[[C8:.*]] = constant 8
// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] // CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1] // CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] // CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] // CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] // CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] // CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]]
@ -102,110 +92,20 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF: else // CHECK-SCF: else
// Generic case 3: // Generic case 3:
// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]] // CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]]
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]] // CHECK-SCF: assert %[[MAX_RED_RANK_LE_3]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 3"
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) // CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 4:
// CHECK-SCF: %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]]
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 5:
// CHECK-SCF: %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]]
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 6:
// CHECK-SCF: %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]]
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 7:
// CHECK-SCF: %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]]
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 8:
// CHECK-SCF: %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]]
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] // CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]] // CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]]
// Reshape the result. // Reshape the result.
@ -544,19 +444,9 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1 // CHECK-SCF-DAG: %[[C1:.*]] = constant 1
// CHECK-SCF-DAG: %[[C2:.*]] = constant 2 // CHECK-SCF-DAG: %[[C2:.*]] = constant 2
// CHECK-SCF-DAG: %[[C3:.*]] = constant 3 // CHECK-SCF-DAG: %[[C3:.*]] = constant 3
// CHECK-SCF-DAG: %[[C4:.*]] = constant 4
// CHECK-SCF-DAG: %[[C5:.*]] = constant 5
// CHECK-SCF-DAG: %[[C6:.*]] = constant 6
// CHECK-SCF-DAG: %[[C7:.*]] = constant 7
// CHECK-SCF-DAG: %[[C8:.*]] = constant 8
// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] // CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1] // CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] // CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] // CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] // CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// Lhs scalar case: // Lhs scalar case:
@ -629,86 +519,16 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF: else // CHECK-SCF: else
// Generic case 3: // Generic case 3:
// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]] // CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]]
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]] // CHECK-SCF: assert %[[MAX_RED_RANK_LE_3]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 3"
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 4:
// CHECK-SCF: %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]]
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 5:
// CHECK-SCF: %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]]
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 6:
// CHECK-SCF: %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]]
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 7:
// CHECK-SCF: %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]]
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 8:
// CHECK-SCF: %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]]
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] // CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]] // CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_EQ_SHAPES]] // CHECK-SCF: scf.yield %[[UNSHAPED_RES_EQ_SHAPES]]