diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h index d9e637d..2b061c0 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h @@ -25,30 +25,22 @@ limitations under the License. namespace mlir { namespace chlo { -struct HloComplexAdaptor { - static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, - Value broadcasted_lhs, Value broadcasted_rhs, - OpBuilder &builder) { - return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs); - } -}; template -struct HloBinaryElementwiseAdaptor { +struct HloNaryElementwiseAdaptor { static ToOpTy CreateOp(FromOpTy from_op, Type result_type, - Value broadcasted_lhs, Value broadcasted_rhs, - OpBuilder &builder) { + ValueRange broadcasted_operands, OpBuilder &builder) { return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs); + broadcasted_operands); } }; struct HloCompareAdaptor { static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, - Value broadcasted_lhs, Value broadcasted_rhs, + ValueRange broadcasted_operands, OpBuilder &builder) { return builder.create( - from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, - from_op.comparison_direction(), from_op.compare_typeAttr()); + from_op.getLoc(), result_type, broadcasted_operands[0], + broadcasted_operands[1], from_op.comparison_direction(), + from_op.compare_typeAttr()); } }; @@ -59,14 +51,15 @@ template