Legalize MinimumBroadcastShapes op.
Use it in TransformUnrankedHloPass, which allows to reduce the maximum rank for rank specialized broadcast from 6 to 5. PiperOrigin-RevId: 360415743
This commit is contained in:
parent
329b1fd071
commit
0683db3b24
|
@ -51,6 +51,7 @@ struct ChloLegalizeToHloPass
|
|||
conversionTarget.addLegalDialect<
|
||||
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
|
||||
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
||||
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
|
||||
|
||||
if (broadcast_only_) {
|
||||
chlo::PopulateChloBroadcastingPatterns(&getContext(),
|
||||
|
|
|
@ -223,9 +223,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
|||
}
|
||||
|
||||
static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op,
|
||||
Value value, int targeted_rank) {
|
||||
Value shape, int targeted_rank) {
|
||||
auto loc = op.getLoc();
|
||||
Value shape = builder.create<shape::ShapeOfOp>(loc, value);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
|
@ -246,6 +245,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
|||
static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder,
|
||||
ChloOpTy op,
|
||||
ValueRange operands,
|
||||
ValueRange operand_shapes,
|
||||
int targeted_rank) {
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<Value, 2> reshaped_operands;
|
||||
|
@ -253,10 +253,12 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
|||
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
|
||||
targeted_rank, RankedTensorType::kDynamicSize);
|
||||
|
||||
for (Value operand : operands) {
|
||||
for (auto it : llvm::zip(operands, operand_shapes)) {
|
||||
Value operand, shape;
|
||||
std::tie(operand, shape) = it;
|
||||
// Handle shape broadcasting and inference.
|
||||
Value extended_operand_casted =
|
||||
createBroadcastToKnownRank(if_builder, op, operand, targeted_rank);
|
||||
createBroadcastToKnownRank(if_builder, op, shape, targeted_rank);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of
|
||||
// elements)
|
||||
|
@ -290,13 +292,37 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
|||
ValueRange operands) {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Find the larger rank of the operands.
|
||||
// Get the minimum broadcast shapes of the operands.
|
||||
SmallVector<Value> shapes;
|
||||
shapes.reserve(operands.size());
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
Value greater_rank;
|
||||
for (Value operand : operands) {
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
auto broadcast_shape = rewriter.create<shape::BroadcastOp>(
|
||||
loc, extent_tensor_type, shapes, nullptr);
|
||||
SmallVector<Type> result_types(shapes.size(), extent_tensor_type);
|
||||
auto reduced_shapes =
|
||||
rewriter
|
||||
.create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes)
|
||||
.results();
|
||||
SmallVector<Value> reshaped_operands;
|
||||
reshaped_operands.reserve(operands.size());
|
||||
for (auto it : llvm::zip(operands, reduced_shapes)) {
|
||||
Value operand;
|
||||
Value reduced_shape;
|
||||
std::tie(operand, reduced_shape) = it;
|
||||
auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, operand.getType(), operand, reduced_shape);
|
||||
reshaped_operands.push_back(reshaped_operand);
|
||||
}
|
||||
|
||||
// Find the largest rank of the operands.
|
||||
Value greater_rank;
|
||||
for (Value shape : reduced_shapes) {
|
||||
Value rank =
|
||||
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
|
||||
if (!greater_rank) {
|
||||
|
@ -314,17 +340,19 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
|||
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
rewriter, op, greater_rank, 1);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1);
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, 1);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
constexpr int kMaxRankSpecialization = 6;
|
||||
constexpr int kMaxRankSpecialization = 5;
|
||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
else_builder, op, greater_rank, i);
|
||||
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, operands, i);
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, i);
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||
}
|
||||
|
@ -336,12 +364,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
|||
kMaxRankSpecialization),
|
||||
"Input for dynamic binary op lowering was of a rank greater than " +
|
||||
std::to_string(kMaxRankSpecialization));
|
||||
// Add the rank 6 specialization to the innermost else block.
|
||||
createRankSpecializedBroadcastAndOp(else_builder, op, operands,
|
||||
kMaxRankSpecialization);
|
||||
// Add the rank 5 specialization to the innermost else block.
|
||||
createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands,
|
||||
reduced_shapes, kMaxRankSpecialization);
|
||||
|
||||
// Return the result of the outermost if statement.
|
||||
return if_op.getResult(0);
|
||||
// Return the reshaped result of the outermost if statement.
|
||||
auto result = if_op.getResult(0);
|
||||
auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, result.getType(), result, broadcast_shape);
|
||||
return reshaped_result;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -497,16 +528,17 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
|
|||
struct TransformUnrankedHloPass
|
||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
|
||||
shape::ShapeDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
// Setup conversion target.
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect,
|
||||
shape::ShapeDialect, scf::SCFDialect,
|
||||
tensor::TensorDialect>();
|
||||
target.addLegalDialect<chlo::HloClientDialect, mhlo::MhloDialect,
|
||||
StandardOpsDialect, shape::ShapeDialect,
|
||||
scf::SCFDialect, tensor::TensorDialect>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
|
||||
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
|
||||
|
|
|
@ -199,20 +199,24 @@ func @addUnrankedUnranked(
|
|||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// Handle rank 1 specialization
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||
|
@ -222,12 +226,12 @@ func @addUnrankedUnranked(
|
|||
// Handle rank 2 specialization
|
||||
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
|
||||
|
@ -237,12 +241,12 @@ func @addUnrankedUnranked(
|
|||
// Handle rank 3 specialization
|
||||
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
|
||||
|
@ -252,47 +256,30 @@ func @addUnrankedUnranked(
|
|||
// Handle rank 4 specialization
|
||||
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index
|
||||
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
|
||||
// Handle rank 5 specialization
|
||||
// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C6]] : index
|
||||
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]]
|
||||
// Handle rank 6 specialization
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
@ -300,7 +287,8 @@ func @addUnrankedUnranked(
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
@ -325,13 +313,18 @@ func @selectUnrankedUnrankedUnranked(
|
|||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %c1 = constant 1 : index
|
||||
|
@ -339,15 +332,15 @@ func @selectUnrankedUnrankedUnranked(
|
|||
// Handle rank 1 specialization
|
||||
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
||||
|
@ -357,4 +350,3 @@ func @selectUnrankedUnrankedUnranked(
|
|||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
|
|
Loading…
Reference in New Issue