Generalize the HloBinaryElementwiseAdaptor
We can use it also for ternary ops like Select if we change the signature so that a ValueRange is passed in. Also remove special casing for HloComplexAdaptor. It can be handled with the generic adaptor as well. PiperOrigin-RevId: 365777493
This commit is contained in:
parent
6388e8d9ee
commit
c1a6ae8994
|
@ -25,30 +25,22 @@ limitations under the License.
|
|||
namespace mlir {
|
||||
namespace chlo {
|
||||
|
||||
struct HloComplexAdaptor {
|
||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
template <typename FromOpTy, typename ToOpTy>
|
||||
struct HloBinaryElementwiseAdaptor {
|
||||
struct HloNaryElementwiseAdaptor {
|
||||
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
ValueRange broadcasted_operands, OpBuilder &builder) {
|
||||
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
broadcasted_operands);
|
||||
}
|
||||
};
|
||||
struct HloCompareAdaptor {
|
||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
ValueRange broadcasted_operands,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::CompareOp>(
|
||||
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction(), from_op.compare_typeAttr());
|
||||
from_op.getLoc(), result_type, broadcasted_operands[0],
|
||||
broadcasted_operands[1], from_op.comparison_direction(),
|
||||
from_op.compare_typeAttr());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -59,14 +51,15 @@ template <template <typename, typename, typename> class Pattern,
|
|||
void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns,
|
||||
ConstructorArgs &&...args) {
|
||||
#define POPULATE_BCAST(ChloOp, HloOp) \
|
||||
patterns->insert< \
|
||||
Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \
|
||||
#define POPULATE_BCAST(ChloOp, HloOp) \
|
||||
patterns->insert< \
|
||||
Pattern<ChloOp, HloOp, HloNaryElementwiseAdaptor<ChloOp, HloOp>>>( \
|
||||
context, args...);
|
||||
|
||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastComplexOp, mhlo::ComplexOp);
|
||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||
|
@ -83,9 +76,6 @@ void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
|||
POPULATE_BCAST(BroadcastZetaOp, ZetaOp);
|
||||
|
||||
// Broadcasting ops requiring special construction.
|
||||
patterns
|
||||
->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
|
||||
context, args...);
|
||||
patterns
|
||||
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
|
||||
context, args...);
|
||||
|
|
|
@ -1119,9 +1119,8 @@ struct ConvertTrivialNonBroadcastBinaryOp
|
|||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(
|
||||
op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
|
||||
operands[1], rewriter)});
|
||||
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
|
||||
operands, rewriter)});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1214,8 +1213,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
|||
rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
|
||||
|
||||
// And generate the final non-broadcasted binary op.
|
||||
Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs,
|
||||
broadcasted_rhs, rewriter);
|
||||
Value final_result = Adaptor::CreateOp(
|
||||
op, result_type, {broadcasted_lhs, broadcasted_rhs}, rewriter);
|
||||
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
|
||||
rewriter.replaceOp(op, {assuming_op.getResult(0)});
|
||||
return success();
|
||||
|
|
|
@ -463,7 +463,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
OpBuilder if_eq_shapes_builder =
|
||||
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
||||
Value non_broadcast_op =
|
||||
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
||||
Adaptor::CreateOp(op, result_type, {lhs, rhs}, if_eq_shapes_builder);
|
||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||
|
||||
// If shapes do not have exactly one element, nor are equal
|
||||
|
|
Loading…
Reference in New Issue