From 9df327d88fa9e7f5839b713e864e980671fa7214 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Tue, 15 Dec 2020 00:12:32 -0800 Subject: [PATCH] Forward listeners in LhloLegalizeToParallelLoops builders PiperOrigin-RevId: 347554379 --- .../lhlo_legalize_to_parallel_loops.cc | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 78d681b..3f3c4e3 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -437,12 +437,14 @@ class ReduceWindowOpConverter loc, operand_type.getElementType(), mapped_ivs.in_bounds, /*withElseRegion=*/true); - OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); + OpBuilder then_builder = + elem_or_init.getThenBodyBuilder(rewriter->getListener()); Value elem = then_builder.create( loc, reduce_window_op.operand(), mapped_ivs.ivs); then_builder.create(loc, elem); - OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); + OpBuilder else_builder = + elem_or_init.getElseBodyBuilder(rewriter->getListener()); else_builder.create(loc, *window_loop.initVals().begin()); return rewriter->create(loc, @@ -617,7 +619,8 @@ class SelectAndScatterOpConverter // Case when we are inside boundaries of 'arg' and not in the pad area. { - OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder(); + OpBuilder in_bounds_then_b = + if_in_bounds.getThenBodyBuilder(b->getListener()); auto select_or_init_results = SelectOrInitialize( s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b); in_bounds_then_b.create(loc, select_or_init_results); @@ -625,7 +628,8 @@ class SelectAndScatterOpConverter // Case when we are in the pad. { - OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder(); + OpBuilder in_bounds_else_b = + if_in_bounds.getElseBodyBuilder(b->getListener()); in_bounds_else_b.create(loc, ivs_val_flag.to_vector()); } @@ -651,7 +655,7 @@ class SelectAndScatterOpConverter // element in boundaries of the operand. Select function has to be computed // here. { - OpBuilder if_init_then_b = if_init.getThenBodyBuilder(); + OpBuilder if_init_then_b = if_init.getThenBodyBuilder(b->getListener()); auto& lhlo_select = s_and_s_op.select().front(); Value pred = @@ -664,14 +668,14 @@ class SelectAndScatterOpConverter // Pred == true, therefore pack newly selected ivs, val and init flag back // to iter_args and return. { - OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(); + OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(b->getListener()); if_pred_then_b.create( loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); } // Pred == false, therefore return old iter_args. { - OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(); + OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(b->getListener()); if_pred_else_b.create(loc, ivs_val_flag->to_vector()); } @@ -680,7 +684,7 @@ class SelectAndScatterOpConverter // Init == false, i.e. only pad was visited before and this is the first // element in the boundaries of the operand. { - OpBuilder if_init_else_b = if_init.getElseBodyBuilder(); + OpBuilder if_init_else_b = if_init.getElseBodyBuilder(b->getListener()); if_init_else_b.create( loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());