Legalize TensorFlow NonMaxSuppressionV4 and SelfAdjointEigV2Op ops to HLO

Added support for HLO ops bitcast-convert, sort and while in MlirHloBuilder and enabled tests for NonMaxSuppressionV4 and SelfAdjointEigV2Op using these ops.

PiperOrigin-RevId: 324360651
This commit is contained in:
Smit Hinsu 2020-07-31 23:10:33 -07:00 committed by TensorFlow MLIR Team
parent 7809320a5e
commit 3fe9a7d2db
1 changed files with 5 additions and 2 deletions

View File

@ -30,13 +30,14 @@ namespace {
// A pass that sinks constants implicitly captured in control flow regions. This
// is necessary to export to XLA.
//
// TODO(hinsu): Generalize this pass to handle all the ops with regions. Any
// value used within the region that is defined outside of op's region should be
// sank to the regions and not just the constants. Ops such as If and While
// whose computations doesn't require fixed signature like Sort or Reduce have
// an option to pass outside values as operands of the op to avoid recomputing
// those within internally. Note that doing so is the only option in case of
// BlockArguments.
// values defined outside that are BlockArguments of any of the parent region.
class SinkConstantsToControlFlowPass
: public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> {
void runOnFunction() override {
@ -60,7 +61,7 @@ class SinkConstantsToControlFlowPass
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
Value constant = use->get();
auto op = constant.getDefiningOp();
if (!op || !isa<ConstOp, ConstantOp>(op)) return;
if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return;
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
if (!map_entry.second) {
// This constant has already been cloned into the region, reuse it.
@ -82,6 +83,8 @@ class SinkConstantsToControlFlowPass
} // anonymous namespace
// TODO(hinsu): Rename this pass and move to a different file along with the
// generalization to make all ops isolated from above.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
return std::make_unique<SinkConstantsToControlFlowPass>();
}