diff --git a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index f2ebf9d..70f0c58 100644 --- a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -30,13 +30,14 @@ namespace { // A pass that sinks constants implicitly captured in control flow regions. This // is necessary to export to XLA. +// // TODO(hinsu): Generalize this pass to handle all the ops with regions. Any // value used within the region that is defined outside of op's region should be // sank to the regions and not just the constants. Ops such as If and While // whose computations doesn't require fixed signature like Sort or Reduce have // an option to pass outside values as operands of the op to avoid recomputing // those within internally. Note that doing so is the only option in case of -// BlockArguments. +// values defined outside that are BlockArguments of any of the parent region. class SinkConstantsToControlFlowPass : public mlir::PassWrapper { void runOnFunction() override { @@ -60,7 +61,7 @@ class SinkConstantsToControlFlowPass visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { Value constant = use->get(); auto op = constant.getDefiningOp(); - if (!op || !isa(op)) return; + if (!op || !op->hasTrait()) return; auto map_entry = sunk_constant.try_emplace(constant, nullptr); if (!map_entry.second) { // This constant has already been cloned into the region, reuse it. @@ -82,6 +83,8 @@ class SinkConstantsToControlFlowPass } // anonymous namespace +// TODO(hinsu): Rename this pass and move to a different file along with the +// generalization to make all ops isolated from above. std::unique_ptr> createSinkConstantsToControlFlowPass() { return std::make_unique(); }