[MLIR][HLO] Move broadcasts over n-ary shape-preserving ops
This will open up more fusion opportunities. PiperOrigin-RevId: 364577231
This commit is contained in:
parent
5bc4bf0834
commit
8987dfd1d6
|
@ -270,9 +270,8 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
|
|||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||
|
||||
class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||
HLO_Op<mnemonic, !listconcat(traits,
|
||||
[InferShapedTypeOpInterface, InferFusibilityOpInterface,
|
||||
SameOperandsAndResultShape])> {
|
||||
HLO_Op<mnemonic, traits # [InferShapedTypeOpInterface,
|
||||
InferFusibilityOpInterface, SameOperandsAndResultShape, Elementwise]> {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs
|
||||
|
|
|
@ -77,14 +77,12 @@ struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> {
|
|||
};
|
||||
|
||||
// We can only move up broadcasting ops that apply to the result of a
|
||||
// shape-preserving operation. For now, we restrict this to unary operations.
|
||||
// TODO(frgossen): Generalize this to n-ary operations.
|
||||
// shape-preserving operation.
|
||||
bool isDynamicBroadcastInDimOpMovable(Value operand) {
|
||||
Operation *producer_op = operand.getDefiningOp();
|
||||
return producer_op != nullptr &&
|
||||
producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() &&
|
||||
producer_op->hasTrait<OpTrait::Elementwise>() &&
|
||||
producer_op->getNumOperands() == 1;
|
||||
producer_op->hasTrait<OpTrait::Elementwise>();
|
||||
}
|
||||
|
||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||
|
@ -96,8 +94,6 @@ struct MoveUpBroadcastInDimOpConversion
|
|||
LogicalResult matchAndRewrite(
|
||||
DynamicBroadcastInDimOp bcast_op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// We can only move up broadcasting ops that apply to the result of a
|
||||
// shape-preserving operation.
|
||||
DynamicBroadcastInDimOp::Adaptor transformed(operands);
|
||||
if (!isDynamicBroadcastInDimOpMovable(transformed.operand()))
|
||||
return failure();
|
||||
|
|
|
@ -29,10 +29,10 @@ func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) {
|
|||
|
||||
// -----
|
||||
|
||||
// Broadcasts can be moved up over shape-preserving operations.
|
||||
// CHECK-LABEL: @bcast
|
||||
// Broadcasts can be moved up over unary shape-preserving operations.
|
||||
// CHECK-LABEL: @bcast_unary
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
|
||||
func @bcast(%arg : tensor<?x32xi16>, %out_dims : tensor<3xindex>)
|
||||
func @bcast_unary(%arg : tensor<?x32xi16>, %out_dims : tensor<3xindex>)
|
||||
-> tensor<?x?x32xf16> {
|
||||
// CHECK: %[[BCASTED_OPERAND:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[OUT_DIMS]])
|
||||
// CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xi16>, tensor<3xindex>) -> tensor<?x?x32xi16>
|
||||
|
@ -46,6 +46,24 @@ func @bcast(%arg : tensor<?x32xi16>, %out_dims : tensor<3xindex>)
|
|||
|
||||
// -----
|
||||
|
||||
// Broadcasts can be moved up over n-ary shape-preserving operations.
|
||||
// CHECK-LABEL: @bcast_nary
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf32>, %[[ARG1:.*]]: tensor<?x32xf32>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
|
||||
func @bcast_nary(%arg0 : tensor<?x32xf32>, %arg1 : tensor<?x32xf32>,
|
||||
%out_dims : tensor<3xindex>) -> tensor<?x?x32xf32> {
|
||||
// CHECK-NOT: subtract
|
||||
// CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[OUT_DIMS]])
|
||||
// CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[OUT_DIMS]])
|
||||
// CHECK: %{{.*}} = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] : tensor<?x?x32xf32>
|
||||
%0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf32>
|
||||
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) {
|
||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } :
|
||||
(tensor<?x32xf32>, tensor<3xindex>) -> tensor<?x?x32xf32>
|
||||
return %1 : tensor<?x?x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Exemplary IR as it appears in the lowering with `tf.Sub` and `tf.Cast`.
|
||||
// CHECK-LABEL: @cast_sub
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xi16>, %[[ARG1:.*]]: tensor<?x?x32xf16>) -> tensor<?x?x32xf16>
|
||||
|
|
Loading…
Reference in New Issue