diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 56d8130..2502415 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" @@ -22,6 +24,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" namespace mlir { @@ -74,10 +77,6 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { // - Legal combinations of degenerate (1-dim) implicit broadcasting. // The restriction on broadcast_dims derives from the definition of the // `shape.broadcast` op, which only supports prefix-padding. -// -// It may be possible to expand this pattern to operate on unranked tensors in -// the future by emitting more code to dynamically differentiate based on rank. -// Whether that is of any practical benefit remains to be seen. template struct ConvertRankedDynamicBroadcastBinaryOp : public OpRewritePattern { @@ -160,6 +159,68 @@ struct ConvertRankedDynamicBroadcastBinaryOp } }; +// Converts a broadcasting binary operation with a scalar operand and an +// unranked operand to a ranked broadcasting operation by dynamically reshaping +// the unranked operand to a 1D tensor. This will always be safe because +// broadcasting from a scalar to another shape always works. +template +struct ConvertUnrankedScalarDynamicBroadcastBinaryOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value lhs = op.lhs(); + Value rhs = op.rhs(); + + auto lhs_ranked_type = lhs.getType().dyn_cast(); + auto lhs_unranked_type = lhs.getType().dyn_cast(); + + auto rhs_ranked_type = rhs.getType().dyn_cast(); + auto rhs_unranked_type = rhs.getType().dyn_cast(); + + bool lhs_is_scalar = lhs_ranked_type && + lhs_ranked_type.getShape().empty() && + rhs_unranked_type; + bool rhs_is_scalar = rhs_ranked_type && + rhs_ranked_type.getShape().empty() && + lhs_unranked_type; + + // Only support the case where exactly one operand is scalar and the other + // is unranked. Other patterns in this file will create more efficient + // lowerings for cases where both ranks are known or will handle the more + // generic case of both inputs being unranked. + if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure(); + + auto result_type = op.getResult().getType().template dyn_cast(); + + // Reshape the non-scalar value into a dynamically sized, rank-1 tensor + Value shape = + rewriter.create(loc, lhs_is_scalar ? rhs : lhs); + Value num_elements = rewriter.create(loc, shape); + Value size = rewriter.create(loc, num_elements); + Value size_tensor = rewriter.create(loc, size); + Value reshaped = rewriter.create( + loc, RankedTensorType::get({-1}, result_type.getElementType()), + lhs_is_scalar ? rhs : lhs, size_tensor); + + // Create a new ranked Chlo op that will be further lowered by other + // patterns into Mhlo. + SmallVector operands{lhs_is_scalar ? lhs : reshaped, + rhs_is_scalar ? rhs : reshaped}; + Value computed = rewriter.create( + loc, SmallVector{reshaped.getType()}, operands, op.getAttrs()); + + // Reshape the result back into an unranked tensor. + Value shape_tensor = rewriter.create( + loc, RankedTensorType::get({-1}, rewriter.getIndexType()), shape); + rewriter.replaceOpWithNewOp(op, result_type, + computed, shape_tensor); + + return success(); + } +}; + template void PopulateForBinaryOp(MLIRContext *context, OwningRewritePatternList *patterns) { @@ -169,6 +230,9 @@ void PopulateForBinaryOp(MLIRContext *context, patterns->insert< ConvertRankedDynamicBroadcastBinaryOp>( context, 5); + patterns->insert< + ConvertUnrankedScalarDynamicBroadcastBinaryOp>( + context); } template diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index c7ec85f..bdfdd39 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" @@ -37,6 +38,7 @@ struct TestChloLegalizeToHloPass // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index 20ad579..7782b4d 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -237,3 +237,77 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } + +// ----- +func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor<*xf32>) + -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addScalarUnranked( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32> +// CHECK-SAME: ) -> tensor<*xf32> { +// First handle the dynamic reshaping of the unranked operand +// to a 1D tensor. +// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> +// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] +// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex> +// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// The assuming region is part of the second stage of lowering +// with ranked broadcasting logic. +// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor +// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor +// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]] +// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { +// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape [] +// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]] +// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : tensor<1xindex> +// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor +// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor +// CHECK: } +// As part of the unranked logic, the result is reshaped back +// to an unranked tensor. +// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor +// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32> +// CHECK: } + +// ----- +func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor) + -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @addUnrankedScalar( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor) -> tensor<*xf32> { +// First handle the dynamic reshaping of the unranked operand +// to a 1D tensor. +// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> +// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] +// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex> +// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// The assuming region is part of the second stage of lowering +// with ranked broadcasting logic. +// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor +// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor +// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]] +// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { +// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] : tensor<1xindex> +// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor +// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor +// CHECK: } +// As part of the unranked logic, the result is reshaped back +// to an unranked tensor. +// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_0]] : tensor +// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32> +// CHECK: }