mlir-hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp

122 lines
4.1 KiB
C++
Raw Normal View History

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/MlirOptMain.h"
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> verifyPasses(
"verify-each",
llvm::cl::desc("Run the verifier after each transformation pass"),
llvm::cl::init(true));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> allowUnregisteredDialects(
"allow-unregistered-dialect",
llvm::cl::desc("Allow operation with no registered dialects"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> showDialects(
"show-dialects", llvm::cl::desc("Print the list of registered dialects"),
llvm::cl::init(false));
int main(int argc, char **argv) {
mlir::registerAllDialects();
mlir::registerAllPasses();
mlir::mhlo::registerAllDialects();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
llvm::InitLLVM y(argc, argv);
// Register any pass manager command line options.
mlir::registerPassManagerCLOptions();
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
// Parse pass names in main to ensure static initialization completed.
llvm::cl::ParseCommandLineOptions(argc, argv,
"MLIR modular optimizer driver\n");
if (showDialects) {
mlir::MLIRContext context;
llvm::outs() << "Registered Dialects:\n";
for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
llvm::outs() << dialect->getNamespace() << "\n";
}
return 0;
}
// Set up the input file.
std::string errorMessage;
auto file = mlir::openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return 1;
}
auto output = mlir::openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
exit(1);
}
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects))) {
return 1;
}
// Keep the output file if the invocation of MlirOptMain was successful.
output->keep();
return 0;
}