diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 0f59779..7ae241d 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -256,31 +256,66 @@ SmallVector MaterializeFinalReshape( })); } -SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( +Value MaterializeScalarRankSpecializationCase( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - Value non_scalar_operand) { - // Flatten the non-scalar operand. - Value flat_shape = b.create( - loc, b.create( - loc, b.getIndexType(), - b.create(loc, non_scalar_operand)) - .result()); - Value flat_non_scalar_operand = b.create( - loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1), - non_scalar_operand, flat_shape); - - // Materialize ranked variants for the element-wise operations. - BlockAndValueMapping bvm; - for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) { - Value operand = std::get<1>(it); - bvm.map(std::get<0>(it), - operand == non_scalar_operand ? flat_non_scalar_operand : operand); + const SmallVector &shapes, int64_t non_scalar_idx, + function_ref else_builder_fn) { + // Materialize predicate: All operands except one are scalars. + Value one = b.create(loc, 1); + Value all_others_are_scalar; + for (auto it : llvm::enumerate(shapes)) { + if (it.index() == non_scalar_idx) continue; + auto literal = + b.create(loc, CmpIPredicate::eq, + b.create(loc, it.value()), one); + all_others_are_scalar = + all_others_are_scalar + ? b.create(loc, all_others_are_scalar, literal).getResult() + : literal.result(); } - SmallVector unshaped_results = - MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1); - // Restore the results' expected shape. - return MaterializeFinalReshape(b, loc, op, unshaped_results); + auto if_op = b.create( + loc, op->getResultTypes(), all_others_are_scalar, + [&](OpBuilder &b, Location loc) { + // Flatten the non-scalar operand. + Value flat_shape = b.create( + loc, b.create(loc, b.getIndexType(), + shapes[non_scalar_idx]) + .result()); + Value non_scalar_operand = op.operands()[non_scalar_idx]; + Value flat_non_scalar_operand = b.create( + loc, + DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1), + non_scalar_operand, flat_shape); + + // Derive ranked operands. + auto ranked_operands = + llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { + if (v == non_scalar_operand) return flat_non_scalar_operand; + return b + .create( + loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v) + .dest(); + })); + + // Materialize ranked variants for the element-wise operations. + BlockAndValueMapping bvm; + for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands)) + bvm.map(std::get<0>(it), std::get<1>(it)); + Value unshaped_result = + MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1) + .front(); + + // Return as unranked tensor for compatibility with the other cases. + b.create( + loc, b.create( + loc, DeriveUnrankedTensorTypes(unshaped_result.getType()), + unshaped_result) + .dest()); + }, + else_builder_fn); + + return if_op.results().front(); } Value MaterializeEqualShapesRankSpecializationCase( @@ -313,7 +348,7 @@ Value MaterializeEqualShapesRankSpecializationCase( Value flat_shape = b.create( loc, b.create(loc, b.getIndexType(), shape) .result()); - SmallVector flat_operands = + auto flat_operands = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { return b .create( @@ -355,7 +390,7 @@ Value MaterializeTargetRankSpecializationCase( loc, extent_tensor_ty, mlir::DenseIntElementsAttr::get(extent_tensor_ty, SmallVector(target_rank, 1))); - SmallVector ranked_operands; + SmallVector ranked_operands; for (auto it : llvm::zip(op.operands(), shapes)) { Value operand, shape; std::tie(operand, shape) = it; @@ -464,6 +499,72 @@ Value MaterializeGenericRankSpecializationCases( b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank); } +Value MaterializeDefaultRankSpecializationCases( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, + const SmallVector &shapes) { + return MaterializeEqualShapesRankSpecializationCase( + b, loc, op, shapes, [&](OpBuilder &b, Location loc) { + b.create( + loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes)); + }); +} + +Value MaterializeRankSpecializationForExactlyTwoOperands( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) { + assert(op->getNumOperands() == 2 && op.getNumResults() == 1 && + "The rank specialization strategy for clusters with exactly two " + "operands supports only one result."); + + auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { + return b.create(loc, v).result(); + })); + + // Materialize all the different cases. + Value unshaped_result = MaterializeScalarRankSpecializationCase( + b, loc, op, shapes, /*non_scalar_idx=*/1, + [&](OpBuilder &b, Location loc) { + b.create( + loc, MaterializeScalarRankSpecializationCase( + b, loc, op, shapes, /*non_scalar_idx=*/0, + [&](OpBuilder &b, Location loc) { + b.create( + loc, MaterializeDefaultRankSpecializationCases( + b, loc, op, shapes)); + })); + }); + + // Materialize final reshape once and for all rank specialization cases. + return MaterializeFinalReshape(b, loc, op, unshaped_result).front(); +} + +SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, + Value non_scalar_operand) { + // Flatten the non-scalar operand. + Value flat_shape = b.create( + loc, b.create( + loc, b.getIndexType(), + b.create(loc, non_scalar_operand)) + .result()); + Value flat_non_scalar_operand = b.create( + loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1), + non_scalar_operand, flat_shape); + + // Materialize ranked variants for the element-wise operations. + BlockAndValueMapping bvm; + for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) { + Value operand = std::get<1>(it); + bvm.map(std::get<0>(it), + operand == non_scalar_operand ? flat_non_scalar_operand : operand); + } + SmallVector unshaped_results = + MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1); + + // Restore the results' expected shape. + return MaterializeFinalReshape(b, loc, op, unshaped_results); +} + +// Materialize rank generic rank specialization. Value MaterializeDefaultRankSpecialization( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) { auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { @@ -471,11 +572,8 @@ Value MaterializeDefaultRankSpecialization( })); // Materialize all the different cases. - Value unshaped_result = MaterializeEqualShapesRankSpecializationCase( - b, loc, op, shapes, [&](OpBuilder &b, Location loc) { - b.create( - loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes)); - }); + Value unshaped_result = + MaterializeDefaultRankSpecializationCases(b, loc, op, shapes); // Materialize final reshape once and for all rank specialization cases. return MaterializeFinalReshape(b, loc, op, unshaped_result).front(); @@ -487,6 +585,11 @@ struct LowerRankSpecializationClusterPattern LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, PatternRewriter &rewriter) const override { + // Restoring the result shape currently relies on all operands being used + // for a single result. The result shape is then the broadcasted shape of + // all operands. + if (op.getNumResults() != 1) return failure(); + // TODO(frgossen): If there is a single operand, we can flatten it // completely and apply a non-broadcasting operation. @@ -501,10 +604,13 @@ struct LowerRankSpecializationClusterPattern return success(); } - // Restoring the result shape currently relies on all operands being used - // for a single result. The result shape is then the broadcasted shape of - // all operands. - if (op.getNumResults() != 1) return failure(); + // If there are only two operands, we can consider extra cases in which + // either operand is scalar. + if (op->getNumOperands() == 2) { + rewriter.replaceOp(op, MaterializeRankSpecializationForExactlyTwoOperands( + rewriter, loc, op)); + return success(); + } // For all other cases, reshape the operands to match in rank, apply the // operation, and restore the expected shape. diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 1365e88..45001f1 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -346,12 +346,12 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>) } // CHECK-SCF-LABEL: @mixed -// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor -// CHECK-SCF-DAG: %[[TMP1:.*]] = "mhlo.sqrt"(%{{.*}}) : (tensor) -// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]]) : (tensor) -// CHECK-SCF: chlo.tan %[[TMP4]] : tensor +// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor +// CHECK-SCF-DAG: %[[TMP1:.*]] = "mhlo.sqrt"(%{{.*}}) : (tensor) +// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]]) : (tensor) +// CHECK-SCF: chlo.tan %[[TMP4]] : tensor // ----- @@ -446,3 +446,198 @@ func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor) // CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor, tensor) // CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor, tensor, tensor) + +// ----- + +// CHECK-LABEL: @mul +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) +func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) + // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): + // CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) + // CHECK: return %[[RES]] + %0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-SCF-LABEL: @mul +// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) +// CHECK-SCF-DAG: %[[C1:.*]] = constant 1 +// CHECK-SCF-DAG: %[[C2:.*]] = constant 2 +// CHECK-SCF-DAG: %[[C3:.*]] = constant 3 +// CHECK-SCF-DAG: %[[C4:.*]] = constant 4 +// CHECK-SCF-DAG: %[[C5:.*]] = constant 5 +// CHECK-SCF-DAG: %[[C6:.*]] = constant 6 +// CHECK-SCF-DAG: %[[C7:.*]] = constant 7 +// CHECK-SCF-DAG: %[[C8:.*]] = constant 8 +// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1] +// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1] +// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] +// Lhs scalar case: +// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = cmpi eq, %[[LHS_N]], %[[C1]] +// CHECK-SCF: %[[UNSHAPED_RES_LHS_SCALAR:.*]] = scf.if %[[LHS_SCALAR]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG0]] +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[FLAT_NON_SCALAR]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Rhs scalar case: +// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = cmpi eq, %[[RHS_N]], %[[C1]] +// CHECK-SCF: %[[UNSHAPED_RES_RHS_SCALAR:.*]] = scf.if %[[RHS_SCALAR]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG1]] +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[SCALAR]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Equal shapes case: +// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]] +// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]] +// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Find maximum reduced rank. +// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 +// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 +// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] +// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] +// Generic case 1: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]] +// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Generic case 2: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]] +// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Generic case 3: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]] +// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Generic case 4: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]] +// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Generic case 5: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]] +// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Generic case 6: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]] +// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Generic case 7: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]] +// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Generic case 8: +// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]] +// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_EQ_SHAPES]] +// CHECK-SCF: scf.yield %[[UNSHAPED_RES_RHS_SCALAR]] +// Reshape the result. +// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]]) +// CHECK-SCF: return %[[RES]]