diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index f1440b5..4fdc94a 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -463,8 +463,6 @@ llvm::cl::opt printIR("printIR", void outputCode( mlir::OwningModuleRef &module, string filename, string extension) { - // Start a separate process to redirect the model output. I/O redirection - // changes will not be visible to the parent process. string tempFilename = filename + extension; mlir::OpPrintingFlags flags; if (preserveLocations) @@ -473,18 +471,17 @@ void outputCode( #ifdef _WIN32 // copy original stderr file number int stderrOrigin = _dup(_fileno(stderr)); +#else + int stderrOrigin = dup(fileno(stderr)); +#endif freopen(tempFilename.c_str(), "w", stderr); module->print(llvm::errs(), flags); fflush(stderr); // set modified stderr as original stderr +#ifdef _WIN32 _dup2(stderrOrigin, _fileno(stderr)); #else - if (fork() == 0) { - freopen(tempFilename.c_str(), "w", stderr); - module->print(llvm::errs(), flags); - fclose(stderr); - exit(0); - } + dup2(stderrOrigin, fileno(stderr)); #endif if (printIR) module->print(llvm::outs(), flags); diff --git a/src/main.cpp b/src/main.cpp index 29989e8..aff99ad 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -21,6 +21,10 @@ int main(int argc, char *argv[]) { llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::cat(OnnxMlirOptions)); + llvm::cl::opt outputBaseName("o", + llvm::cl::desc("Base path for output files, extensions will be added."), + llvm::cl::value_desc("path"), llvm::cl::cat(OnnxMlirOptions), + llvm::cl::ValueRequired); llvm::cl::opt emissionTarget( llvm::cl::desc("Choose target to emit:"), llvm::cl::values( @@ -46,9 +50,9 @@ int main(int argc, char *argv[]) { mlir::OwningModuleRef module; processInputFile(inputFilename, emissionTarget, context, module); - // Input file base name. - string outputBaseName = - inputFilename.substr(0, inputFilename.find_last_of(".")); + // Input file base name, replace path if required. + if (outputBaseName == "") + outputBaseName = inputFilename.substr(0, inputFilename.find_last_of(".")); return compileModule(module, context, outputBaseName, emissionTarget); }