diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h b/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h new file mode 100644 index 0000000..5f18eeb --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace mhlo { + +#define GEN_PASS_CLASSES +#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" + +} // end namespace mhlo +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index aa0f4c3..52cee9d 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -88,7 +88,7 @@ def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> { } -def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow", "FuncOp"> { +def SinkConstantsToControlFlowPass : FunctionPass<"mhlo-sink-constants-to-control-flow"> { let summary = "Sink constants implicitly captured in control flow regions. This " "is necessary to export to XLA."; let constructor = "createSinkConstantsToControlFlowPass()"; diff --git a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index 8d677f4..d863d82 100644 --- a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -16,6 +16,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Casting.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -39,7 +40,8 @@ namespace { // those within internally. Note that doing so is the only option in case of // values defined outside that are BlockArguments of any of the parent region. class SinkConstantsToControlFlowPass - : public mlir::PassWrapper { + : public SinkConstantsToControlFlowPassBase< + SinkConstantsToControlFlowPass> { void runOnFunction() override { getFunction().walk([](Operation* op) { if (auto while_op = llvm::dyn_cast(op)) {