Support CHLO->LHLO lowering for broadcasting operations with both inputs unranked.
PiperOrigin-RevId: 323960733
This commit is contained in:
parent
b09bf2a4dc
commit
ffef8d6593
|
@ -220,6 +220,229 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
|||
}
|
||||
};
|
||||
|
||||
// Handles lowering of the following pattern to patterns that will be further
|
||||
// matched by other patterns until they result in LHLO:
|
||||
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
||||
//
|
||||
// The sequence of specializations this handles is:
|
||||
// - Either operand being scalar
|
||||
// - Operands having equal shapes
|
||||
// - The resulting value being any of ranks [2,6]
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||
: public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Only support unranked operands. If either operand is ranked, another
|
||||
// pattern will handle the lowering.
|
||||
if (!lhs_type || !rhs_type) return failure();
|
||||
|
||||
// If lhs is scalar
|
||||
auto if_op = rewriter.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
|
||||
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||
op.getAttrs());
|
||||
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
|
||||
|
||||
// If lhs is NOT scalar
|
||||
//
|
||||
// See if rhs is scalar
|
||||
OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder();
|
||||
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
|
||||
true);
|
||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||
if_rhs_scalar_op.getResult(0));
|
||||
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
|
||||
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||
op.getAttrs());
|
||||
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
|
||||
|
||||
// If NEITHER shape is scalar
|
||||
//
|
||||
// See if shapes are equal.
|
||||
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value shape_of_lhs =
|
||||
else_no_scalars_builder.create<shape::ToExtentTensorOp>(
|
||||
loc, extent_tensor_type,
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs)
|
||||
.getResult());
|
||||
Value shape_of_rhs =
|
||||
else_no_scalars_builder.create<shape::ToExtentTensorOp>(
|
||||
loc, extent_tensor_type,
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs)
|
||||
.getResult());
|
||||
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
||||
loc, shape_of_lhs, shape_of_rhs);
|
||||
|
||||
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
|
||||
loc, result_type, equal_shapes, true);
|
||||
else_no_scalars_builder.create<scf::YieldOp>(loc,
|
||||
if_eq_shapes_op.getResult(0));
|
||||
|
||||
OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder();
|
||||
Value non_broadcast_op =
|
||||
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||
|
||||
// If shapes are not scalar, nor equal
|
||||
//
|
||||
// See if values are of a rank that we support.
|
||||
OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder();
|
||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
|
||||
|
||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns the dyanamic result of checking the given value is a scalar
|
||||
// tensor.
|
||||
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
||||
auto loc = op.getLoc();
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
|
||||
Value shape_of_tensor = rewriter.create<shape::ToExtentTensorOp>(
|
||||
loc, extent_tensor_type,
|
||||
rewriter.create<shape::ShapeOfOp>(loc, tensor).getResult());
|
||||
Value rank_tensor = rewriter.create<shape::RankOp>(
|
||||
loc, rewriter.getIndexType(), shape_of_tensor);
|
||||
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
||||
rank_tensor,
|
||||
rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
|
||||
// Create the if statement and code for a broadcasting op with a result of a
|
||||
// given rank.
|
||||
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
|
||||
Value lhs, Value rhs,
|
||||
Value actual_rank,
|
||||
int targeted_rank) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Create the if block to place the current specialized logic in.
|
||||
Value greater_rank_is_n = builder.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, actual_rank,
|
||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||
auto if_op =
|
||||
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder();
|
||||
|
||||
// Handle shape broadcasting and inferrence.
|
||||
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto extent_tensor_type =
|
||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||
auto reshaped_type = RankedTensorType::get(
|
||||
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||
RankedTensorType::kDynamicSize),
|
||||
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
||||
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
|
||||
loc, extent_tensor_type,
|
||||
mlir::DenseIntElementsAttr::get(extent_tensor_type, ranked_shape));
|
||||
// TODO(tpopp): Return extent tensors when possible to signal that this is a
|
||||
// guaranteed safe broadcast by construction.
|
||||
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, lhs_shape, ranked_shape_val, nullptr);
|
||||
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, rhs_shape, ranked_shape_val, nullptr);
|
||||
Value lhs_extent_tensor = if_builder.create<shape::ToExtentTensorOp>(
|
||||
loc, extent_tensor_type, extended_lhs);
|
||||
Value rhs_extent_tensor = if_builder.create<shape::ToExtentTensorOp>(
|
||||
loc, extent_tensor_type, extended_rhs);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of elements)
|
||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
||||
// can be broadcasted and do the actual broadcasting)
|
||||
// 3. Type erase the output back to unranked
|
||||
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, lhs, lhs_extent_tensor);
|
||||
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, rhs, rhs_extent_tensor);
|
||||
Value result = if_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{reshaped_type},
|
||||
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
||||
Value reshaped_result = if_builder.create<TensorCastOp>(
|
||||
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
|
||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||
|
||||
// Return the if_op, so the result can be used and the else block can be
|
||||
// used for the next rank specialized step.
|
||||
return if_op;
|
||||
}
|
||||
|
||||
// Iterates over the desired ranks to be specialized and generates the code
|
||||
// snippet for each case.
|
||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
||||
Value rhs) const {
|
||||
constexpr int max_rank_specialization = 7;
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Find the larger rank of the 2 operands.
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value lhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
|
||||
Value rhs_shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
||||
Value lhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
||||
Value rhs_rank =
|
||||
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
||||
Value greater_rank_lhs =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
||||
Value greater_rank =
|
||||
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
||||
|
||||
// Generate a list of nested if/else statements to handle rank
|
||||
// specializations from 2-6.
|
||||
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
||||
rhs, greater_rank, 2);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder();
|
||||
for (int i = 3; i < max_rank_specialization; i++) {
|
||||
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
||||
rhs, greater_rank, i);
|
||||
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder();
|
||||
}
|
||||
|
||||
// Fire an assertion if none of the rank specializations applied (one of the
|
||||
// ranks was greater than 6).
|
||||
else_builder.create<AssertOp>(
|
||||
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
|
||||
"Input for dynamic binary op lowering was of a rank greater than 6");
|
||||
else_builder.create<scf::YieldOp>(loc, lhs);
|
||||
|
||||
// Return the result of the outermost if statement.
|
||||
return if_op.getResult(0);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
void PopulateForBinaryOp(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
|
@ -230,7 +453,8 @@ void PopulateForBinaryOp(MLIRContext *context,
|
|||
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||
context, 5);
|
||||
patterns->insert<
|
||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>>(
|
||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>,
|
||||
ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||
context);
|
||||
}
|
||||
|
||||
|
|
|
@ -311,3 +311,141 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
|
|||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedUnranked(
|
||||
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addUnrankedUnranked(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32>
|
||||
// CHECK: %[[LHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[LHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_EXTENT_TENSOR]]
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
|
||||
// Handle scalar LHS case
|
||||
// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32>
|
||||
// CHECK: %[[RHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[RHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_EXTENT_TENSOR]]
|
||||
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
|
||||
// Handle scalar RHS case
|
||||
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_EXTENT_TENSOR]], %[[RHS_EXTENT_TENSOR]] : tensor<?xindex>, tensor<?xindex>
|
||||
// Handle scalar RHS case
|
||||
// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>
|
||||
// CHECK: scf.yield %[[VAL_19]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
|
||||
// Handle rank 2 specialization
|
||||
// CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]]
|
||||
// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]]
|
||||
// CHECK: %[[EXTENT_LHS_2:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_2]] : !shape.shape -> tensor<2xindex>
|
||||
// CHECK: %[[EXTENT_RHS_2:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_2]] : !shape.shape -> tensor<2xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
|
||||
// Handle rank 3 specialization
|
||||
// CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]]
|
||||
// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]]
|
||||
// CHECK: %[[EXTENT_LHS_3:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_3]] : !shape.shape -> tensor<3xindex>
|
||||
// CHECK: %[[EXTENT_RHS_3:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_3]] : !shape.shape -> tensor<3xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C4:.*]] = constant 4 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
|
||||
// Handle rank 4 specialization
|
||||
// CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]]
|
||||
// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]]
|
||||
// CHECK: %[[EXTENT_LHS_4:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_4]] : !shape.shape -> tensor<4xindex>
|
||||
// CHECK: %[[EXTENT_RHS_4:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_4]] : !shape.shape -> tensor<4xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C5:.*]] = constant 5 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
|
||||
// Handle rank 5 specialization
|
||||
// CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]]
|
||||
// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]]
|
||||
// CHECK: %[[EXTENT_LHS_5:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_5]] : !shape.shape -> tensor<5xindex>
|
||||
// CHECK: %[[EXTENT_RHS_5:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_5]] : !shape.shape -> tensor<5xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[C6:.*]] = constant 6 : index
|
||||
// CHECK: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
|
||||
// Handle rank 6 specialization
|
||||
// CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
|
||||
// CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]]
|
||||
// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]]
|
||||
// CHECK: %[[EXTENT_LHS_6:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_6]] : !shape.shape -> tensor<6xindex>
|
||||
// CHECK: %[[EXTENT_RHS_6:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_6]] : !shape.shape -> tensor<6xindex>
|
||||
// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||
// CHECK: } else {
|
||||
// CHECK: %false = constant false
|
||||
// CHECK: assert %false
|
||||
// CHECK: scf.yield %[[LHS]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return %[[VAL_71:.*]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
|
Loading…
Reference in New Issue