From 0683db3b24fe075cdecfc53e680da33816902d7a Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 2 Mar 2021 06:37:48 -0800 Subject: [PATCH] Legalize MinimumBroadcastShapes op. Use it in TransformUnrankedHloPass, which allows to reduce the maximum rank for rank specialized broadcast from 6 to 5. PiperOrigin-RevId: 360415743 --- .../transforms/chlo_legalize_to_hlo_pass.cc | 1 + .../mhlo/transforms/transform_unranked_hlo.cc | 68 ++++++++--- tests/hlo-transform-unranked.mlir | 110 ++++++++---------- 3 files changed, 102 insertions(+), 77 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index f9ee38b..5daf607 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -51,6 +51,7 @@ struct ChloLegalizeToHloPass conversionTarget.addLegalDialect< MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect, mlir::shape::ShapeDialect, mlir::scf::SCFDialect>(); + conversionTarget.addLegalOp(); if (broadcast_only_) { chlo::PopulateChloBroadcastingPatterns(&getContext(), diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 6164185..9325d7d 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -223,9 +223,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { } static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, - Value value, int targeted_rank) { + Value shape, int targeted_rank) { auto loc = op.getLoc(); - Value shape = builder.create(loc, value); SmallVector ranked_shape(targeted_rank, 1); auto unknown_rank_extent_tensor_type = RankedTensorType::get( {RankedTensorType::kDynamicSize}, builder.getIndexType()); @@ -246,6 +245,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op, ValueRange operands, + ValueRange operand_shapes, int targeted_rank) { auto loc = op.getLoc(); SmallVector reshaped_operands; @@ -253,10 +253,12 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { auto dynamic_dimensions = llvm::SmallVector( targeted_rank, RankedTensorType::kDynamicSize); - for (Value operand : operands) { + for (auto it : llvm::zip(operands, operand_shapes)) { + Value operand, shape; + std::tie(operand, shape) = it; // Handle shape broadcasting and inference. Value extended_operand_casted = - createBroadcastToKnownRank(if_builder, op, operand, targeted_rank); + createBroadcastToKnownRank(if_builder, op, shape, targeted_rank); // 1. Reshape operands to the given rank (with the same number of // elements) @@ -290,13 +292,37 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { ValueRange operands) { auto loc = op.getLoc(); - // Find the larger rank of the operands. + // Get the minimum broadcast shapes of the operands. + SmallVector shapes; + shapes.reserve(operands.size()); auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); - Value greater_rank; for (Value operand : operands) { Value shape = rewriter.create(loc, extent_tensor_type, operand); + shapes.push_back(shape); + } + auto broadcast_shape = rewriter.create( + loc, extent_tensor_type, shapes, nullptr); + SmallVector result_types(shapes.size(), extent_tensor_type); + auto reduced_shapes = + rewriter + .create(loc, result_types, shapes) + .results(); + SmallVector reshaped_operands; + reshaped_operands.reserve(operands.size()); + for (auto it : llvm::zip(operands, reduced_shapes)) { + Value operand; + Value reduced_shape; + std::tie(operand, reduced_shape) = it; + auto reshaped_operand = rewriter.create( + loc, operand.getType(), operand, reduced_shape); + reshaped_operands.push_back(reshaped_operand); + } + + // Find the largest rank of the operands. + Value greater_rank; + for (Value shape : reduced_shapes) { Value rank = rewriter.create(loc, rewriter.getIndexType(), shape); if (!greater_rank) { @@ -314,17 +340,19 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( rewriter, op, greater_rank, 1); OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1); + createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, + reduced_shapes, 1); // Put each subsequent rank specialization inside the else statement of the // previous one. OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - constexpr int kMaxRankSpecialization = 6; + constexpr int kMaxRankSpecialization = 5; for (int i = 2; i < kMaxRankSpecialization; i++) { auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( else_builder, op, greater_rank, i); if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, operands, i); + createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, + reduced_shapes, i); else_builder.create(loc, inner_if.getResult(0)); else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); } @@ -336,12 +364,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { kMaxRankSpecialization), "Input for dynamic binary op lowering was of a rank greater than " + std::to_string(kMaxRankSpecialization)); - // Add the rank 6 specialization to the innermost else block. - createRankSpecializedBroadcastAndOp(else_builder, op, operands, - kMaxRankSpecialization); + // Add the rank 5 specialization to the innermost else block. + createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands, + reduced_shapes, kMaxRankSpecialization); - // Return the result of the outermost if statement. - return if_op.getResult(0); + // Return the reshaped result of the outermost if statement. + auto result = if_op.getResult(0); + auto reshaped_result = rewriter.create( + loc, result.getType(), result, broadcast_shape); + return reshaped_result; } }; @@ -497,16 +528,17 @@ struct ConvertUnrankedDynamicBroadcastSelectOp struct TransformUnrankedHloPass : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnFunction() override { // Setup conversion target. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor(&target) #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor(&target) diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index 6ae07ce..b3abe71 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -199,20 +199,24 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor +// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor, tensor +// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index +// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index // CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index // Handle rank 1 specialization // CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index -// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> @@ -222,12 +226,12 @@ func @addUnrankedUnranked( // Handle rank 2 specialization // CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32> @@ -237,12 +241,12 @@ func @addUnrankedUnranked( // Handle rank 3 specialization // CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32> @@ -252,47 +256,30 @@ func @addUnrankedUnranked( // Handle rank 4 specialization // CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32> // CHECK-NEXT: } else { // CHECK-NEXT: %[[C5:.*]] = constant 5 : index // CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index +// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]] // Handle rank 5 specialization -// CHECK-NEXT: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[C6:.*]] = constant 6 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C6]] : index -// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]] -// Handle rank 6 specialization -// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor to tensor<6xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor to tensor<6xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32> +// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32> // CHECK-NEXT: } @@ -300,7 +287,8 @@ func @addUnrankedUnranked( // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32> // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_69:.*]] : tensor<*xf32> +// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32> // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32> // CHECK-NEXT: } @@ -325,13 +313,18 @@ func @selectUnrankedUnrankedUnranked( // CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, // CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor -// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor -> index // CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor +// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor +// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor, tensor, tensor +// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor) -> tensor<*xi1> +// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index +// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index // CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index // CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index -// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor -> index // CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %c1 = constant 1 : index @@ -339,15 +332,15 @@ func @selectUnrankedUnrankedUnranked( // Handle rank 1 specialization // CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex> -// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor to tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> @@ -357,4 +350,3 @@ func @selectUnrankedUnrankedUnranked( // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor -// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor