Generalize the HloBinaryElementwiseAdaptor
We can use it also for ternary ops like Select if we change the signature so that a ValueRange is passed in. Also remove special casing for HloComplexAdaptor. It can be handled with the generic adaptor as well. PiperOrigin-RevId: 365777493
This commit is contained in:
parent
6388e8d9ee
commit
c1a6ae8994
|
@ -25,30 +25,22 @@ limitations under the License.
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace chlo {
|
namespace chlo {
|
||||||
|
|
||||||
struct HloComplexAdaptor {
|
|
||||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
|
||||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
|
||||||
OpBuilder &builder) {
|
|
||||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
|
||||||
broadcasted_lhs, broadcasted_rhs);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
template <typename FromOpTy, typename ToOpTy>
|
template <typename FromOpTy, typename ToOpTy>
|
||||||
struct HloBinaryElementwiseAdaptor {
|
struct HloNaryElementwiseAdaptor {
|
||||||
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
||||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
ValueRange broadcasted_operands, OpBuilder &builder) {
|
||||||
OpBuilder &builder) {
|
|
||||||
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
||||||
broadcasted_lhs, broadcasted_rhs);
|
broadcasted_operands);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct HloCompareAdaptor {
|
struct HloCompareAdaptor {
|
||||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
ValueRange broadcasted_operands,
|
||||||
OpBuilder &builder) {
|
OpBuilder &builder) {
|
||||||
return builder.create<mhlo::CompareOp>(
|
return builder.create<mhlo::CompareOp>(
|
||||||
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
|
from_op.getLoc(), result_type, broadcasted_operands[0],
|
||||||
from_op.comparison_direction(), from_op.compare_typeAttr());
|
broadcasted_operands[1], from_op.comparison_direction(),
|
||||||
|
from_op.compare_typeAttr());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -61,12 +53,13 @@ void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
||||||
ConstructorArgs &&...args) {
|
ConstructorArgs &&...args) {
|
||||||
#define POPULATE_BCAST(ChloOp, HloOp) \
|
#define POPULATE_BCAST(ChloOp, HloOp) \
|
||||||
patterns->insert< \
|
patterns->insert< \
|
||||||
Pattern<ChloOp, HloOp, HloBinaryElementwiseAdaptor<ChloOp, HloOp>>>( \
|
Pattern<ChloOp, HloOp, HloNaryElementwiseAdaptor<ChloOp, HloOp>>>( \
|
||||||
context, args...);
|
context, args...);
|
||||||
|
|
||||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||||
|
POPULATE_BCAST(BroadcastComplexOp, mhlo::ComplexOp);
|
||||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||||
|
@ -83,9 +76,6 @@ void PopulateForBroadcastingBinaryOp(MLIRContext *context,
|
||||||
POPULATE_BCAST(BroadcastZetaOp, ZetaOp);
|
POPULATE_BCAST(BroadcastZetaOp, ZetaOp);
|
||||||
|
|
||||||
// Broadcasting ops requiring special construction.
|
// Broadcasting ops requiring special construction.
|
||||||
patterns
|
|
||||||
->insert<Pattern<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>>(
|
|
||||||
context, args...);
|
|
||||||
patterns
|
patterns
|
||||||
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
|
->insert<Pattern<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>>(
|
||||||
context, args...);
|
context, args...);
|
||||||
|
|
|
@ -1119,9 +1119,8 @@ struct ConvertTrivialNonBroadcastBinaryOp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(
|
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
|
||||||
op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
|
operands, rewriter)});
|
||||||
operands[1], rewriter)});
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1214,8 +1213,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||||
rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
|
rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
|
||||||
|
|
||||||
// And generate the final non-broadcasted binary op.
|
// And generate the final non-broadcasted binary op.
|
||||||
Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs,
|
Value final_result = Adaptor::CreateOp(
|
||||||
broadcasted_rhs, rewriter);
|
op, result_type, {broadcasted_lhs, broadcasted_rhs}, rewriter);
|
||||||
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
|
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
|
||||||
rewriter.replaceOp(op, {assuming_op.getResult(0)});
|
rewriter.replaceOp(op, {assuming_op.getResult(0)});
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -463,7 +463,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
OpBuilder if_eq_shapes_builder =
|
OpBuilder if_eq_shapes_builder =
|
||||||
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
|
||||||
Value non_broadcast_op =
|
Value non_broadcast_op =
|
||||||
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
Adaptor::CreateOp(op, result_type, {lhs, rhs}, if_eq_shapes_builder);
|
||||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||||
|
|
||||||
// If shapes do not have exactly one element, nor are equal
|
// If shapes do not have exactly one element, nor are equal
|
||||||
|
|
Loading…
Reference in New Issue