Forward listeners in LhloLegalizeToParallelLoops builders

PiperOrigin-RevId: 347554379
This commit is contained in:
Tres Popp 2020-12-15 00:12:32 -08:00 committed by TensorFlow MLIR Team
parent 1a58f19664
commit 9df327d88f
1 changed files with 12 additions and 8 deletions

View File

@ -437,12 +437,14 @@ class ReduceWindowOpConverter
loc, operand_type.getElementType(), mapped_ivs.in_bounds, loc, operand_type.getElementType(), mapped_ivs.in_bounds,
/*withElseRegion=*/true); /*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); OpBuilder then_builder =
elem_or_init.getThenBodyBuilder(rewriter->getListener());
Value elem = then_builder.create<mlir::LoadOp>( Value elem = then_builder.create<mlir::LoadOp>(
loc, reduce_window_op.operand(), mapped_ivs.ivs); loc, reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<scf::YieldOp>(loc, elem); then_builder.create<scf::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); OpBuilder else_builder =
elem_or_init.getElseBodyBuilder(rewriter->getListener());
else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin()); else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
return rewriter->create<scf::ReduceOp>(loc, return rewriter->create<scf::ReduceOp>(loc,
@ -617,7 +619,8 @@ class SelectAndScatterOpConverter
// Case when we are inside boundaries of 'arg' and not in the pad area. // Case when we are inside boundaries of 'arg' and not in the pad area.
{ {
OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder(); OpBuilder in_bounds_then_b =
if_in_bounds.getThenBodyBuilder(b->getListener());
auto select_or_init_results = SelectOrInitialize( auto select_or_init_results = SelectOrInitialize(
s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b); s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results); in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
@ -625,7 +628,8 @@ class SelectAndScatterOpConverter
// Case when we are in the pad. // Case when we are in the pad.
{ {
OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder(); OpBuilder in_bounds_else_b =
if_in_bounds.getElseBodyBuilder(b->getListener());
in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector()); in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
} }
@ -651,7 +655,7 @@ class SelectAndScatterOpConverter
// element in boundaries of the operand. Select function has to be computed // element in boundaries of the operand. Select function has to be computed
// here. // here.
{ {
OpBuilder if_init_then_b = if_init.getThenBodyBuilder(); OpBuilder if_init_then_b = if_init.getThenBodyBuilder(b->getListener());
auto& lhlo_select = s_and_s_op.select().front(); auto& lhlo_select = s_and_s_op.select().front();
Value pred = Value pred =
@ -664,14 +668,14 @@ class SelectAndScatterOpConverter
// Pred == true, therefore pack newly selected ivs, val and init flag back // Pred == true, therefore pack newly selected ivs, val and init flag back
// to iter_args and return. // to iter_args and return.
{ {
OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(); OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(b->getListener());
if_pred_then_b.create<scf::YieldOp>( if_pred_then_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
} }
// Pred == false, therefore return old iter_args. // Pred == false, therefore return old iter_args.
{ {
OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(); OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(b->getListener());
if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector()); if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
} }
@ -680,7 +684,7 @@ class SelectAndScatterOpConverter
// Init == false, i.e. only pad was visited before and this is the first // Init == false, i.e. only pad was visited before and this is the first
// element in the boundaries of the operand. // element in the boundaries of the operand.
{ {
OpBuilder if_init_else_b = if_init.getElseBodyBuilder(); OpBuilder if_init_else_b = if_init.getElseBodyBuilder(b->getListener());
if_init_else_b.create<scf::YieldOp>( if_init_else_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());