Generalize the HloBinaryElementwiseAdaptor

We can use it also for ternary ops like Select if we change the signature so
that a ValueRange is passed in.
Also remove special casing for HloComplexAdaptor. It can be handled with the
generic adaptor as well.

PiperOrigin-RevId: 365777493
This commit is contained in:
Adrian Kuegel 2021-03-30 03:53:04 -07:00 committed by TensorFlow MLIR Team
parent 6388e8d9ee
commit c1a6ae8994
3 changed files with 16 additions and 27 deletions

View File

@ -25,30 +25,22 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace chlo { namespace chlo {
struct HloComplexAdaptor {
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
template <typename FromOpTy, typename ToOpTy> template <typename FromOpTy, typename ToOpTy>
struct HloBinaryElementwiseAdaptor { struct HloNaryElementwiseAdaptor {
static ToOpTy CreateOp(FromOpTy from_op, Type result_type, static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs, ValueRange broadcasted_operands, OpBuilder &builder) {
OpBuilder &builder) {
return builder.create<ToOpTy>(from_op.getLoc(), result_type, return builder.create<ToOpTy>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs); broadcasted_operands);
} }
}; };
struct HloCompareAdaptor { struct HloCompareAdaptor {
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs, ValueRange broadcasted_operands,
OpBuilder &builder) { OpBuilder &builder) {
return builder.create<mhlo::CompareOp>( return builder.create<mhlo::CompareOp>(
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, from_op.getLoc(), result_type, broadcasted_operands[0],
from_op.comparison_direction(), from_op.compare_typeAttr()); broadcasted_operands[1], from_op.comparison_direction(),
from_op.compare_typeAttr());
} }
}; };
@ -59,14 +51,15 @@ template <template <typename, typename, typename> class Pattern,
void PopulateForBroadcastingBinaryOp(MLIRContext *context, void PopulateForBroadcastingBinaryOp(MLIRContext *context,
OwningRewritePatternList *patterns, OwningRewritePatternList *patterns,
ConstructorArgs &&...args) { ConstructorArgs &&...args) {
#define POPULATE_BCAST(ChloOp, HloOp) \ #define POPULATE_BCAST(ChloOp, HloOp) \
patterns->insert< \ patterns->insert< \
Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \ Pattern<ChloOp, HloOp, HloNaryElementwiseAdaptor<ChloOp, HloOp>>>( \
context, args...); context, args...);
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp); POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp); POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op); POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
POPULATE_BCAST(BroadcastComplexOp, mhlo::ComplexOp);
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp); POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp); POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp); POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
@ -83,9 +76,6 @@ void PopulateForBroadcastingBinaryOp(MLIRContext *context,
POPULATE_BCAST(BroadcastZetaOp, ZetaOp); POPULATE_BCAST(BroadcastZetaOp, ZetaOp);
// Broadcasting ops requiring special construction. // Broadcasting ops requiring special construction.
patterns
->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
context, args...);
patterns patterns
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>( ->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
context, args...); context, args...);

View File

@ -1119,9 +1119,8 @@ struct ConvertTrivialNonBroadcastBinaryOp
} }
} }
rewriter.replaceOp( rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0], operands, rewriter)});
operands[1], rewriter)});
return success(); return success();
} }
}; };
@ -1214,8 +1213,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp
rewriter.getI64TensorAttr(rhs_broadcast_dimensions)); rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
// And generate the final non-broadcasted binary op. // And generate the final non-broadcasted binary op.
Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs, Value final_result = Adaptor::CreateOp(
broadcasted_rhs, rewriter); op, result_type, {broadcasted_lhs, broadcasted_rhs}, rewriter);
rewriter.create<shape::AssumingYieldOp>(loc, final_result); rewriter.create<shape::AssumingYieldOp>(loc, final_result);
rewriter.replaceOp(op, {assuming_op.getResult(0)}); rewriter.replaceOp(op, {assuming_op.getResult(0)});
return success(); return success();

View File

@ -463,7 +463,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
OpBuilder if_eq_shapes_builder = OpBuilder if_eq_shapes_builder =
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener()); if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
Value non_broadcast_op = Value non_broadcast_op =
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder); Adaptor::CreateOp(op, result_type, {lhs, rhs}, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op); if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
// If shapes do not have exactly one element, nor are equal // If shapes do not have exactly one element, nor are equal