From 577a81a66d4741259e52e6e78f2cad0ce7604d38 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Fri, 31 Jul 2020 16:01:37 -0700 Subject: [PATCH] Sink standard dialect constants in sink_constants_to_control_flow pass This is required before exporting HLO dialect ops with standard dialect constant to XLA. Also, sink constants for sort op as well. Added a TODO to generalize this pass to handle more ops and non-const values defined outside. PiperOrigin-RevId: 324301911 --- .../sink_constants_to_control_flow.cc | 26 +++++++++++++------ tests/sink-constants-to-control-flow.mlir | 14 ++++++++++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index 0f31e61..f2ebf9d 100644 --- a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" namespace mlir { namespace mhlo { @@ -29,6 +30,13 @@ namespace { // A pass that sinks constants implicitly captured in control flow regions. This // is necessary to export to XLA. +// TODO(hinsu): Generalize this pass to handle all the ops with regions. Any +// value used within the region that is defined outside of op's region should be +// sank to the regions and not just the constants. Ops such as If and While +// whose computations doesn't require fixed signature like Sort or Reduce have +// an option to pass outside values as operands of the op to avoid recomputing +// those within internally. Note that doing so is the only option in case of +// BlockArguments. class SinkConstantsToControlFlowPass : public mlir::PassWrapper { void runOnFunction() override { @@ -39,6 +47,8 @@ class SinkConstantsToControlFlowPass } else if (auto if_op = llvm::dyn_cast(op)) { SinkToRegion(&if_op.true_branch()); SinkToRegion(&if_op.false_branch()); + } else if (auto sort_op = llvm::dyn_cast(op)) { + SinkToRegion(&sort_op.comparator()); } }); } @@ -46,26 +56,26 @@ class SinkConstantsToControlFlowPass private: // Performs constant sinking into a region. static void SinkToRegion(Region* region) { - llvm::DenseMap sunk_constant; + llvm::DenseMap sunk_constant; visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { Value constant = use->get(); - auto const_op = dyn_cast_or_null(constant.getDefiningOp()); - if (!const_op) return; + auto op = constant.getDefiningOp(); + if (!op || !isa(op)) return; auto map_entry = sunk_constant.try_emplace(constant, nullptr); if (!map_entry.second) { // This constant has already been cloned into the region, reuse it. - use->set(map_entry.first->getSecond().getResult()); - if (constant.use_empty()) const_op.erase(); + use->set(map_entry.first->getSecond()->getResult(0)); + if (op->use_empty()) op->erase(); return; } if (constant.hasOneUse()) { - const_op.getOperation()->moveBefore(®ion->front().front()); + op->moveBefore(®ion->front().front()); return; } - map_entry.first->getSecond() = const_op.clone(); + map_entry.first->getSecond() = op->clone(); region->front().getOperations().insert(region->front().begin(), map_entry.first->getSecond()); - use->set(map_entry.first->getSecond().getResult()); + use->set(map_entry.first->getSecond()->getResult(0)); }); } }; diff --git a/tests/sink-constants-to-control-flow.mlir b/tests/sink-constants-to-control-flow.mlir index f8b6b62..9e18ad8 100644 --- a/tests/sink-constants-to-control-flow.mlir +++ b/tests/sink-constants-to-control-flow.mlir @@ -58,3 +58,17 @@ func @sink_const_to_conditional(%arg0: tensor) -> tensor { %9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor return %9 : tensor } + +func @sink_const_to_sort(%arg0: tensor<16xf32>) { + %c0 = constant dense<1.0> : tensor + // CHECK: "mhlo.sort" + %0 = "mhlo.sort"(%arg0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): + // CHECK: constant dense<1.000000e+00> + %1 = "mhlo.divide"(%arg1, %c0) : (tensor, tensor) -> tensor + %2 = "mhlo.divide"(%arg2, %c0) : (tensor, tensor) -> tensor + %3 = "mhlo.compare"(%1, %2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%3) : (tensor) -> () + }) {is_stable = true} : (tensor<16xf32>) -> tensor<16xi32> + return +}