Add special cases for SelectOp rank specialization.
We now use the same special cases for all ops with arity >= 2. For binary ops, we now have only one special case if at least one of the operands has exactly one element. In that case, we reshape both operands to rank 1. Before, we had separate special cases whether the left-hand side or the right-hand side have a scalar shape. PiperOrigin-RevId: 366005835
This commit is contained in:
parent
9206805c58
commit
4033a56750
|
@ -183,7 +183,9 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
Value size_tensor =
|
Value size_tensor =
|
||||||
rewriter.create<tensor::FromElementsOp>(loc, num_elements);
|
rewriter.create<tensor::FromElementsOp>(loc, num_elements);
|
||||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
loc, RankedTensorType::get({-1}, scalar_element_type),
|
loc,
|
||||||
|
RankedTensorType::get({RankedTensorType::kDynamicSize},
|
||||||
|
scalar_element_type),
|
||||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||||
|
|
||||||
// Create a new ranked Chlo op that will be further lowered by other
|
// Create a new ranked Chlo op that will be further lowered by other
|
||||||
|
@ -191,7 +193,9 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
|
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
|
||||||
rhs_is_scalar ? rhs : reshaped};
|
rhs_is_scalar ? rhs : reshaped};
|
||||||
Value computed = rewriter.create<ChloOpTy>(
|
Value computed = rewriter.create<ChloOpTy>(
|
||||||
loc, TypeRange{RankedTensorType::get({-1}, result_element_type)},
|
loc,
|
||||||
|
TypeRange{RankedTensorType::get({RankedTensorType::kDynamicSize},
|
||||||
|
result_element_type)},
|
||||||
new_operands, op->getAttrs());
|
new_operands, op->getAttrs());
|
||||||
|
|
||||||
// Reshape the result back into an unranked tensor.
|
// Reshape the result back into an unranked tensor.
|
||||||
|
@ -202,7 +206,7 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ChloOpTy, typename HloOpTy>
|
template <typename ChloOpTy>
|
||||||
struct ConvertUnrankedDynamicBroadcastOpHelper {
|
struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||||
// Returns the dynamic result of checking the given value is effectively a
|
// Returns the dynamic result of checking the given value is effectively a
|
||||||
// scalar shape (i.e. the number of elements is 1).
|
// scalar shape (i.e. the number of elements is 1).
|
||||||
|
@ -378,14 +382,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||||
|
|
||||||
// Handles lowering of the following pattern to patterns that will be further
|
// Handles lowering of the following pattern to patterns that will be further
|
||||||
// matched by other patterns until they result in LHLO:
|
// matched by other patterns until they result in LHLO:
|
||||||
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
// %result = "chlo.op"(%op0, %op1, ...) : (<*xTy>, <*xTy>, ...) -> <*xTy>
|
||||||
//
|
//
|
||||||
// The sequence of specializations this handles is:
|
// The sequence of specializations this handles is:
|
||||||
// - Either operand being scalar
|
// - At most one operand has a shape that does not consist of exactly one
|
||||||
// - Operands having equal shapes
|
// element.
|
||||||
// - The resulting value being any of ranks [2,6]
|
// - All operands having equal shapes
|
||||||
|
// - The resulting minimized shapes being any of ranks [1,5]
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
struct ConvertUnrankedDynamicBroadcastBinaryOp
|
struct ConvertUnrankedDynamicBroadcastNaryOp
|
||||||
: public OpConversionPattern<ChloOpTy> {
|
: public OpConversionPattern<ChloOpTy> {
|
||||||
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
|
||||||
|
|
||||||
|
@ -394,76 +399,96 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
typename ChloOpTy::Adaptor transformed(operands);
|
typename ChloOpTy::Adaptor transformed(operands);
|
||||||
Value lhs = transformed.lhs();
|
ValueRange transformed_operands = transformed.getOperands();
|
||||||
Value rhs = transformed.rhs();
|
auto num_operands = transformed_operands.size();
|
||||||
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
llvm::SmallVector<UnrankedTensorType, 3> operand_types;
|
||||||
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
operand_types.reserve(num_operands);
|
||||||
|
for (int i = 0; i < num_operands; ++i) {
|
||||||
|
auto type =
|
||||||
|
transformed_operands[i].getType().dyn_cast<UnrankedTensorType>();
|
||||||
|
// Only support unranked operands.
|
||||||
|
if (!type) return failure();
|
||||||
|
operand_types.push_back(type);
|
||||||
|
}
|
||||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||||
|
|
||||||
// Only support unranked operands. If either operand is ranked, another
|
llvm::SmallVector<Value> shapes;
|
||||||
// pattern will handle the lowering.
|
shapes.reserve(num_operands);
|
||||||
if (!lhs_type || !rhs_type) return failure();
|
for (int i = 0; i < num_operands; ++i) {
|
||||||
|
shapes.push_back(
|
||||||
|
rewriter.create<shape::ShapeOfOp>(loc, transformed_operands[i]));
|
||||||
|
}
|
||||||
|
|
||||||
Value shape_of_lhs = rewriter.create<shape::ShapeOfOp>(loc, lhs);
|
// If at most one shape does not have exactly one element
|
||||||
Value shape_of_rhs = rewriter.create<shape::ShapeOfOp>(loc, rhs);
|
Value counter = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||||
|
for (int i = 0; i < num_operands; ++i) {
|
||||||
|
Value is_scalar_like = IsSingleElementShape(rewriter, op, shapes[i]);
|
||||||
|
Value counter_plus_one = rewriter.create<AddIOp>(loc, counter, one);
|
||||||
|
counter = rewriter.create<SelectOp>(loc, is_scalar_like, counter_plus_one,
|
||||||
|
counter);
|
||||||
|
}
|
||||||
|
Value num_operands_minus_one =
|
||||||
|
rewriter.create<ConstantIndexOp>(loc, num_operands - 1);
|
||||||
|
Value at_most_one_non_scalar =
|
||||||
|
rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::uge,
|
||||||
|
counter, num_operands_minus_one);
|
||||||
|
|
||||||
// If lhs has exactly one element
|
auto if_op = rewriter.create<scf::IfOp>(loc, result_type,
|
||||||
auto if_op = rewriter.create<scf::IfOp>(
|
at_most_one_non_scalar, true);
|
||||||
loc, result_type, IsSingleElementShape(rewriter, op, shape_of_lhs),
|
OpBuilder if_at_most_one_non_scalar_builder =
|
||||||
true);
|
|
||||||
OpBuilder if_lhs_scalar_builder =
|
|
||||||
if_op.getThenBodyBuilder(rewriter.getListener());
|
if_op.getThenBodyBuilder(rewriter.getListener());
|
||||||
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
|
|
||||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
|
||||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
|
||||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
|
||||||
op->getAttrs());
|
|
||||||
Value extended_if_lhs_scalar_result =
|
|
||||||
extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result,
|
|
||||||
shape_of_lhs, shape_of_rhs);
|
|
||||||
if_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
|
||||||
extended_if_lhs_scalar_result);
|
|
||||||
|
|
||||||
// If lhs does not have exactly one element
|
llvm::SmallVector<Value, 3> reshaped_operands;
|
||||||
|
reshaped_operands.reserve(num_operands);
|
||||||
|
for (int i = 0; i < num_operands; ++i) {
|
||||||
|
Value num_elements =
|
||||||
|
if_at_most_one_non_scalar_builder.create<shape::NumElementsOp>(
|
||||||
|
loc, shapes[i]);
|
||||||
|
Value size_tensor =
|
||||||
|
if_at_most_one_non_scalar_builder.create<tensor::FromElementsOp>(
|
||||||
|
loc, num_elements);
|
||||||
|
Value reshaped =
|
||||||
|
if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get({RankedTensorType::kDynamicSize},
|
||||||
|
operand_types[i].getElementType()),
|
||||||
|
transformed_operands[i], size_tensor);
|
||||||
|
reshaped_operands.push_back(reshaped);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rank_one_result_type = RankedTensorType::get(
|
||||||
|
{RankedTensorType::kDynamicSize}, result_type.getElementType());
|
||||||
|
Value if_at_most_one_non_scalar_result =
|
||||||
|
if_at_most_one_non_scalar_builder.create<ChloOpTy>(
|
||||||
|
loc, ArrayRef<Type>{rank_one_result_type}, reshaped_operands,
|
||||||
|
op->getAttrs());
|
||||||
|
Value extended_result = extendToBroadcastShape(
|
||||||
|
if_at_most_one_non_scalar_builder, loc, result_type,
|
||||||
|
if_at_most_one_non_scalar_result, shapes);
|
||||||
|
if_at_most_one_non_scalar_builder.create<scf::YieldOp>(loc,
|
||||||
|
extended_result);
|
||||||
|
|
||||||
|
// If there is more than one shape which does not have exactly one element
|
||||||
//
|
//
|
||||||
// See if rhs has exactly one element
|
// See if all shapes are equal.
|
||||||
OpBuilder else_lhs_scalar_builder =
|
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
if_op.getElseBodyBuilder(rewriter.getListener());
|
Value equal_shapes =
|
||||||
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
else_builder.create<shape::ShapeEqOp>(loc, shapes[0], shapes[1]);
|
||||||
loc, result_type,
|
for (int i = 2; i < num_operands; ++i) {
|
||||||
IsSingleElementShape(else_lhs_scalar_builder, op, shape_of_rhs), true);
|
Value are_equal =
|
||||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
else_builder.create<shape::ShapeEqOp>(loc, shapes[0], shapes[i]);
|
||||||
if_rhs_scalar_op.getResult(0));
|
equal_shapes = else_builder.create<AndOp>(loc, equal_shapes, are_equal);
|
||||||
OpBuilder if_rhs_scalar_builder =
|
}
|
||||||
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
|
|
||||||
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
|
|
||||||
loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs);
|
|
||||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
|
||||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
|
||||||
op->getAttrs());
|
|
||||||
Value extended_if_rhs_scalar_result =
|
|
||||||
extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result,
|
|
||||||
shape_of_lhs, shape_of_rhs);
|
|
||||||
if_rhs_scalar_builder.create<scf::YieldOp>(loc,
|
|
||||||
extended_if_rhs_scalar_result);
|
|
||||||
|
|
||||||
// If NEITHER shape has exactly one element
|
auto if_eq_shapes_op =
|
||||||
//
|
else_builder.create<scf::IfOp>(loc, result_type, equal_shapes, true);
|
||||||
// See if shapes are equal.
|
else_builder.create<scf::YieldOp>(loc, if_eq_shapes_op.getResult(0));
|
||||||
OpBuilder else_no_scalars_builder =
|
|
||||||
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
|
|
||||||
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
|
||||||
loc, shape_of_lhs, shape_of_rhs);
|
|
||||||
|
|
||||||
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
|
|
||||||
loc, result_type, equal_shapes, true);
|
|
||||||
else_no_scalars_builder.create<scf::YieldOp>(loc,
|
|
||||||
if_eq_shapes_op.getResult(0));
|
|
||||||
|
|
||||||
OpBuilder if_eq_shapes_builder =
|
OpBuilder if_eq_shapes_builder =
|
||||||
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
||||||
Value non_broadcast_op =
|
Value non_broadcast_op = Adaptor::CreateOp(
|
||||||
Adaptor::CreateOp(op, result_type, {lhs, rhs}, if_eq_shapes_builder);
|
op, result_type, transformed_operands, if_eq_shapes_builder);
|
||||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||||
|
|
||||||
// If shapes do not have exactly one element, nor are equal
|
// If shapes do not have exactly one element, nor are equal
|
||||||
|
@ -472,9 +497,9 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
OpBuilder if_neq_shapes_builder =
|
OpBuilder if_neq_shapes_builder =
|
||||||
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||||
loc, ConvertUnrankedDynamicBroadcastOpHelper<
|
loc,
|
||||||
ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder,
|
ConvertUnrankedDynamicBroadcastOpHelper<ChloOpTy>::HandleBroadcastAndOp(
|
||||||
op, {lhs, rhs}));
|
if_neq_shapes_builder, op, transformed_operands));
|
||||||
|
|
||||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||||
return success();
|
return success();
|
||||||
|
@ -494,37 +519,18 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
rewriter.create<ConstantIndexOp>(loc, 1));
|
rewriter.create<ConstantIndexOp>(loc, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
|
Value extendToBroadcastShape(OpBuilder &builder, Location loc,
|
||||||
Value shape_of_lhs, Value shape_of_rhs) const {
|
Type result_type, Value value,
|
||||||
|
ValueRange shapes) const {
|
||||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||||
Value broadcast_shape =
|
Value broadcast_shape = builder.create<shape::BroadcastOp>(
|
||||||
builder.create<shape::BroadcastOp>(loc, unknown_rank_extent_tensor_type,
|
loc, unknown_rank_extent_tensor_type, shapes, nullptr);
|
||||||
shape_of_lhs, shape_of_rhs, nullptr);
|
return builder.create<mhlo::DynamicReshapeOp>(loc, result_type, value,
|
||||||
return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
|
|
||||||
broadcast_shape);
|
broadcast_shape);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Rank-specialize chlo.broadcast_select ops.
|
|
||||||
struct ConvertUnrankedDynamicBroadcastSelectOp
|
|
||||||
: public OpConversionPattern<chlo::BroadcastSelectOp> {
|
|
||||||
using OpConversionPattern<chlo::BroadcastSelectOp>::OpConversionPattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
|
||||||
chlo::BroadcastSelectOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
// For now only do the bare minimum and specialize for every rank. There is
|
|
||||||
// more potential for optimization here. This also is missing the
|
|
||||||
// specialization for rank 0.
|
|
||||||
rewriter.replaceOp(
|
|
||||||
op, {ConvertUnrankedDynamicBroadcastOpHelper<
|
|
||||||
chlo::BroadcastSelectOp,
|
|
||||||
mhlo::SelectOp>::HandleBroadcastAndOp(rewriter, op, operands)});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TransformUnrankedHloPass
|
struct TransformUnrankedHloPass
|
||||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
@ -588,11 +594,14 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||||
#undef MAP_HLO
|
#undef MAP_HLO
|
||||||
#undef MAP_CHLO
|
#undef MAP_CHLO
|
||||||
#undef COMMA
|
#undef COMMA
|
||||||
chlo::PopulateForBroadcastingBinaryOp<
|
chlo::PopulateForBroadcastingBinaryOp<ConvertUnrankedDynamicBroadcastNaryOp>(
|
||||||
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
|
context, patterns);
|
||||||
|
patterns->insert<ConvertUnrankedDynamicBroadcastNaryOp<
|
||||||
|
chlo::BroadcastSelectOp, mhlo::SelectOp,
|
||||||
|
chlo::HloNaryElementwiseAdaptor<chlo::BroadcastSelectOp,
|
||||||
|
mhlo::SelectOp>>>(context);
|
||||||
chlo::PopulateForBroadcastingBinaryOp<
|
chlo::PopulateForBroadcastingBinaryOp<
|
||||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
|
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
|
||||||
patterns->insert<ConvertUnrankedDynamicBroadcastSelectOp>(context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
|
||||||
|
|
|
@ -159,138 +159,130 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index
|
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index
|
||||||
// Handle scalar LHS case
|
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[C0]], %[[C1]] : index
|
||||||
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[COUNTER:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[C0]] : index
|
||||||
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor<f32>
|
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index
|
||||||
|
// CHECK-NEXT: %[[COUNTER_PLUS_ONE2:.*]] = addi %[[COUNTER]], %[[C1]] : index
|
||||||
|
// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE2]], %[[COUNTER]] : index
|
||||||
|
// Handle scalar case
|
||||||
|
// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER2]], %[[C1]] : index
|
||||||
|
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
|
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[SHAPE_BROADCAST_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: %[[RESHAPED_EXTENDED_LHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_LHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_LHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_LHS_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index
|
// Handle equal shapes case
|
||||||
// Handle scalar RHS case
|
// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
||||||
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor<f32>
|
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex>
|
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[SHAPE_BROADCAST_RHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_RHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_RHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RHS_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||||
// Handle equal shapes case
|
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
|
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex>
|
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>
|
// Handle rank 1 specialization
|
||||||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
||||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// Handle rank 2 specialization
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
// Handle rank 1 specialization
|
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
|
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C3]] : index
|
||||||
// Handle rank 2 specialization
|
// Handle rank 3 specialization
|
||||||
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
|
// CHECK-NEXT: %[[C4:.*]] = constant 4 : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C3]] : index
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C4]] : index
|
||||||
// Handle rank 3 specialization
|
// Handle rank 4 specialization
|
||||||
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[C4:.*]] = constant 4 : index
|
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C4]] : index
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index
|
||||||
// Handle rank 4 specialization
|
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
|
||||||
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
// Handle rank 5 specialization
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index
|
|
||||||
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
|
|
||||||
// Handle rank 5 specialization
|
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_71:.*]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[VAL_71:.*]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
@ -315,36 +307,77 @@ func @selectUnrankedUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
|
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %c0 = constant 0 : index
|
||||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
||||||
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
|
|
||||||
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
|
|
||||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
|
||||||
// CHECK-NEXT: %c1 = constant 1 : index
|
// CHECK-NEXT: %c1 = constant 1 : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index
|
// CHECK-NEXT: %[[NUM_ELEMENTS_PRED:.*]] = shape.num_elements %[[PRED_SHAPE]] : tensor<?xindex> -> index
|
||||||
// Handle rank 1 specialization
|
// CHECK-NEXT: %[[PRED_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_PRED]], %c1 : index
|
||||||
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %c0, %c1 : index
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
|
// CHECK-NEXT: %[[COUNTER:.*]] = select %[[PRED_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %c0 : index
|
||||||
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[NUM_ELEMENTS_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
|
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_LHS]], %c1 : index
|
||||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER]], %c1 : index
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER]] : index
|
||||||
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
|
// CHECK-NEXT: %[[NUM_ELEMENTS_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_RHS]], %c1 : index
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER2]], %c1 : index
|
||||||
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
|
// CHECK-NEXT: %[[COUNTER3:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER2]] : index
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-NEXT: %c2 = constant 2 : index
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER3]], %c2 : index
|
||||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[IF_IS_SCALAR_CASE:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) {
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
// CHECK-NEXT: %[[NUM_TENS_PRED:.*]] = tensor.from_elements %[[NUM_ELEMENTS_PRED]] : tensor<1xindex>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[NUM_TENS_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||||
|
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_LHS]] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_RHS]] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %[[FIRST_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[LHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[SECOND_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[ALL_SHAPES_EQUAL:.*]] = and %[[FIRST_SHAPES_EQUAL]], %[[SECOND_SHAPES_EQUAL]] : i1
|
||||||
|
// CHECK-NEXT: %[[IF_EQUAL_CASE:.*]] = scf.if %[[ALL_SHAPES_EQUAL]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[ANY_TENSOR]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = "mhlo.select"(%[[FLATTENED_PRED]], %[[FLATTENED_LHS]], %[[FLATTENED_RHS]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: } else {
|
||||||
|
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
|
||||||
|
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||||
|
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||||
|
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||||
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index
|
||||||
|
// Handle rank 1 specialization
|
||||||
|
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||||
|
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||||
|
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
|
||||||
|
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||||
|
|
Loading…
Reference in New Issue