diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 3eecea2..126c7ad 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -270,9 +270,8 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations class HLO_BinaryElementwiseOp traits> : - HLO_Op { + HLO_Op { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 70d6653..6bf23f4 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -77,14 +77,12 @@ struct ShapeOfOpConversion : public OpConversionPattern { }; // We can only move up broadcasting ops that apply to the result of a -// shape-preserving operation. For now, we restrict this to unary operations. -// TODO(frgossen): Generalize this to n-ary operations. +// shape-preserving operation. bool isDynamicBroadcastInDimOpMovable(Value operand) { Operation *producer_op = operand.getDefiningOp(); return producer_op != nullptr && producer_op->hasTrait() && - producer_op->hasTrait() && - producer_op->getNumOperands() == 1; + producer_op->hasTrait(); } // TODO(frgossen): Only move up broadcasting operations if there is a consumer. @@ -96,8 +94,6 @@ struct MoveUpBroadcastInDimOpConversion LogicalResult matchAndRewrite( DynamicBroadcastInDimOp bcast_op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - // We can only move up broadcasting ops that apply to the result of a - // shape-preserving operation. DynamicBroadcastInDimOp::Adaptor transformed(operands); if (!isDynamicBroadcastInDimOpMovable(transformed.operand())) return failure(); diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index 7fe3302..b97f1ab 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -29,10 +29,10 @@ func @shape_of_nary(%arg0 : tensor, %arg1 : tensor) { // ----- -// Broadcasts can be moved up over shape-preserving operations. -// CHECK-LABEL: @bcast +// Broadcasts can be moved up over unary shape-preserving operations. +// CHECK-LABEL: @bcast_unary // CHECK-SAME: (%[[ARG:.*]]: tensor, %[[OUT_DIMS:.*]]: tensor<3xindex>) -func @bcast(%arg : tensor, %out_dims : tensor<3xindex>) +func @bcast_unary(%arg : tensor, %out_dims : tensor<3xindex>) -> tensor { // CHECK: %[[BCASTED_OPERAND:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[OUT_DIMS]]) // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor @@ -46,6 +46,24 @@ func @bcast(%arg : tensor, %out_dims : tensor<3xindex>) // ----- +// Broadcasts can be moved up over n-ary shape-preserving operations. +// CHECK-LABEL: @bcast_nary +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[OUT_DIMS:.*]]: tensor<3xindex>) +func @bcast_nary(%arg0 : tensor, %arg1 : tensor, + %out_dims : tensor<3xindex>) -> tensor { + // CHECK-NOT: subtract + // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[OUT_DIMS]]) + // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[OUT_DIMS]]) + // CHECK: %{{.*}} = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] : tensor + %0 = mhlo.subtract %arg0, %arg1 : tensor + %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : + (tensor, tensor<3xindex>) -> tensor + return %1 : tensor +} + +// ----- + // Exemplary IR as it appears in the lowering with `tf.Sub` and `tf.Cast`. // CHECK-LABEL: @cast_sub // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> tensor