diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 0b9f0f3..e50c4ee 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -181,6 +181,10 @@ bool IsScalarTensorType(Type ty) { return ranked_ty && ranked_ty.getRank() == 0; } +bool IsScalarShapeType(Type ty) { + return ty.cast().getDimSize(0) == 0; +} + Type DeriveRankedTensorTypes(Type ty, int64_t rank) { auto tensor_ty = ty.dyn_cast(); if (!tensor_ty) return ty; @@ -208,11 +212,16 @@ Optional FindUniqueNonScalar(ValueRange values) { SmallVector MaterializeRankedOperations( OpBuilder &b, Location loc, BlockAndValueMapping &bvm, - chlo::RankSpecializationClusterOp op, int64_t target_rank) { + chlo::RankSpecializationClusterOp op) { // Create ranked operations. for (Operation &nested_op : op.getBody()->without_terminator()) { auto mapped_operands = llvm::to_vector<4>(llvm::map_range( nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); })); + int64_t target_rank = 0; + for (Value v : mapped_operands) { + target_rank = + std::max(target_rank, v.getType().cast().getRank()); + } auto ranked_result_types = llvm::to_vector<2>(llvm::map_range( nested_op.getResultTypes(), [&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); })); @@ -265,12 +274,15 @@ Value MaterializeScalarRankSpecializationCase( Value all_others_are_scalar; for (auto it : llvm::enumerate(shapes)) { if (it.index() == non_scalar_idx) continue; + // For statically known scalars, there is no need to test. + if (IsScalarTensorType(op.getOperand(it.index()).getType())) continue; auto literal = b.create(loc, CmpIPredicate::eq, b.create(loc, it.value()), one); all_others_are_scalar = all_others_are_scalar - ? b.create(loc, all_others_are_scalar, literal).getResult() + ? b.create(loc, all_others_are_scalar, literal) + .getResult() : literal.result(); } @@ -303,8 +315,7 @@ Value MaterializeScalarRankSpecializationCase( for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands)) bvm.map(std::get<0>(it), std::get<1>(it)); Value unshaped_result = - MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1) - .front(); + MaterializeRankedOperations(b, loc, bvm, op).front(); // Return as unranked tensor for compatibility with the other cases. b.create( @@ -322,26 +333,29 @@ Value MaterializeEqualShapesRankSpecializationCase( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, const SmallVector &shapes, function_ref else_builder_fn) { - assert(shapes.size() >= 2 && - "This strategy should only be materialized if there are at least two " - "shapes involved."); - // Materialize all shapes equal predicate. - Value all_shapes_eq; - for (Value s : llvm::drop_begin(shapes)) { - auto literal = b.create(loc, shapes.front(), s); - all_shapes_eq = - all_shapes_eq - ? b.create(loc, all_shapes_eq, literal).result() + Value all_shapes_eq_or_scalar; + auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range( + shapes, [](Value v) { return !IsScalarShapeType(v.getType()); })); + assert( + non_scalar_shapes.size() >= 2 && + "Equal shapes strategy requires at least two non-scalar operand shapes."); + for (Value s : llvm::drop_begin(non_scalar_shapes)) { + auto literal = + b.create(loc, non_scalar_shapes.front(), s); + all_shapes_eq_or_scalar = + all_shapes_eq_or_scalar + ? b.create(loc, all_shapes_eq_or_scalar, literal) + .result() : literal; } auto if_op = b.create( - loc, op->getResultTypes(), all_shapes_eq, + loc, op->getResultTypes(), all_shapes_eq_or_scalar, [&](OpBuilder &b, Location loc) { - // Flatten operands. - Value shape = shapes.front(); - for (Value s : llvm::drop_begin(shapes)) { + // Flatten non-scalar operands. + Value shape = non_scalar_shapes.front(); + for (Value s : llvm::drop_begin(non_scalar_shapes)) { shape = b.create(loc, shape.getType(), ValueRange{shape, s}); } @@ -350,6 +364,7 @@ Value MaterializeEqualShapesRankSpecializationCase( .result()); auto flat_operands = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { + if (IsScalarTensorType(v.getType())) return v; return b .create( loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v, @@ -358,13 +373,11 @@ Value MaterializeEqualShapesRankSpecializationCase( })); // Materialize ranked variants for the element-wise operations. - // TODO(frgossen): Materializae non-broadcasting equivalents instead. BlockAndValueMapping bvm; for (auto it : llvm::zip(op.getBody()->getArguments(), flat_operands)) bvm.map(std::get<0>(it), std::get<1>(it)); Value unshaped_result = - MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1) - .front(); + MaterializeRankedOperations(b, loc, bvm, op).front(); // Return as unranked tensor for compatibility with the other cases. b.create( @@ -382,8 +395,6 @@ Value MaterializeTargetRankSpecializationCase( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, const SmallVector &shapes, int64_t target_rank) { // Reshape operands to match the target rank. - llvm::SmallVector ranked_ty_dynamic_dims( - target_rank, RankedTensorType::kDynamicSize); RankedTensorType extent_tensor_ty = shape::getExtentTensorType(b.getContext(), target_rank); Value all_ones_shape = b.create( @@ -394,16 +405,19 @@ Value MaterializeTargetRankSpecializationCase( for (auto it : llvm::zip(op.operands(), shapes)) { Value operand, shape; std::tie(operand, shape) = it; + if (operand.getType().isa()) { + ranked_operands.push_back(operand); + continue; + } Value ranked_shape = b.create( loc, extent_tensor_ty, b.create(loc, shape::getExtentTensorType(b.getContext()), shape, all_ones_shape, /*error=*/nullptr)); - Type element_ty = operand.getType().dyn_cast().getElementType(); - auto ranked_ty = RankedTensorType::get(ranked_ty_dynamic_dims, element_ty); ranked_operands.push_back(b.create( - loc, ranked_ty, operand, ranked_shape)); + loc, DeriveRankedTensorTypes(operand.getType(), target_rank), operand, + ranked_shape)); } // Materialize ranked versions of the element-wise operations. @@ -412,8 +426,7 @@ Value MaterializeTargetRankSpecializationCase( bvm.map(std::get<0>(it), std::get<1>(it)); // Return as unranked for compatibility with other target ranks. - auto unshaped_result = - MaterializeRankedOperations(b, loc, bvm, op, target_rank).front(); + auto unshaped_result = MaterializeRankedOperations(b, loc, bvm, op).front(); return b.create( loc, DeriveUnrankedTensorTypes(unshaped_result.getType()), unshaped_result); @@ -460,25 +473,17 @@ Value MaterializeGenericRankSpecializationCases( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, const SmallVector &shapes) { // Get the minimum broadcast shapes of the operands. - ValueRange reduced_shapes = - b.create( - loc, - SmallVector(shapes.size(), - shape::getExtentTensorType(b.getContext())), - shapes) - .results(); - // TODO(frgossen): Avoid this reshape if it is redundant in all cases. - SmallVector reshaped_args; - for (auto it : llvm::zip(op.operands(), reduced_shapes)) { - Value arg = std::get<0>(it); - Value reduced_shape = std::get<1>(it); - reshaped_args.push_back(b.create( - loc, arg.getType(), arg, reduced_shape)); - } + auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range( + shapes, [](Value v) { return !IsScalarShapeType(v.getType()); })); + auto min_bcast_shapes_op = b.create( + loc, + SmallVector(non_scalar_shapes.size(), + shape::getExtentTensorType(b.getContext())), + non_scalar_shapes); // Find the maximum rank among the reduced operand shapes. Value max_rank; - for (Value shape : reduced_shapes) { + for (Value shape : min_bcast_shapes_op.results()) { Value rank = b.create(loc, b.getIndexType(), shape); if (!max_rank) { max_rank = rank; @@ -489,6 +494,17 @@ Value MaterializeGenericRankSpecializationCases( } } + // Collect reduced shapes. + SmallVector reduced_shapes; + auto it = min_bcast_shapes_op.result_begin(); + for (Value s : shapes) { + if (IsScalarShapeType(s.getType())) { + reduced_shapes.push_back(s); + } else { + reduced_shapes.push_back(*it++); + } + } + // Materialize rank specialization for ranks 1, ..., 8. // TODO(frgossen): For clusters w/o a select operation, consider only ranks // 1, ..., 5. @@ -508,34 +524,6 @@ Value MaterializeDefaultRankSpecializationCases( }); } -Value MaterializeRankSpecializationForExactlyTwoOperands( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) { - assert(op->getNumOperands() == 2 && op.getNumResults() == 1 && - "The rank specialization strategy for clusters with exactly two " - "operands supports only one result."); - - auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { - return b.create(loc, v).result(); - })); - - // Materialize all the different cases. - Value unshaped_result = MaterializeScalarRankSpecializationCase( - b, loc, op, shapes, /*non_scalar_idx=*/1, - [&](OpBuilder &b, Location loc) { - b.create( - loc, MaterializeScalarRankSpecializationCase( - b, loc, op, shapes, /*non_scalar_idx=*/0, - [&](OpBuilder &b, Location loc) { - b.create( - loc, MaterializeDefaultRankSpecializationCases( - b, loc, op, shapes)); - })); - }); - - // Materialize final reshape once and for all rank specialization cases. - return MaterializeFinalReshape(b, loc, op, unshaped_result).front(); -} - SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, Value non_scalar_operand) { @@ -552,17 +540,48 @@ SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( // Materialize ranked variants for the element-wise operations. BlockAndValueMapping bvm; for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) { - Value operand = std::get<1>(it); - bvm.map(std::get<0>(it), + Value operand; + Value bb_arg; + std::tie(bb_arg, operand) = it; + bvm.map(bb_arg, operand == non_scalar_operand ? flat_non_scalar_operand : operand); } SmallVector unshaped_results = - MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1); + MaterializeRankedOperations(b, loc, bvm, op); // Restore the results' expected shape. return MaterializeFinalReshape(b, loc, op, unshaped_results); } +Value MaterializeRankSpecializationForTwoNonScalarOperands( + OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, + ValueRange non_scalar_operands) { + assert(non_scalar_operands.size() == 2); + + auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { + return b.create(loc, v).result(); + })); + auto non_scalar_lhs = llvm::find(op.operands(), non_scalar_operands[0]); + auto non_scalar_rhs = llvm::find(op.operands(), non_scalar_operands[1]); + + // Materialize all the different cases. + Value unshaped_result = MaterializeScalarRankSpecializationCase( + b, loc, op, shapes, non_scalar_rhs.getIndex(), + [&](OpBuilder &b, Location loc) { + b.create( + loc, MaterializeScalarRankSpecializationCase( + b, loc, op, shapes, non_scalar_lhs.getIndex(), + [&](OpBuilder &b, Location loc) { + b.create( + loc, MaterializeDefaultRankSpecializationCases( + b, loc, op, shapes)); + })); + }); + + // Materialize final reshape once and for all rank specialization cases. + return MaterializeFinalReshape(b, loc, op, unshaped_result).front(); +} + // Materialize rank generic rank specialization. Value MaterializeDefaultRankSpecialization( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) { @@ -584,31 +603,36 @@ struct LowerRankSpecializationClusterPattern LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - // Restoring the result shape currently relies on all operands being used // for a single result. The result shape is then the broadcasted shape of // all operands. if (op.getNumResults() != 1) return failure(); - // TODO(frgossen): If there is a single operand, we can flatten it - // completely and apply a non-broadcasting operation. - - // If there is only one unranked operand and all others are known scalars, - // we can flatten the operands to rank 1. - if (Optional non_scalar_operand = - FindUniqueNonScalar(op.operands())) { + // If there is only a single non-scalar operand, we can flatten that operand + // completely. + Location loc = op.getLoc(); + auto non_scalar_operands = + llvm::to_vector<2>(llvm::make_filter_range(op.operands(), [](Value v) { + return !IsScalarTensorType(v.getType()); + })); + if (non_scalar_operands.size() == 1) { rewriter.replaceOp(op, MaterializeRankSpecializationForSingleNonScalarOperand( - rewriter, loc, op, *non_scalar_operand)); + rewriter, loc, op, non_scalar_operands.front())); return success(); } - // If there are only two operands, we can consider extra cases in which - // either operand is scalar. - if (op->getNumOperands() == 2) { - rewriter.replaceOp(op, MaterializeRankSpecializationForExactlyTwoOperands( - rewriter, loc, op)); + // If there are exactly two unranked operands and all others are known to be + // scalars, we can consider two extra cases: If either of the unranked + // operands turns out to be a scalar at runtime, we can, again, apply the + // trick for a single non-scalar operand. + if (non_scalar_operands.size() == 2 && + llvm::all_of(non_scalar_operands, [](Value v) { + return v.getType().isa(); + })) { + rewriter.replaceOp(op, + MaterializeRankSpecializationForTwoNonScalarOperands( + rewriter, loc, op, non_scalar_operands)); return success(); } diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 5040b6f..5c84eb6 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -267,7 +267,7 @@ func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> { // ----- -// Ternary operation. +// Operation with mixed ranked and unranked operands. // CHECK-LABEL: @select_mixed // CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>) func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>, @@ -283,7 +283,8 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>, } // CHECK-SCF-LABEL: @select_mixed -// CHECK-SCF: chlo.broadcast_select %{{.*}}, %{{.*}}, %{{.*}} : (tensor, tensor, tensor) +// CHECK-SCF: chlo.broadcast_select %{{.*}}, %{{.*}}, %{{.*}} : (tensor, tensor, tensor) +// CHECK-SCF: return // ----- @@ -417,6 +418,7 @@ func @angle(%arg : tensor<*xcomplex>) -> tensor<*xf32> { // ----- +// Scalar cluster operand. // CHECK-LABEL: @xlogy // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { @@ -441,11 +443,87 @@ func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { return %5 : tensor<*xf32> } -// CHECK-SCF-LABEL: @xlogy -// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %{{.*}}, %{{.*}} {comparison_direction = "EQ"} : (tensor, tensor) -// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor) -// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor, tensor) -// CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor, tensor, tensor) +// CHECK-SCF: @xlogy +// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) +// CHECK-SCF-DAG: %[[C1:.*]] = constant 1 +// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1] +// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.00{{.*}}> +// Lhs scalar case: +// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = cmpi eq, %[[LHS_N]], %[[C1]] +// CHECK-SCF: %[[UNSHAPED_RES:.*]] = scf.if %[[LHS_SCALAR]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG0]]) +// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[SCALAR]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[FLAT_NON_SCALAR]]) : (tensor) +// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[TMP0]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Rhs scalar case: +// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = cmpi eq, %[[RHS_N]], %[[C1]] +// CHECK-SCF: %{{.*}} = scf.if %[[RHS_SCALAR]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG1]]) +// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_NON_SCALAR]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[SCALAR]]) : (tensor) +// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[TMP0]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Equal shapes case: +// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF: %{{.*}} = scf.if %[[SHAPES_EQ]] +// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]] +// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_ARG0]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[FLAT_ARG1]]) : (tensor) +// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[TMP0]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// Find maximum reduced rank. +// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 +// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 +// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] +// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]] +// Generic case 1: +// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C1]] +// CHECK-SCF: %{{.*}} = scf.if %[[MAX_RED_RANK_LE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] +// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) +// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) +// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[REDUCED_ARG0]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor, tensor) +// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[REDUCED_ARG1]]) : (tensor) +// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[TMP0]] : (tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor, tensor, tensor) +// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] +// CHECK-SCF: scf.yield %[[INNER_RES_]] +// CHECK-SCF: else +// ... +// Reshape the result. +// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) +// CHECK-SCF: return %[[RES]] // -----