Forward listeners in LhloLegalizeToParallelLoops builders

PiperOrigin-RevId: 347554379
This commit is contained in:
Tres Popp 2020-12-15 00:12:32 -08:00 committed by TensorFlow MLIR Team
parent 1a58f19664
commit 9df327d88f
1 changed files with 12 additions and 8 deletions

View File

@ -437,12 +437,14 @@ class ReduceWindowOpConverter
loc, operand_type.getElementType(), mapped_ivs.in_bounds,
/*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
OpBuilder then_builder =
elem_or_init.getThenBodyBuilder(rewriter->getListener());
Value elem = then_builder.create<mlir::LoadOp>(
loc, reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<scf::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
OpBuilder else_builder =
elem_or_init.getElseBodyBuilder(rewriter->getListener());
else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
return rewriter->create<scf::ReduceOp>(loc,
@ -617,7 +619,8 @@ class SelectAndScatterOpConverter
// Case when we are inside boundaries of 'arg' and not in the pad area.
{
OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder();
OpBuilder in_bounds_then_b =
if_in_bounds.getThenBodyBuilder(b->getListener());
auto select_or_init_results = SelectOrInitialize(
s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
@ -625,7 +628,8 @@ class SelectAndScatterOpConverter
// Case when we are in the pad.
{
OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder();
OpBuilder in_bounds_else_b =
if_in_bounds.getElseBodyBuilder(b->getListener());
in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
}
@ -651,7 +655,7 @@ class SelectAndScatterOpConverter
// element in boundaries of the operand. Select function has to be computed
// here.
{
OpBuilder if_init_then_b = if_init.getThenBodyBuilder();
OpBuilder if_init_then_b = if_init.getThenBodyBuilder(b->getListener());
auto& lhlo_select = s_and_s_op.select().front();
Value pred =
@ -664,14 +668,14 @@ class SelectAndScatterOpConverter
// Pred == true, therefore pack newly selected ivs, val and init flag back
// to iter_args and return.
{
OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder();
OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(b->getListener());
if_pred_then_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
}
// Pred == false, therefore return old iter_args.
{
OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder();
OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(b->getListener());
if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
}
@ -680,7 +684,7 @@ class SelectAndScatterOpConverter
// Init == false, i.e. only pad was visited before and this is the first
// element in the boundaries of the operand.
{
OpBuilder if_init_else_b = if_init.getElseBodyBuilder();
OpBuilder if_init_else_b = if_init.getElseBodyBuilder(b->getListener());
if_init_else_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());