Add a header for table-gen generated pass for MHLO and use it in SinkConstantsToControlFlowPass
A non globally registered pass should define `getName()` in order to generate correct crash reproducers. This is something we get "for free" when using the TableGen generated base class. We should also migrate the other passes to the same mechanism and remove the static global registration. PiperOrigin-RevId: 332976907
This commit is contained in:
parent
0bb017098a
commit
7abd557a61
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace mhlo {
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc"
|
||||||
|
|
||||||
|
} // end namespace mhlo
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_
|
|
@ -88,7 +88,7 @@ def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow", "FuncOp"> {
|
def SinkConstantsToControlFlowPass : FunctionPass<"mhlo-sink-constants-to-control-flow"> {
|
||||||
let summary = "Sink constants implicitly captured in control flow regions. This "
|
let summary = "Sink constants implicitly captured in control flow regions. This "
|
||||||
"is necessary to export to XLA.";
|
"is necessary to export to XLA.";
|
||||||
let constructor = "createSinkConstantsToControlFlowPass()";
|
let constructor = "createSinkConstantsToControlFlowPass()";
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
@ -39,7 +40,8 @@ namespace {
|
||||||
// those within internally. Note that doing so is the only option in case of
|
// those within internally. Note that doing so is the only option in case of
|
||||||
// values defined outside that are BlockArguments of any of the parent region.
|
// values defined outside that are BlockArguments of any of the parent region.
|
||||||
class SinkConstantsToControlFlowPass
|
class SinkConstantsToControlFlowPass
|
||||||
: public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> {
|
: public SinkConstantsToControlFlowPassBase<
|
||||||
|
SinkConstantsToControlFlowPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
getFunction().walk([](Operation* op) {
|
getFunction().walk([](Operation* op) {
|
||||||
if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
|
if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
|
||||||
|
|
Loading…
Reference in New Issue