From 824bc9c4259f9fad0ada9c4fd0a3eae207099a14 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Sun, 14 Feb 2021 23:24:45 -0800 Subject: [PATCH] Improve broadcast transformation to treat dynamic shapes with 1 element as scalar. A shape that contains exactly one element is effectively a scalar. This leads to a speedup in cases where we have a binary op with one operand that is effectively a scalar, because we can use the fast path. PiperOrigin-RevId: 357515552 --- .../mhlo/transforms/transform_unranked_hlo.cc | 125 ++++++++++-------- tests/hlo-transform-unranked.mlir | 33 ++--- 2 files changed, 91 insertions(+), 67 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 3b52288..7c47b6f 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -230,46 +230,54 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp // pattern will handle the lowering. if (!lhs_type || !rhs_type) return failure(); - // If lhs is scalar + Value shape_of_lhs = rewriter.create(loc, lhs); + Value shape_of_rhs = rewriter.create(loc, rhs); + + // If lhs has exactly one element auto if_op = rewriter.create( - loc, result_type, IsScalarTensor(rewriter, op, lhs), true); + loc, result_type, IsSingleElementShape(rewriter, op, shape_of_lhs), + true); OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder(rewriter.getListener()); - Value reshaped_lhs = if_lhs_scalar_builder.create( + Value reshaped_lhs = if_lhs_scalar_builder.create( loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); Value if_lhs_scalar_result = if_lhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, rhs}, op.getAttrs()); - if_lhs_scalar_builder.create(loc, if_lhs_scalar_result); + Value extended_if_lhs_scalar_result = + extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result, + shape_of_lhs, shape_of_rhs); + if_lhs_scalar_builder.create(loc, + extended_if_lhs_scalar_result); - // If lhs is NOT scalar + // If lhs does not have exactly one element // - // See if rhs is scalar + // See if rhs has exactly one element OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder(rewriter.getListener()); auto if_rhs_scalar_op = else_lhs_scalar_builder.create( - loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs), - true); + loc, result_type, + IsSingleElementShape(else_lhs_scalar_builder, op, shape_of_rhs), true); else_lhs_scalar_builder.create(loc, if_rhs_scalar_op.getResult(0)); OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener()); - Value reshaped_rhs = if_rhs_scalar_builder.create( - loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); + Value reshaped_rhs = if_rhs_scalar_builder.create( + loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs); Value if_rhs_scalar_result = if_rhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{lhs, reshaped_rhs}, op.getAttrs()); - if_rhs_scalar_builder.create(loc, if_rhs_scalar_result); + Value extended_if_rhs_scalar_result = + extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result, + shape_of_lhs, shape_of_rhs); + if_rhs_scalar_builder.create(loc, + extended_if_rhs_scalar_result); - // If NEITHER shape is scalar + // If NEITHER shape has exactly one element // // See if shapes are equal. OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener()); - Value shape_of_lhs = - else_no_scalars_builder.create(loc, lhs); - Value shape_of_rhs = - else_no_scalars_builder.create(loc, rhs); Value equal_shapes = else_no_scalars_builder.create( loc, shape_of_lhs, shape_of_rhs); @@ -284,7 +292,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder); if_eq_shapes_builder.create(loc, non_broadcast_op); - // If shapes are not scalar, nor equal + // If shapes do not have exactly one element, nor are equal // // See if values are of a rank that we support. OpBuilder if_neq_shapes_builder = @@ -297,16 +305,17 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp } private: - // Returns the dynamic result of checking the given value is a scalar tensor. - Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const { + // Returns the dynamic result of checking the given value is effectively a + // scalar shape (i.e. the number of elements is 1). + Value IsSingleElementShape(OpBuilder &rewriter, ChloOpTy op, + Value shape_of_tensor) const { auto loc = op.getLoc(); - Value shape_of_tensor = rewriter.create(loc, tensor); - Value rank_tensor = rewriter.create( - loc, rewriter.getIndexType(), shape_of_tensor); + Value num_elements = + rewriter.create(loc, shape_of_tensor); return rewriter.create(loc, rewriter.getI1Type(), CmpIPredicate::eq, - rank_tensor, - rewriter.create(loc, 0)); + num_elements, + rewriter.create(loc, 1)); } Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank, @@ -326,6 +335,36 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp greater_rank_is_n, true); } + Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value, + Value shape_of_lhs, Value shape_of_rhs) const { + auto unknown_rank_extent_tensor_type = RankedTensorType::get( + {RankedTensorType::kDynamicSize}, builder.getIndexType()); + Value broadcast_shape = + builder.create(loc, unknown_rank_extent_tensor_type, + shape_of_lhs, shape_of_rhs, nullptr); + return builder.create(loc, value.getType(), value, + broadcast_shape); + } + + Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value, + int targeted_rank) const { + auto loc = op.getLoc(); + Value shape = builder.create(loc, value); + SmallVector ranked_shape(targeted_rank, 1); + auto unknown_rank_extent_tensor_type = RankedTensorType::get( + {RankedTensorType::kDynamicSize}, builder.getIndexType()); + auto known_rank_extent_tensor_type = + RankedTensorType::get({targeted_rank}, builder.getIndexType()); + Value ranked_shape_val = builder.create( + loc, known_rank_extent_tensor_type, + mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, + ranked_shape)); + Value extended_value = builder.create( + loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); + return builder.create(loc, known_rank_extent_tensor_type, + extended_value); + } + // Create the if statement and code for a broadcasting op with a result of a // given rank. void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op, @@ -333,32 +372,16 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp int targeted_rank) const { auto loc = op.getLoc(); - // Handle shape broadcasting and inferrence. - Value lhs_shape = if_builder.create(loc, lhs); - Value rhs_shape = if_builder.create(loc, rhs); - SmallVector ranked_shape(targeted_rank, 1); - auto unknown_rank_extent_tensor_type = RankedTensorType::get( - {RankedTensorType::kDynamicSize}, if_builder.getIndexType()); - auto known_rank_extent_tensor_type = - RankedTensorType::get({targeted_rank}, if_builder.getIndexType()); + // Handle shape broadcasting and inference. + Value extended_lhs_casted = + createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank); + Value extended_rhs_casted = + createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank); + auto dynamic_dimensions = llvm::SmallVector( + targeted_rank, RankedTensorType::kDynamicSize); auto reshaped_type = RankedTensorType::get( - llvm::SmallVector(targeted_rank, - RankedTensorType::kDynamicSize), + dynamic_dimensions, lhs.getType().template dyn_cast().getElementType()); - Value ranked_shape_val = if_builder.create( - loc, known_rank_extent_tensor_type, - mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, - ranked_shape)); - Value extended_lhs = if_builder.create( - loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val, - nullptr); - Value extended_lhs_casted = if_builder.create( - loc, known_rank_extent_tensor_type, extended_lhs); - Value extended_rhs = if_builder.create( - loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val, - nullptr); - Value extended_rhs_casted = if_builder.create( - loc, known_rank_extent_tensor_type, extended_rhs); // 1. Reshape operands to the given rank (with the same number of elements) // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops @@ -372,10 +395,8 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp .getType() .template dyn_cast() .getElementType(); - auto result_type = RankedTensorType::get( - llvm::SmallVector(targeted_rank, - RankedTensorType::kDynamicSize), - result_element_type); + auto result_type = + RankedTensorType::get(dynamic_dimensions, result_element_type); Value result = if_builder.create( loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, reshaped_rhs}, op.getAttrs()); diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index d18df73..a074763 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -158,32 +158,34 @@ func @addUnrankedUnranked( // CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, // CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[C0:.*]] = constant 0 : index -// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[LHS_RANK]], %[[C0]] : index +// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor +// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[C1:.*]] = constant 1 : index +// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index // Handle scalar LHS case // CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor.cast %[[LHS]] : tensor<*xf32> to tensor -// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor -> index +// CHECK-NEXT: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor +// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor -> index // CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex> // CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32> +// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[SHAPE_BROADCAST_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor +// CHECK-NEXT: %[[RESHAPED_EXTENDED_LHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_LHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_LHS]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_LHS_RESULT]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[RHS_RANK]], %[[C0]] : index +// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index // Handle scalar RHS case // CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor.cast %[[RHS]] : tensor<*xf32> to tensor -// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor // CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex> // CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32> +// CHECK-NEXT: %[[SHAPE_BROADCAST_RHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor +// CHECK-NEXT: %[[RESHAPED_EXTENDED_RHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_RHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_RHS]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RHS_RESULT]] : tensor<*xf32> // CHECK-NEXT: } else { // CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor // Handle equal shapes case @@ -197,10 +199,11 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> // CHECK-NEXT: } else { +// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index // CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index // Handle rank 1 specialization -// CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index // CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]