Rework output to improve readability of intermediate MLIR code. (#87)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Output of non-value constants. Write full source to file.

* Fix e2e tests.

* Output constant free and full code in separate files.

* Emit separate files.

* Move file output management to utils.

* Elide the values of glotbal krnl constants.

* Add dual file output for Basic flag.

* Add tests.

* Add passes to cmake file.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-04-24 16:15:36 -04:00 committed by GitHub
parent 363ee26a52
commit 137ce767e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 308 additions and 22 deletions

View File

@ -263,7 +263,9 @@ set(ONNXMLIRWholeArchiveLibs
OMShapeInference
OMShapeInferenceOpInterface
OMAttributePromotion
OMPromotableConstOperandsOpInterface)
OMPromotableConstOperandsOpInterface
OMElideConstants
OMElideKrnlGlobalConstants)
# Function to construct linkage option for the static libraries that must be
# linked with --whole-archive (or equivalent).

View File

@ -34,6 +34,8 @@ target_link_libraries(onnx-mlir
OMShapeInferenceOpInterface
OMAttributePromotion
OMPromotableConstOperandsOpInterface
OMElideConstants
OMElideKrnlGlobalConstants
OMKrnlToAffine
OMKrnlToLLVM
OMONNXToKrnl

View File

@ -41,8 +41,8 @@ struct ONNXConstantOpLowering : public ConversionPattern {
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
memRefType,
rewriter.getI64ArrayAttr(shape),
constantOp.value().getValue(),
rewriter.getStringAttr("constant_" + std::to_string(constantID)));
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
constantOp.value().getValue());
// Increment constant ID:
constantID++;

View File

@ -199,7 +199,7 @@ def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
Operation for holding global data values.
}];
let arguments = (ins AnyAttr:$shape, AnyAttr:$value, StrAttr:$name);
let arguments = (ins AnyAttr:$shape, StrAttr:$name, OptionalAttr<AnyAttr>:$value);
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
let parser = ?;

View File

@ -9,6 +9,9 @@
//===----------------------------------------------------------------------===//
#include "src/MainUtils.hpp"
#include <fcntl.h>
#include <stdio.h>
#include <unistd.h>
using namespace std;
using namespace onnx_mlir;
@ -35,9 +38,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
}
}
void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
void EmitLLVMBitCode(
const mlir::OwningModuleRef &module, string outputFilename) {
error_code error;
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error,
llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error,
llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
moduleBitcodeStream);
@ -97,3 +101,67 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
LoadMLIR(inputFilename, context, module);
}
}
void outputCode(
mlir::OwningModuleRef &module, string filename, string extension) {
// Start a separate process to redirect the model output. I/O redirection
// changes will not be visible to the parent process.
if (fork() == 0) {
const char * tempFilename = (filename + extension).c_str();
freopen(tempFilename, "w", stderr);
module->dump();
fclose(stderr);
exit(0);
}
}
void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// For EmitONNXIR and EmitMLIR the constant value are embedded in the code
// thus making the code hard to read. These values can be elided by emitting
// two versions of the same source code:
// (1) a version with all the constant values included meant for being passed
// back to onnx-mlir for further processing and stored in:
//
// <name>.onnx.mlir
//
// (2) a version without constants meant for being inspected by users and
// stored in:
//
// <name>.mlir
//
// In the case of the LLVM Dialect IR the constant values are grouped
// outside the function code at the beginning of the file in which case the
// elision of these constants is not strictly required. Elision is also not
// necessary when emitting the .bc file.
if (emissionTarget == EmitLLVMBC) {
// Write LLVM bitcode to disk.
string outputFilename = outputBaseName + ".bc";
EmitLLVMBitCode(module, outputFilename);
printf("LLVM bitcode written to %s\n", outputFilename.c_str());
} else {
// Emit the version with all constants included.
outputCode(module, outputBaseName, ".onnx.mlir");
printf("Full MLIR code written to: \n\t%s\n\n",
(outputBaseName + ".onnx.mlir").c_str());
// Apply specific passes to clean up the code where necessary.
mlir::PassManager cleanSourcePM(&context);
if (emissionTarget == EmitONNXIR || emissionTarget == EmitONNXBasic)
cleanSourcePM.addPass(mlir::createElideConstantValuePass());
if (emissionTarget == EmitMLIR)
cleanSourcePM.addPass(mlir::createElideConstGlobalValuePass());
if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR ||
emissionTarget == EmitMLIR) {
if (mlir::failed(cleanSourcePM.run(*module)))
llvm::errs() << "Could not apply simplification passes.\n";
outputCode(module, outputBaseName, ".mlir");
printf("Constant-free MLIR Code written to: \n\t%s\n\n",
(outputBaseName + ".mlir").c_str());
printf("Use:\n\t%s\nto continue lowering the code to other dialects.\n",
(outputBaseName + ".onnx.mlir").c_str());
}
}
}

View File

@ -48,7 +48,8 @@ enum EmissionTargetType {
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module);
void EmitLLVMBitCode(const mlir::OwningModuleRef &module);
void EmitLLVMBitCode(
const mlir::OwningModuleRef &module, std::string outputFilename);
void registerDialects();
@ -61,5 +62,13 @@ void addKrnlToAffinePasses(mlir::PassManager &pm);
void addKrnlToLLVMPasses(mlir::PassManager &pm);
void processInputFile(std::string inputFilename,
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
mlir::OwningModuleRef &module);
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
mlir::OwningModuleRef &module);
void outputCode(
mlir::OwningModuleRef &module, std::string filename,
std::string extension);
void emitOutputFiles(std::string outputBaseName,
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
mlir::OwningModuleRef &module);

View File

@ -23,12 +23,18 @@ std::unique_ptr<Pass> createShapeInferencePass();
/// Pass for promoting constant operands to attributes.
std::unique_ptr<Pass> createAttributePromotionPass();
/// Pass for eliding the values of constant operations.
std::unique_ptr<Pass> createElideConstantValuePass();
/// Add pass for lowering to Krnl IR.
std::unique_ptr<Pass> createLowerToKrnlPass();
/// Pass for lowering frontend dialects to Krnl IR dialect.
std::unique_ptr<Pass> createLowerKrnlPass();
/// Pass for eliding the values of global Krnl operations.
std::unique_ptr<Pass> createElideConstGlobalValuePass();
/// Pass for lowering Krnl dialect to LLVM dialect.
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();

View File

@ -24,4 +24,16 @@ add_dependencies(OMKrnlToLLVM
OMKrnlOps
OMONNXOps)
add_library(OMElideKrnlGlobalConstants
ElideKrnlGlobalConstants.cpp)
target_include_directories(OMElideKrnlGlobalConstants
PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
target_link_libraries(OMElideKrnlGlobalConstants
${MLIRLibs}
OMKrnlOps
OMONNXOps)
add_subdirectory(ONNX)

View File

@ -0,0 +1,76 @@
//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// In practice, the constant values of Global Krnl operations may be large
// enough to hinder the readability of the MLIR intermediate representation.
//
// This file creates a pass which elides the explicit values of constant
// global operations. This pass has purely cosmetic purposes and should only be
// run to obtain a compact representation of the program when emitting Krnl
// dialect code. This pass should never be invoked on code meant to be run.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp"
using namespace mlir;
namespace {
/*!
* RewritePattern that replaces existing constant Krnl global values
* with a similar operation which preserves all attributes except the value
* attribute.
*/
class KrnlConstGlobalValueElision : public OpRewritePattern<KrnlGlobalOp> {
public:
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
LogicalResult matchAndRewrite(KrnlGlobalOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
if (op.value().hasValue()) {
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(
loc, op.getResult().getType(), op.shape(), op.name(), nullptr);
rewriter.replaceOp(op, newGlobalOp.getResult());
}
return success();
}
};
/*!
* Function pass that performs constant value elision of Krnl globals.
*/
class ElideConstGlobalValuePass
: public mlir::FunctionPass<ElideConstGlobalValuePass> {
public:
void runOnFunction() override {
auto function = getFunction();
ConversionTarget target(getContext());
OwningRewritePatternList patterns;
patterns.insert<KrnlConstGlobalValueElision>(&getContext());
applyPatternsGreedily(function, patterns);
}
};
} // namespace
std::unique_ptr<Pass> mlir::createElideConstGlobalValuePass() {
return std::make_unique<ElideConstGlobalValuePass>();
}
static PassRegistration<ElideConstGlobalValuePass> pass("elide-krnl-constants",
"Elide the constant values of the Global Krnl operations.");

View File

@ -143,9 +143,12 @@ public:
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
assert(krnlGlobalOp.value().hasValue() &&
"Krnl Global must always have a value");
global = rewriter.create<LLVM::GlobalOp>(loc,
llvmGlobalType, /*isConstant=*/true,
LLVM::Linkage::Internal, name, krnlGlobalOp.value());
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
}
// Some frequently used types.

View File

@ -1,5 +1,4 @@
//===----- attribute_promotion.cpp - Attribute Promotion
//-------------------===//
//===----- attribute_promotion.cpp - Attribute Promotion ------------------===//
//
// Copyright 2020 The IBM Research Authors.
//

View File

@ -8,6 +8,12 @@ target_include_directories(OMAttributePromotion
add_dependencies(OMAttributePromotion
OMPromotableConstOperandsOpInterface)
add_library(OMElideConstants
ElideConstants.cpp)
target_include_directories(OMElideConstants
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNF_MLIR_SRC_ROOT})
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
add_public_tablegen_target(OMONNXRewriteIncGen)

View File

@ -0,0 +1,83 @@
//===----- ElideConstants.cpp - Elide Constant Values ---------------------===//
//
// Copyright 2020 The IBM Research Authors.
//
// =============================================================================
//
// In practice, the constant values of Constant operations may be large enough
// to hinder the readability of the MLIR intermediate representation.
//
// This file creates a pass which elides the explicit values of Constant
// operations. This pass has purely cosmetic purposes and should only be run to
// obtain a compact representation of the program when emitting ONNX and KRNL
// Dialect code. This pass should never be invoked on code meant to be run.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"
using namespace mlir;
namespace {
/*!
* RewritePattern that replaces existing Constant operations
* with Constant operations with the same shape information but
* no values.
*/
class ConstantValueElision : public OpRewritePattern<ONNXConstantOp> {
public:
using OpRewritePattern<ONNXConstantOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXConstantOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto constOp = llvm::dyn_cast<ONNXConstantOp>(&op);
if (constOp->sparse_value().hasValue())
emitError(loc, "Only support dense values at this time");
if (constOp->value().hasValue()) {
auto newConstOp = rewriter.create<ONNXConstantOp>(
loc, constOp->getResult().getType(), nullptr, nullptr);
rewriter.replaceOp(op, newConstOp.getResult());
}
return success();
}
};
/*!
* Function pass that performs constant value elision.
*/
class ElideConstantValuePass
: public mlir::FunctionPass<ElideConstantValuePass> {
public:
void runOnFunction() override {
auto function = getFunction();
ConversionTarget target(getContext());
OwningRewritePatternList patterns;
patterns.insert<ConstantValueElision>(&getContext());
applyPatternsGreedily(function, patterns);
}
};
} // end anonymous namespace
/*!
* Create a Constant Value Elision pass.
*/
std::unique_ptr<mlir::Pass> mlir::createElideConstantValuePass() {
return std::make_unique<ElideConstantValuePass>();
}
static PassRegistration<ElideConstantValuePass> pass(
"elide-constants", "Elide values of constant operations.");

View File

@ -43,9 +43,14 @@ int main(int argc, char *argv[]) {
mlir::OwningModuleRef module;
processInputFile(inputFilename, emissionTarget, context, module);
// Input file base name.
string outputBaseName =
inputFilename.substr(0, inputFilename.find_last_of("."));
mlir::PassManager pm(&context);
if (emissionTarget >= EmitONNXIR)
if (emissionTarget >= EmitONNXIR) {
addONNXToMLIRPasses(pm);
}
if (emissionTarget >= EmitMLIR) {
addONNXToKrnlPasses(pm);
@ -58,12 +63,7 @@ int main(int argc, char *argv[]) {
if (mlir::failed(pm.run(*module)))
return 4;
if (emissionTarget == EmitLLVMBC) {
// Write LLVM bitcode to disk.
EmitLLVMBitCode(module);
printf("LLVM bitcode written to ./model.bc");
} else
module->dump();
emitOutputFiles(outputBaseName, emissionTarget, context, module);
return 0;
}

View File

@ -42,13 +42,13 @@ class DummyBackend(onnx.backend.base.Backend):
execute_commands([ONNX_MLIR, "temp_model.onnx"])
# Call llc to generate object file from bitcode.
execute_commands(
[LLC, "-filetype=obj", "-relocation-model=pic", "model.bc"])
[LLC, "-filetype=obj", "-relocation-model=pic", "temp_model.bc"])
# Generate shared library from object file, linking with c runtime.
execute_commands([
CXX, "-shared", "-fPIC", "model.o", "-o", "model.so",
CXX, "-shared", "-fPIC", "temp_model.o", "-o", "temp_model.so",
"-L" + RUNTIME_DIR, "-lcruntime"
])
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
return ExecutionSession("./temp_model.so", "_dyn_entry_point_main_graph")
@classmethod
def supports_device(cls, device):

View File

@ -0,0 +1,10 @@
// RUN: onnx-mlir-opt --elide-constants %s -split-input-file | FileCheck %s
// CHECK-LABEL: func @test_elide_constant(%arg0: tensor<1xf32>) -> tensor<1x10xf32>
func @test_elide_constant(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
%0 = "onnx.Constant"() {value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
"std.return"(%0) : (tensor<1x10xf32>) -> ()
// CHECK: %0 = "onnx.Constant"() : () -> tensor<1x10xf32>
// CHECK: return %0 : tensor<1x10xf32>
}

View File

@ -0,0 +1,10 @@
// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s
// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32>
func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> memref<1x10xf32>
return %0 : memref<1x10xf32>
// CHECK: %0 = "krnl.global"() {name = "constant_0", shape = [1, 10]} : () -> memref<1x10xf32>
// CHECK: return %0 : memref<1x10xf32>
}