[MLIR][KernelGen] Simplify rank specialization tests with smaller target rank
For the tests rank specialize only up to rank 3. The remaining cases for higher ranks are analogous. PiperOrigin-RevId: 377024370
This commit is contained in:
parent
968d4b8709
commit
557e56362e
|
@ -1,5 +1,5 @@
|
|||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=8 | FileCheck %s --check-prefix CHECK-SCF
|
||||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=3 | FileCheck %s --check-prefix CHECK-SCF
|
||||
|
||||
// CHECK-LABEL: @add_mul
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||
|
@ -24,19 +24,9 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
|||
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1
|
||||
// CHECK-SCF-DAG: %[[C2:.*]] = constant 2
|
||||
// CHECK-SCF-DAG: %[[C3:.*]] = constant 3
|
||||
// CHECK-SCF-DAG: %[[C4:.*]] = constant 4
|
||||
// CHECK-SCF-DAG: %[[C5:.*]] = constant 5
|
||||
// CHECK-SCF-DAG: %[[C6:.*]] = constant 6
|
||||
// CHECK-SCF-DAG: %[[C7:.*]] = constant 7
|
||||
// CHECK-SCF-DAG: %[[C8:.*]] = constant 8
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]]
|
||||
|
@ -102,7 +92,7 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
|||
// CHECK-SCF: else
|
||||
// Generic case 3:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]]
|
||||
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_3]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 3"
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
|
||||
|
@ -116,96 +106,6 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
|||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 4:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 5:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 6:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 7:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 8:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]]
|
||||
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
|
||||
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]]
|
||||
// Reshape the result.
|
||||
|
@ -544,19 +444,9 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1
|
||||
// CHECK-SCF-DAG: %[[C2:.*]] = constant 2
|
||||
// CHECK-SCF-DAG: %[[C3:.*]] = constant 3
|
||||
// CHECK-SCF-DAG: %[[C4:.*]] = constant 4
|
||||
// CHECK-SCF-DAG: %[[C5:.*]] = constant 5
|
||||
// CHECK-SCF-DAG: %[[C6:.*]] = constant 6
|
||||
// CHECK-SCF-DAG: %[[C7:.*]] = constant 7
|
||||
// CHECK-SCF-DAG: %[[C8:.*]] = constant 8
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
|
||||
// Lhs scalar case:
|
||||
|
@ -629,7 +519,7 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF: else
|
||||
// Generic case 3:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_3:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C3]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_LE_3]]
|
||||
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_3]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 3"
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
|
@ -639,76 +529,6 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 4:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_4:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C4]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_LE_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 5:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_5:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C5]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_LE_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 6:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_6:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C6]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_LE_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 7:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_7:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C7]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_LE_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 8:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_LE_8:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C8]]
|
||||
// CHECK-SCF: assert %[[MAX_RED_RANK_LE_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_EQ_SHAPES]]
|
||||
|
|
Loading…
Reference in New Issue