Rework output to improve readability of intermediate MLIR code. (#87)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Output of non-value constants. Write full source to file. * Fix e2e tests. * Output constant free and full code in separate files. * Emit separate files. * Move file output management to utils. * Elide the values of glotbal krnl constants. * Add dual file output for Basic flag. * Add tests. * Add passes to cmake file.
This commit is contained in:
parent
363ee26a52
commit
137ce767e6
|
@ -263,7 +263,9 @@ set(ONNXMLIRWholeArchiveLibs
|
|||
OMShapeInference
|
||||
OMShapeInferenceOpInterface
|
||||
OMAttributePromotion
|
||||
OMPromotableConstOperandsOpInterface)
|
||||
OMPromotableConstOperandsOpInterface
|
||||
OMElideConstants
|
||||
OMElideKrnlGlobalConstants)
|
||||
|
||||
# Function to construct linkage option for the static libraries that must be
|
||||
# linked with --whole-archive (or equivalent).
|
||||
|
|
|
@ -34,6 +34,8 @@ target_link_libraries(onnx-mlir
|
|||
OMShapeInferenceOpInterface
|
||||
OMAttributePromotion
|
||||
OMPromotableConstOperandsOpInterface
|
||||
OMElideConstants
|
||||
OMElideKrnlGlobalConstants
|
||||
OMKrnlToAffine
|
||||
OMKrnlToLLVM
|
||||
OMONNXToKrnl
|
||||
|
|
|
@ -41,8 +41,8 @@ struct ONNXConstantOpLowering : public ConversionPattern {
|
|||
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
|
||||
memRefType,
|
||||
rewriter.getI64ArrayAttr(shape),
|
||||
constantOp.value().getValue(),
|
||||
rewriter.getStringAttr("constant_" + std::to_string(constantID)));
|
||||
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
|
||||
constantOp.value().getValue());
|
||||
|
||||
// Increment constant ID:
|
||||
constantID++;
|
||||
|
|
|
@ -199,7 +199,7 @@ def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
|
|||
Operation for holding global data values.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyAttr:$shape, AnyAttr:$value, StrAttr:$name);
|
||||
let arguments = (ins AnyAttr:$shape, StrAttr:$name, OptionalAttr<AnyAttr>:$value);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
|
||||
|
||||
let parser = ?;
|
||||
|
|
|
@ -9,6 +9,9 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/MainUtils.hpp"
|
||||
#include <fcntl.h>
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace onnx_mlir;
|
||||
|
@ -35,9 +38,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
|||
}
|
||||
}
|
||||
|
||||
void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
|
||||
void EmitLLVMBitCode(
|
||||
const mlir::OwningModuleRef &module, string outputFilename) {
|
||||
error_code error;
|
||||
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error,
|
||||
llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error,
|
||||
llvm::sys::fs::F_None);
|
||||
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
||||
moduleBitcodeStream);
|
||||
|
@ -97,3 +101,67 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
|
|||
LoadMLIR(inputFilename, context, module);
|
||||
}
|
||||
}
|
||||
|
||||
void outputCode(
|
||||
mlir::OwningModuleRef &module, string filename, string extension) {
|
||||
// Start a separate process to redirect the model output. I/O redirection
|
||||
// changes will not be visible to the parent process.
|
||||
if (fork() == 0) {
|
||||
const char * tempFilename = (filename + extension).c_str();
|
||||
freopen(tempFilename, "w", stderr);
|
||||
module->dump();
|
||||
fclose(stderr);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
|
||||
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||
// For EmitONNXIR and EmitMLIR the constant value are embedded in the code
|
||||
// thus making the code hard to read. These values can be elided by emitting
|
||||
// two versions of the same source code:
|
||||
// (1) a version with all the constant values included meant for being passed
|
||||
// back to onnx-mlir for further processing and stored in:
|
||||
//
|
||||
// <name>.onnx.mlir
|
||||
//
|
||||
// (2) a version without constants meant for being inspected by users and
|
||||
// stored in:
|
||||
//
|
||||
// <name>.mlir
|
||||
//
|
||||
// In the case of the LLVM Dialect IR the constant values are grouped
|
||||
// outside the function code at the beginning of the file in which case the
|
||||
// elision of these constants is not strictly required. Elision is also not
|
||||
// necessary when emitting the .bc file.
|
||||
if (emissionTarget == EmitLLVMBC) {
|
||||
// Write LLVM bitcode to disk.
|
||||
string outputFilename = outputBaseName + ".bc";
|
||||
EmitLLVMBitCode(module, outputFilename);
|
||||
printf("LLVM bitcode written to %s\n", outputFilename.c_str());
|
||||
} else {
|
||||
// Emit the version with all constants included.
|
||||
outputCode(module, outputBaseName, ".onnx.mlir");
|
||||
printf("Full MLIR code written to: \n\t%s\n\n",
|
||||
(outputBaseName + ".onnx.mlir").c_str());
|
||||
|
||||
// Apply specific passes to clean up the code where necessary.
|
||||
mlir::PassManager cleanSourcePM(&context);
|
||||
if (emissionTarget == EmitONNXIR || emissionTarget == EmitONNXBasic)
|
||||
cleanSourcePM.addPass(mlir::createElideConstantValuePass());
|
||||
if (emissionTarget == EmitMLIR)
|
||||
cleanSourcePM.addPass(mlir::createElideConstGlobalValuePass());
|
||||
|
||||
if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR ||
|
||||
emissionTarget == EmitMLIR) {
|
||||
if (mlir::failed(cleanSourcePM.run(*module)))
|
||||
llvm::errs() << "Could not apply simplification passes.\n";
|
||||
outputCode(module, outputBaseName, ".mlir");
|
||||
printf("Constant-free MLIR Code written to: \n\t%s\n\n",
|
||||
(outputBaseName + ".mlir").c_str());
|
||||
|
||||
printf("Use:\n\t%s\nto continue lowering the code to other dialects.\n",
|
||||
(outputBaseName + ".onnx.mlir").c_str());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -48,7 +48,8 @@ enum EmissionTargetType {
|
|||
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module);
|
||||
|
||||
void EmitLLVMBitCode(const mlir::OwningModuleRef &module);
|
||||
void EmitLLVMBitCode(
|
||||
const mlir::OwningModuleRef &module, std::string outputFilename);
|
||||
|
||||
void registerDialects();
|
||||
|
||||
|
@ -61,5 +62,13 @@ void addKrnlToAffinePasses(mlir::PassManager &pm);
|
|||
void addKrnlToLLVMPasses(mlir::PassManager &pm);
|
||||
|
||||
void processInputFile(std::string inputFilename,
|
||||
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module);
|
||||
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module);
|
||||
|
||||
void outputCode(
|
||||
mlir::OwningModuleRef &module, std::string filename,
|
||||
std::string extension);
|
||||
|
||||
void emitOutputFiles(std::string outputBaseName,
|
||||
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module);
|
|
@ -23,12 +23,18 @@ std::unique_ptr<Pass> createShapeInferencePass();
|
|||
/// Pass for promoting constant operands to attributes.
|
||||
std::unique_ptr<Pass> createAttributePromotionPass();
|
||||
|
||||
/// Pass for eliding the values of constant operations.
|
||||
std::unique_ptr<Pass> createElideConstantValuePass();
|
||||
|
||||
/// Add pass for lowering to Krnl IR.
|
||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||
|
||||
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
||||
std::unique_ptr<Pass> createLowerKrnlPass();
|
||||
|
||||
/// Pass for eliding the values of global Krnl operations.
|
||||
std::unique_ptr<Pass> createElideConstGlobalValuePass();
|
||||
|
||||
/// Pass for lowering Krnl dialect to LLVM dialect.
|
||||
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
|
||||
|
||||
|
|
|
@ -24,4 +24,16 @@ add_dependencies(OMKrnlToLLVM
|
|||
OMKrnlOps
|
||||
OMONNXOps)
|
||||
|
||||
add_library(OMElideKrnlGlobalConstants
|
||||
ElideKrnlGlobalConstants.cpp)
|
||||
target_include_directories(OMElideKrnlGlobalConstants
|
||||
PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
target_link_libraries(OMElideKrnlGlobalConstants
|
||||
${MLIRLibs}
|
||||
OMKrnlOps
|
||||
OMONNXOps)
|
||||
|
||||
add_subdirectory(ONNX)
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// In practice, the constant values of Global Krnl operations may be large
|
||||
// enough to hinder the readability of the MLIR intermediate representation.
|
||||
//
|
||||
// This file creates a pass which elides the explicit values of constant
|
||||
// global operations. This pass has purely cosmetic purposes and should only be
|
||||
// run to obtain a compact representation of the program when emitting Krnl
|
||||
// dialect code. This pass should never be invoked on code meant to be run.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||
#include "src/Pass/Passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/*!
|
||||
* RewritePattern that replaces existing constant Krnl global values
|
||||
* with a similar operation which preserves all attributes except the value
|
||||
* attribute.
|
||||
*/
|
||||
|
||||
class KrnlConstGlobalValueElision : public OpRewritePattern<KrnlGlobalOp> {
|
||||
public:
|
||||
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(KrnlGlobalOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
if (op.value().hasValue()) {
|
||||
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(
|
||||
loc, op.getResult().getType(), op.shape(), op.name(), nullptr);
|
||||
rewriter.replaceOp(op, newGlobalOp.getResult());
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* Function pass that performs constant value elision of Krnl globals.
|
||||
*/
|
||||
class ElideConstGlobalValuePass
|
||||
: public mlir::FunctionPass<ElideConstGlobalValuePass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
auto function = getFunction();
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<KrnlConstGlobalValueElision>(&getContext());
|
||||
|
||||
applyPatternsGreedily(function, patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createElideConstGlobalValuePass() {
|
||||
return std::make_unique<ElideConstGlobalValuePass>();
|
||||
}
|
||||
|
||||
static PassRegistration<ElideConstGlobalValuePass> pass("elide-krnl-constants",
|
||||
"Elide the constant values of the Global Krnl operations.");
|
|
@ -143,9 +143,12 @@ public:
|
|||
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
|
||||
assert(krnlGlobalOp.value().hasValue() &&
|
||||
"Krnl Global must always have a value");
|
||||
|
||||
global = rewriter.create<LLVM::GlobalOp>(loc,
|
||||
llvmGlobalType, /*isConstant=*/true,
|
||||
LLVM::Linkage::Internal, name, krnlGlobalOp.value());
|
||||
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
|
||||
}
|
||||
|
||||
// Some frequently used types.
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
//===----- attribute_promotion.cpp - Attribute Promotion
|
||||
//-------------------===//
|
||||
//===----- attribute_promotion.cpp - Attribute Promotion ------------------===//
|
||||
//
|
||||
// Copyright 2020 The IBM Research Authors.
|
||||
//
|
||||
|
|
|
@ -8,6 +8,12 @@ target_include_directories(OMAttributePromotion
|
|||
add_dependencies(OMAttributePromotion
|
||||
OMPromotableConstOperandsOpInterface)
|
||||
|
||||
add_library(OMElideConstants
|
||||
ElideConstants.cpp)
|
||||
target_include_directories(OMElideConstants
|
||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNF_MLIR_SRC_ROOT})
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
|
||||
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
|
||||
add_public_tablegen_target(OMONNXRewriteIncGen)
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
//===----- ElideConstants.cpp - Elide Constant Values ---------------------===//
|
||||
//
|
||||
// Copyright 2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// In practice, the constant values of Constant operations may be large enough
|
||||
// to hinder the readability of the MLIR intermediate representation.
|
||||
//
|
||||
// This file creates a pass which elides the explicit values of Constant
|
||||
// operations. This pass has purely cosmetic purposes and should only be run to
|
||||
// obtain a compact representation of the program when emitting ONNX and KRNL
|
||||
// Dialect code. This pass should never be invoked on code meant to be run.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
#include "src/Pass/Passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/*!
|
||||
* RewritePattern that replaces existing Constant operations
|
||||
* with Constant operations with the same shape information but
|
||||
* no values.
|
||||
*/
|
||||
|
||||
class ConstantValueElision : public OpRewritePattern<ONNXConstantOp> {
|
||||
public:
|
||||
using OpRewritePattern<ONNXConstantOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXConstantOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto constOp = llvm::dyn_cast<ONNXConstantOp>(&op);
|
||||
|
||||
if (constOp->sparse_value().hasValue())
|
||||
emitError(loc, "Only support dense values at this time");
|
||||
|
||||
if (constOp->value().hasValue()) {
|
||||
auto newConstOp = rewriter.create<ONNXConstantOp>(
|
||||
loc, constOp->getResult().getType(), nullptr, nullptr);
|
||||
rewriter.replaceOp(op, newConstOp.getResult());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* Function pass that performs constant value elision.
|
||||
*/
|
||||
class ElideConstantValuePass
|
||||
: public mlir::FunctionPass<ElideConstantValuePass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
auto function = getFunction();
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConstantValueElision>(&getContext());
|
||||
|
||||
applyPatternsGreedily(function, patterns);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/*!
|
||||
* Create a Constant Value Elision pass.
|
||||
*/
|
||||
std::unique_ptr<mlir::Pass> mlir::createElideConstantValuePass() {
|
||||
return std::make_unique<ElideConstantValuePass>();
|
||||
}
|
||||
|
||||
static PassRegistration<ElideConstantValuePass> pass(
|
||||
"elide-constants", "Elide values of constant operations.");
|
14
src/main.cpp
14
src/main.cpp
|
@ -43,9 +43,14 @@ int main(int argc, char *argv[]) {
|
|||
mlir::OwningModuleRef module;
|
||||
processInputFile(inputFilename, emissionTarget, context, module);
|
||||
|
||||
// Input file base name.
|
||||
string outputBaseName =
|
||||
inputFilename.substr(0, inputFilename.find_last_of("."));
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
if (emissionTarget >= EmitONNXIR)
|
||||
if (emissionTarget >= EmitONNXIR) {
|
||||
addONNXToMLIRPasses(pm);
|
||||
}
|
||||
|
||||
if (emissionTarget >= EmitMLIR) {
|
||||
addONNXToKrnlPasses(pm);
|
||||
|
@ -58,12 +63,7 @@ int main(int argc, char *argv[]) {
|
|||
if (mlir::failed(pm.run(*module)))
|
||||
return 4;
|
||||
|
||||
if (emissionTarget == EmitLLVMBC) {
|
||||
// Write LLVM bitcode to disk.
|
||||
EmitLLVMBitCode(module);
|
||||
printf("LLVM bitcode written to ./model.bc");
|
||||
} else
|
||||
module->dump();
|
||||
emitOutputFiles(outputBaseName, emissionTarget, context, module);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -42,13 +42,13 @@ class DummyBackend(onnx.backend.base.Backend):
|
|||
execute_commands([ONNX_MLIR, "temp_model.onnx"])
|
||||
# Call llc to generate object file from bitcode.
|
||||
execute_commands(
|
||||
[LLC, "-filetype=obj", "-relocation-model=pic", "model.bc"])
|
||||
[LLC, "-filetype=obj", "-relocation-model=pic", "temp_model.bc"])
|
||||
# Generate shared library from object file, linking with c runtime.
|
||||
execute_commands([
|
||||
CXX, "-shared", "-fPIC", "model.o", "-o", "model.so",
|
||||
CXX, "-shared", "-fPIC", "temp_model.o", "-o", "temp_model.so",
|
||||
"-L" + RUNTIME_DIR, "-lcruntime"
|
||||
])
|
||||
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
|
||||
return ExecutionSession("./temp_model.so", "_dyn_entry_point_main_graph")
|
||||
|
||||
@classmethod
|
||||
def supports_device(cls, device):
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
// RUN: onnx-mlir-opt --elide-constants %s -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @test_elide_constant(%arg0: tensor<1xf32>) -> tensor<1x10xf32>
|
||||
func @test_elide_constant(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
|
||||
"std.return"(%0) : (tensor<1x10xf32>) -> ()
|
||||
|
||||
// CHECK: %0 = "onnx.Constant"() : () -> tensor<1x10xf32>
|
||||
// CHECK: return %0 : tensor<1x10xf32>
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32>
|
||||
func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> {
|
||||
%0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> memref<1x10xf32>
|
||||
return %0 : memref<1x10xf32>
|
||||
|
||||
// CHECK: %0 = "krnl.global"() {name = "constant_0", shape = [1, 10]} : () -> memref<1x10xf32>
|
||||
// CHECK: return %0 : memref<1x10xf32>
|
||||
}
|
Loading…
Reference in New Issue