Sink standard dialect constants in sink_constants_to_control_flow pass
This is required before exporting HLO dialect ops with standard dialect constant to XLA. Also, sink constants for sort op as well. Added a TODO to generalize this pass to handle more ops and non-const values defined outside. PiperOrigin-RevId: 324301911
This commit is contained in:
parent
1c535f1718
commit
577a81a66d
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Transforms/RegionUtils.h"
|
#include "mlir/Transforms/RegionUtils.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
@ -29,6 +30,13 @@ namespace {
|
||||||
|
|
||||||
// A pass that sinks constants implicitly captured in control flow regions. This
|
// A pass that sinks constants implicitly captured in control flow regions. This
|
||||||
// is necessary to export to XLA.
|
// is necessary to export to XLA.
|
||||||
|
// TODO(hinsu): Generalize this pass to handle all the ops with regions. Any
|
||||||
|
// value used within the region that is defined outside of op's region should be
|
||||||
|
// sank to the regions and not just the constants. Ops such as If and While
|
||||||
|
// whose computations doesn't require fixed signature like Sort or Reduce have
|
||||||
|
// an option to pass outside values as operands of the op to avoid recomputing
|
||||||
|
// those within internally. Note that doing so is the only option in case of
|
||||||
|
// BlockArguments.
|
||||||
class SinkConstantsToControlFlowPass
|
class SinkConstantsToControlFlowPass
|
||||||
: public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> {
|
: public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
|
@ -39,6 +47,8 @@ class SinkConstantsToControlFlowPass
|
||||||
} else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
|
} else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
|
||||||
SinkToRegion(&if_op.true_branch());
|
SinkToRegion(&if_op.true_branch());
|
||||||
SinkToRegion(&if_op.false_branch());
|
SinkToRegion(&if_op.false_branch());
|
||||||
|
} else if (auto sort_op = llvm::dyn_cast<SortOp>(op)) {
|
||||||
|
SinkToRegion(&sort_op.comparator());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -46,26 +56,26 @@ class SinkConstantsToControlFlowPass
|
||||||
private:
|
private:
|
||||||
// Performs constant sinking into a region.
|
// Performs constant sinking into a region.
|
||||||
static void SinkToRegion(Region* region) {
|
static void SinkToRegion(Region* region) {
|
||||||
llvm::DenseMap<Value, ConstOp> sunk_constant;
|
llvm::DenseMap<Value, Operation*> sunk_constant;
|
||||||
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
|
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
|
||||||
Value constant = use->get();
|
Value constant = use->get();
|
||||||
auto const_op = dyn_cast_or_null<ConstOp>(constant.getDefiningOp());
|
auto op = constant.getDefiningOp();
|
||||||
if (!const_op) return;
|
if (!op || !isa<ConstOp, ConstantOp>(op)) return;
|
||||||
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
|
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
|
||||||
if (!map_entry.second) {
|
if (!map_entry.second) {
|
||||||
// This constant has already been cloned into the region, reuse it.
|
// This constant has already been cloned into the region, reuse it.
|
||||||
use->set(map_entry.first->getSecond().getResult());
|
use->set(map_entry.first->getSecond()->getResult(0));
|
||||||
if (constant.use_empty()) const_op.erase();
|
if (op->use_empty()) op->erase();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (constant.hasOneUse()) {
|
if (constant.hasOneUse()) {
|
||||||
const_op.getOperation()->moveBefore(®ion->front().front());
|
op->moveBefore(®ion->front().front());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
map_entry.first->getSecond() = const_op.clone();
|
map_entry.first->getSecond() = op->clone();
|
||||||
region->front().getOperations().insert(region->front().begin(),
|
region->front().getOperations().insert(region->front().begin(),
|
||||||
map_entry.first->getSecond());
|
map_entry.first->getSecond());
|
||||||
use->set(map_entry.first->getSecond().getResult());
|
use->set(map_entry.first->getSecond()->getResult(0));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -58,3 +58,17 @@ func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
|
||||||
%9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
%9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||||
return %9 : tensor<i64>
|
return %9 : tensor<i64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @sink_const_to_sort(%arg0: tensor<16xf32>) {
|
||||||
|
%c0 = constant dense<1.0> : tensor<f32>
|
||||||
|
// CHECK: "mhlo.sort"
|
||||||
|
%0 = "mhlo.sort"(%arg0) ( {
|
||||||
|
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||||
|
// CHECK: constant dense<1.000000e+00>
|
||||||
|
%1 = "mhlo.divide"(%arg1, %c0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
%2 = "mhlo.divide"(%arg2, %c0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
%3 = "mhlo.compare"(%1, %2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
"mhlo.return"(%3) : (tensor<i1>) -> ()
|
||||||
|
}) {is_stable = true} : (tensor<16xf32>) -> tensor<16xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue