Support CHLO broadcasting operations between scalar and unranked tensors.
This is done through reshaping the unranked tensor into a 1D ranked tensor which will result in a safe broadcast/indexing logic when the other operand is a scalar. PiperOrigin-RevId: 322553661
This commit is contained in:
parent
63d62b7952
commit
4251630426
|
@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
|
||||||
|
@ -22,6 +24,7 @@ limitations under the License.
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -74,10 +77,6 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||||
// - Legal combinations of degenerate (1-dim) implicit broadcasting.
|
// - Legal combinations of degenerate (1-dim) implicit broadcasting.
|
||||||
// The restriction on broadcast_dims derives from the definition of the
|
// The restriction on broadcast_dims derives from the definition of the
|
||||||
// `shape.broadcast` op, which only supports prefix-padding.
|
// `shape.broadcast` op, which only supports prefix-padding.
|
||||||
//
|
|
||||||
// It may be possible to expand this pattern to operate on unranked tensors in
|
|
||||||
// the future by emitting more code to dynamically differentiate based on rank.
|
|
||||||
// Whether that is of any practical benefit remains to be seen.
|
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
struct ConvertRankedDynamicBroadcastBinaryOp
|
struct ConvertRankedDynamicBroadcastBinaryOp
|
||||||
: public OpRewritePattern<ChloOpTy> {
|
: public OpRewritePattern<ChloOpTy> {
|
||||||
|
@ -160,6 +159,68 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Converts a broadcasting binary operation with a scalar operand and an
|
||||||
|
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
||||||
|
// the unranked operand to a 1D tensor. This will always be safe because
|
||||||
|
// broadcasting from a scalar to another shape always works.
|
||||||
|
template <typename ChloOpTy, typename HloOpTy>
|
||||||
|
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
|
: public OpRewritePattern<ChloOpTy> {
|
||||||
|
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value lhs = op.lhs();
|
||||||
|
Value rhs = op.rhs();
|
||||||
|
|
||||||
|
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||||
|
|
||||||
|
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||||
|
|
||||||
|
bool lhs_is_scalar = lhs_ranked_type &&
|
||||||
|
lhs_ranked_type.getShape().empty() &&
|
||||||
|
rhs_unranked_type;
|
||||||
|
bool rhs_is_scalar = rhs_ranked_type &&
|
||||||
|
rhs_ranked_type.getShape().empty() &&
|
||||||
|
lhs_unranked_type;
|
||||||
|
|
||||||
|
// Only support the case where exactly one operand is scalar and the other
|
||||||
|
// is unranked. Other patterns in this file will create more efficient
|
||||||
|
// lowerings for cases where both ranks are known or will handle the more
|
||||||
|
// generic case of both inputs being unranked.
|
||||||
|
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
||||||
|
|
||||||
|
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||||
|
|
||||||
|
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
||||||
|
Value shape =
|
||||||
|
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
||||||
|
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||||
|
Value size = rewriter.create<shape::SizeToIndexOp>(loc, num_elements);
|
||||||
|
Value size_tensor = rewriter.create<TensorFromElementsOp>(loc, size);
|
||||||
|
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
||||||
|
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||||
|
|
||||||
|
// Create a new ranked Chlo op that will be further lowered by other
|
||||||
|
// patterns into Mhlo.
|
||||||
|
SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped,
|
||||||
|
rhs_is_scalar ? rhs : reshaped};
|
||||||
|
Value computed = rewriter.create<ChloOpTy>(
|
||||||
|
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
|
||||||
|
|
||||||
|
// Reshape the result back into an unranked tensor.
|
||||||
|
Value shape_tensor = rewriter.create<shape::ToExtentTensorOp>(
|
||||||
|
loc, RankedTensorType::get({-1}, rewriter.getIndexType()), shape);
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
||||||
|
computed, shape_tensor);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
void PopulateForBinaryOp(MLIRContext *context,
|
void PopulateForBinaryOp(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
|
@ -169,6 +230,9 @@ void PopulateForBinaryOp(MLIRContext *context,
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||||
context, 5);
|
context, 5);
|
||||||
|
patterns->insert<
|
||||||
|
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>>(
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename FromOpTy, typename ToOpTy>
|
template <typename FromOpTy, typename ToOpTy>
|
||||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
@ -37,6 +38,7 @@ struct TestChloLegalizeToHloPass
|
||||||
// The conversion uses helpers from the Standard dialect.
|
// The conversion uses helpers from the Standard dialect.
|
||||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||||
|
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
||||||
|
|
||||||
PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
||||||
|
|
||||||
|
|
|
@ -237,3 +237,77 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
||||||
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||||
return %0 : tensor<4xi1>
|
return %0 : tensor<4xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
|
||||||
|
-> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @addScalarUnranked(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
|
||||||
|
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
|
||||||
|
// CHECK-SAME: ) -> tensor<*xf32> {
|
||||||
|
// First handle the dynamic reshaping of the unranked operand
|
||||||
|
// to a 1D tensor.
|
||||||
|
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
||||||
|
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]]
|
||||||
|
// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||||
|
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex>
|
||||||
|
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// The assuming region is part of the second stage of lowering
|
||||||
|
// with ranked broadcasting logic.
|
||||||
|
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32>
|
||||||
|
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
||||||
|
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
|
||||||
|
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||||
|
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
|
||||||
|
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
|
||||||
|
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : tensor<1xindex>
|
||||||
|
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||||
|
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// As part of the unranked logic, the result is reshaped back
|
||||||
|
// to an unranked tensor.
|
||||||
|
// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor<?xindex>
|
||||||
|
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
||||||
|
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
|
||||||
|
-> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @addUnrankedScalar(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
|
||||||
|
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
|
||||||
|
// First handle the dynamic reshaping of the unranked operand
|
||||||
|
// to a 1D tensor.
|
||||||
|
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
||||||
|
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]]
|
||||||
|
// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||||
|
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex>
|
||||||
|
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// The assuming region is part of the second stage of lowering
|
||||||
|
// with ranked broadcasting logic.
|
||||||
|
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
||||||
|
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
|
||||||
|
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
|
||||||
|
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||||
|
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] : tensor<1xindex>
|
||||||
|
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||||
|
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// As part of the unranked logic, the result is reshaped back
|
||||||
|
// to an unranked tensor.
|
||||||
|
// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_0]] : tensor<?xindex>
|
||||||
|
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||||
|
// CHECK: }
|
||||||
|
|
Loading…
Reference in New Issue