Generalize the HloBinaryElementwiseAdaptor

We can use it also for ternary ops like Select if we change the signature so
that a ValueRange is passed in.
Also remove special casing for HloComplexAdaptor. It can be handled with the
generic adaptor as well.

PiperOrigin-RevId: 365777493
This commit is contained in:
Adrian Kuegel 2021-03-30 03:53:04 -07:00 committed by TensorFlow MLIR Team
parent 6388e8d9ee
commit c1a6ae8994
3 changed files with 16 additions and 27 deletions

View File

@ -25,30 +25,22 @@ limitations under the License.
namespace mlir {
namespace chlo {
struct HloComplexAdaptor {
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
template <typename FromOpTy, typename ToOpTy>
struct HloBinaryElementwiseAdaptor {
struct HloNaryElementwiseAdaptor {
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
ValueRange broadcasted_operands, OpBuilder &builder) {
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
broadcasted_operands);
}
};
struct HloCompareAdaptor {
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
ValueRange broadcasted_operands,
OpBuilder &builder) {
return builder.create<mhlo::CompareOp>(
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction(), from_op.compare_typeAttr());
from_op.getLoc(), result_type, broadcasted_operands[0],
broadcasted_operands[1], from_op.comparison_direction(),
from_op.compare_typeAttr());
}
};
@ -59,14 +51,15 @@ template <template <typename, typename, typename> class Pattern,
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
OwningRewritePatternList *patterns,
ConstructorArgs &&...args) {
#define POPULATE_BCAST(ChloOp, HloOp) \
patterns->insert< \
Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \
#define POPULATE_BCAST(ChloOp, HloOp) \
patterns->insert< \
Pattern<ChloOp, HloOp, HloNaryElementwiseAdaptor<ChloOp, HloOp>>>( \
context, args...);
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
POPULATE_BCAST(BroadcastComplexOp, mhlo::ComplexOp);
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
@ -83,9 +76,6 @@ void PopulateForBroadcastingBinaryOp(MLIRContext *context,
POPULATE_BCAST(BroadcastZetaOp, ZetaOp);
// Broadcasting ops requiring special construction.
patterns
->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
context, args...);
patterns
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
context, args...);

View File

@ -1119,9 +1119,8 @@ struct ConvertTrivialNonBroadcastBinaryOp
}
}
rewriter.replaceOp(
op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
operands[1], rewriter)});
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
operands, rewriter)});
return success();
}
};
@ -1214,8 +1213,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp
rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
// And generate the final non-broadcasted binary op.
Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs,
broadcasted_rhs, rewriter);
Value final_result = Adaptor::CreateOp(
op, result_type, {broadcasted_lhs, broadcasted_rhs}, rewriter);
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
rewriter.replaceOp(op, {assuming_op.getResult(0)});
return success();

View File

@ -463,7 +463,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
OpBuilder if_eq_shapes_builder =
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
Value non_broadcast_op =
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
Adaptor::CreateOp(op, result_type, {lhs, rhs}, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
// If shapes do not have exactly one element, nor are equal