[mlir][hlo] Refactor rank specialization to allow an arbitrary number of inputs
This actually simplifies the code a bit. PiperOrigin-RevId: 358201038
This commit is contained in:
parent
ca4034b56e
commit
b42def4612
|
@ -202,6 +202,149 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename ChloOpTy, typename HloOpTy>
|
||||||
|
struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||||
|
// Returns the dynamic result of checking the given value is effectively a
|
||||||
|
// scalar shape (i.e. the number of elements is 1).
|
||||||
|
static Value GreaterRankIsN(OpBuilder &builder, Location loc,
|
||||||
|
Value actual_rank, int targeted_rank) {
|
||||||
|
return builder.create<CmpIOp>(
|
||||||
|
loc, CmpIPredicate::eq, actual_rank,
|
||||||
|
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||||
|
}
|
||||||
|
|
||||||
|
static scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
|
||||||
|
OpBuilder &builder, ChloOpTy op, Value actual_rank, int targeted_rank) {
|
||||||
|
// Create the if block to place the current specialized logic in.
|
||||||
|
Value greater_rank_is_n =
|
||||||
|
GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
|
||||||
|
return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
|
||||||
|
greater_rank_is_n, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op,
|
||||||
|
Value value, int targeted_rank) {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value shape = builder.create<shape::ShapeOfOp>(loc, value);
|
||||||
|
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||||
|
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||||
|
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||||
|
auto known_rank_extent_tensor_type =
|
||||||
|
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||||
|
Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
|
||||||
|
loc, known_rank_extent_tensor_type,
|
||||||
|
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||||
|
ranked_shape));
|
||||||
|
Value extended_value = builder.create<shape::BroadcastOp>(
|
||||||
|
loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
|
||||||
|
return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
|
||||||
|
extended_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the if statement and code for a broadcasting op with a result of a
|
||||||
|
// given rank.
|
||||||
|
static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder,
|
||||||
|
ChloOpTy op,
|
||||||
|
ValueRange operands,
|
||||||
|
int targeted_rank) {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
SmallVector<Value, 2> reshaped_operands;
|
||||||
|
|
||||||
|
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
|
||||||
|
targeted_rank, RankedTensorType::kDynamicSize);
|
||||||
|
|
||||||
|
for (Value operand : operands) {
|
||||||
|
// Handle shape broadcasting and inference.
|
||||||
|
Value extended_operand_casted =
|
||||||
|
createBroadcastToKnownRank(if_builder, op, operand, targeted_rank);
|
||||||
|
|
||||||
|
// 1. Reshape operands to the given rank (with the same number of
|
||||||
|
// elements)
|
||||||
|
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the
|
||||||
|
// ops
|
||||||
|
// can be broadcasted and do the actual broadcasting)
|
||||||
|
// 3. Type erase the output back to unranked
|
||||||
|
auto reshaped_type = RankedTensorType::get(
|
||||||
|
dynamic_dimensions,
|
||||||
|
operand.getType().template dyn_cast<TensorType>().getElementType());
|
||||||
|
Value reshaped_operand = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, reshaped_type, operand, extended_operand_casted);
|
||||||
|
reshaped_operands.push_back(reshaped_operand);
|
||||||
|
}
|
||||||
|
auto result_element_type = op.getResult()
|
||||||
|
.getType()
|
||||||
|
.template dyn_cast<TensorType>()
|
||||||
|
.getElementType();
|
||||||
|
auto result_type =
|
||||||
|
RankedTensorType::get(dynamic_dimensions, result_element_type);
|
||||||
|
Value result = if_builder.create<ChloOpTy>(
|
||||||
|
loc, ArrayRef<Type>{result_type}, reshaped_operands, op.getAttrs());
|
||||||
|
Value reshaped_result = if_builder.create<tensor::CastOp>(
|
||||||
|
loc, UnrankedTensorType::get(result_element_type), result);
|
||||||
|
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterates over the desired ranks to be specialized and generates the code
|
||||||
|
// snippet for each case.
|
||||||
|
static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op,
|
||||||
|
ValueRange operands) {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
// Find the larger rank of the operands.
|
||||||
|
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||||
|
rewriter.getIndexType());
|
||||||
|
Value greater_rank;
|
||||||
|
for (Value operand : operands) {
|
||||||
|
Value shape =
|
||||||
|
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
|
||||||
|
Value rank =
|
||||||
|
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
|
||||||
|
if (!greater_rank) {
|
||||||
|
greater_rank = rank;
|
||||||
|
} else {
|
||||||
|
Value greater_rank_compare = rewriter.create<CmpIOp>(
|
||||||
|
loc, CmpIPredicate::sgt, greater_rank, rank);
|
||||||
|
greater_rank = rewriter.create<SelectOp>(loc, greater_rank_compare,
|
||||||
|
greater_rank, rank);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a list of nested if/else statements to handle rank
|
||||||
|
// specializations from 1 to `kMaxRankSpecialization`.
|
||||||
|
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
||||||
|
rewriter, op, greater_rank, 1);
|
||||||
|
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
||||||
|
createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1);
|
||||||
|
|
||||||
|
// Put each subsequent rank specialization inside the else statement of the
|
||||||
|
// previous one.
|
||||||
|
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
|
constexpr int kMaxRankSpecialization = 6;
|
||||||
|
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
||||||
|
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||||
|
else_builder, op, greater_rank, i);
|
||||||
|
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
||||||
|
createRankSpecializedBroadcastAndOp(if_builder, op, operands, i);
|
||||||
|
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||||
|
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||||
|
}
|
||||||
|
// Fire an assertion if none of the rank specializations applied (one of
|
||||||
|
// the ranks was greater than `kMaxRankSpecialization`).
|
||||||
|
else_builder.create<AssertOp>(
|
||||||
|
loc,
|
||||||
|
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
|
||||||
|
kMaxRankSpecialization),
|
||||||
|
"Input for dynamic binary op lowering was of a rank greater than " +
|
||||||
|
std::to_string(kMaxRankSpecialization));
|
||||||
|
// Add the rank 6 specialization to the innermost else block.
|
||||||
|
createRankSpecializedBroadcastAndOp(else_builder, op, operands,
|
||||||
|
kMaxRankSpecialization);
|
||||||
|
|
||||||
|
// Return the result of the outermost if statement.
|
||||||
|
return if_op.getResult(0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Handles lowering of the following pattern to patterns that will be further
|
// Handles lowering of the following pattern to patterns that will be further
|
||||||
// matched by other patterns until they result in LHLO:
|
// matched by other patterns until they result in LHLO:
|
||||||
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
|
||||||
|
@ -298,7 +441,9 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
OpBuilder if_neq_shapes_builder =
|
OpBuilder if_neq_shapes_builder =
|
||||||
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||||
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
|
loc, ConvertUnrankedDynamicBroadcastOpHelper<
|
||||||
|
ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder,
|
||||||
|
op, {lhs, rhs}));
|
||||||
|
|
||||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||||
return success();
|
return success();
|
||||||
|
@ -318,23 +463,6 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
rewriter.create<ConstantIndexOp>(loc, 1));
|
rewriter.create<ConstantIndexOp>(loc, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
|
|
||||||
int targeted_rank) const {
|
|
||||||
return builder.create<CmpIOp>(
|
|
||||||
loc, CmpIPredicate::eq, actual_rank,
|
|
||||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
|
||||||
}
|
|
||||||
|
|
||||||
scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
|
|
||||||
OpBuilder &builder, ChloOpTy op, Value actual_rank,
|
|
||||||
int targeted_rank) const {
|
|
||||||
// Create the if block to place the current specialized logic in.
|
|
||||||
Value greater_rank_is_n =
|
|
||||||
GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
|
|
||||||
return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
|
|
||||||
greater_rank_is_n, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
|
Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
|
||||||
Value shape_of_lhs, Value shape_of_rhs) const {
|
Value shape_of_lhs, Value shape_of_rhs) const {
|
||||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||||
|
@ -345,122 +473,6 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
|
return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
|
||||||
broadcast_shape);
|
broadcast_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value,
|
|
||||||
int targeted_rank) const {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
Value shape = builder.create<shape::ShapeOfOp>(loc, value);
|
|
||||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
|
||||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
|
||||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
|
||||||
auto known_rank_extent_tensor_type =
|
|
||||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
|
||||||
Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
|
|
||||||
loc, known_rank_extent_tensor_type,
|
|
||||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
|
||||||
ranked_shape));
|
|
||||||
Value extended_value = builder.create<shape::BroadcastOp>(
|
|
||||||
loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
|
|
||||||
return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
|
|
||||||
extended_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the if statement and code for a broadcasting op with a result of a
|
|
||||||
// given rank.
|
|
||||||
void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
|
|
||||||
Value lhs, Value rhs,
|
|
||||||
int targeted_rank) const {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
|
|
||||||
// Handle shape broadcasting and inference.
|
|
||||||
Value extended_lhs_casted =
|
|
||||||
createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank);
|
|
||||||
Value extended_rhs_casted =
|
|
||||||
createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank);
|
|
||||||
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
|
|
||||||
targeted_rank, RankedTensorType::kDynamicSize);
|
|
||||||
auto reshaped_type = RankedTensorType::get(
|
|
||||||
dynamic_dimensions,
|
|
||||||
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
|
||||||
|
|
||||||
// 1. Reshape operands to the given rank (with the same number of elements)
|
|
||||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
|
||||||
// can be broadcasted and do the actual broadcasting)
|
|
||||||
// 3. Type erase the output back to unranked
|
|
||||||
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, reshaped_type, lhs, extended_lhs_casted);
|
|
||||||
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, reshaped_type, rhs, extended_rhs_casted);
|
|
||||||
auto result_element_type = op.getResult()
|
|
||||||
.getType()
|
|
||||||
.template dyn_cast<TensorType>()
|
|
||||||
.getElementType();
|
|
||||||
auto result_type =
|
|
||||||
RankedTensorType::get(dynamic_dimensions, result_element_type);
|
|
||||||
Value result = if_builder.create<ChloOpTy>(
|
|
||||||
loc, ArrayRef<Type>{result_type},
|
|
||||||
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
|
||||||
Value reshaped_result = if_builder.create<tensor::CastOp>(
|
|
||||||
loc, UnrankedTensorType::get(result_element_type), result);
|
|
||||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterates over the desired ranks to be specialized and generates the code
|
|
||||||
// snippet for each case.
|
|
||||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
|
||||||
Value rhs) const {
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
|
|
||||||
// Find the larger rank of the 2 operands.
|
|
||||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
|
||||||
rewriter.getIndexType());
|
|
||||||
Value lhs_shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
|
|
||||||
Value rhs_shape =
|
|
||||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
|
|
||||||
Value lhs_rank =
|
|
||||||
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), lhs_shape);
|
|
||||||
Value rhs_rank =
|
|
||||||
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), rhs_shape);
|
|
||||||
Value greater_rank_lhs =
|
|
||||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
|
|
||||||
Value greater_rank =
|
|
||||||
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
|
|
||||||
|
|
||||||
// Generate a list of nested if/else statements to handle rank
|
|
||||||
// specializations from 1 to `kMaxRankSpecialization`.
|
|
||||||
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
|
||||||
rewriter, op, greater_rank, 1);
|
|
||||||
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
|
||||||
createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, 1);
|
|
||||||
|
|
||||||
// Put each subsequent rank specialization inside the else statement of the
|
|
||||||
// previous one.
|
|
||||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
|
||||||
constexpr int kMaxRankSpecialization = 6;
|
|
||||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
|
||||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
|
||||||
else_builder, op, greater_rank, i);
|
|
||||||
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
|
||||||
createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, i);
|
|
||||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
|
||||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
|
||||||
}
|
|
||||||
// Fire an assertion if none of the rank specializations applied (one of
|
|
||||||
// the ranks was greater than `kMaxRankSpecialization`).
|
|
||||||
else_builder.create<AssertOp>(
|
|
||||||
loc,
|
|
||||||
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
|
|
||||||
kMaxRankSpecialization),
|
|
||||||
"Input for dynamic binary op lowering was of a rank greater than " +
|
|
||||||
std::to_string(kMaxRankSpecialization));
|
|
||||||
// Add the rank 6 specialization to the innermost else block.
|
|
||||||
createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs,
|
|
||||||
kMaxRankSpecialization);
|
|
||||||
|
|
||||||
// Return the result of the outermost if statement.
|
|
||||||
return if_op.getResult(0);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TransformUnrankedHloPass
|
struct TransformUnrankedHloPass
|
||||||
|
|
|
@ -209,9 +209,9 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
||||||
|
@ -224,9 +224,9 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||||
|
@ -239,9 +239,9 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||||
|
@ -254,9 +254,9 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||||
|
@ -269,9 +269,9 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||||
|
@ -284,9 +284,9 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||||
|
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||||
|
|
Loading…
Reference in New Issue