From c1a6ae89945036e51fda2785c051f96ac1ddc5b9 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 30 Mar 2021 03:53:04 -0700 Subject: [PATCH] Generalize the HloBinaryElementwiseAdaptor We can use it also for ternary ops like Select if we change the signature so that a ValueRange is passed in. Also remove special casing for HloComplexAdaptor. It can be handled with the generic adaptor as well. PiperOrigin-RevId: 365777493 --- .../mhlo/transforms/map_chlo_to_hlo_op.h | 32 +++++++------------ .../mhlo/transforms/chlo_legalize_to_hlo.cc | 9 +++--- .../mhlo/transforms/transform_unranked_hlo.cc | 2 +- 3 files changed, 16 insertions(+), 27 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h index d9e637d..2b061c0 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h @@ -25,30 +25,22 @@ limitations under the License. namespace mlir { namespace chlo { -struct HloComplexAdaptor { - static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, - Value broadcasted_lhs, Value broadcasted_rhs, - OpBuilder &builder) { - return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs); - } -}; template -struct HloBinaryElementwiseAdaptor { +struct HloNaryElementwiseAdaptor { static ToOpTy CreateOp(FromOpTy from_op, Type result_type, - Value broadcasted_lhs, Value broadcasted_rhs, - OpBuilder &builder) { + ValueRange broadcasted_operands, OpBuilder &builder) { return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs); + broadcasted_operands); } }; struct HloCompareAdaptor { static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, - Value broadcasted_lhs, Value broadcasted_rhs, + ValueRange broadcasted_operands, OpBuilder &builder) { return builder.create( - from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, - from_op.comparison_direction(), from_op.compare_typeAttr()); + from_op.getLoc(), result_type, broadcasted_operands[0], + broadcasted_operands[1], from_op.comparison_direction(), + from_op.compare_typeAttr()); } }; @@ -59,14 +51,15 @@ template