Rework output to improve readability of intermediate MLIR code. (#87)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Output of non-value constants. Write full source to file. * Fix e2e tests. * Output constant free and full code in separate files. * Emit separate files. * Move file output management to utils. * Elide the values of glotbal krnl constants. * Add dual file output for Basic flag. * Add tests. * Add passes to cmake file.
This commit is contained in:
parent
363ee26a52
commit
137ce767e6
|
@ -263,7 +263,9 @@ set(ONNXMLIRWholeArchiveLibs
|
||||||
OMShapeInference
|
OMShapeInference
|
||||||
OMShapeInferenceOpInterface
|
OMShapeInferenceOpInterface
|
||||||
OMAttributePromotion
|
OMAttributePromotion
|
||||||
OMPromotableConstOperandsOpInterface)
|
OMPromotableConstOperandsOpInterface
|
||||||
|
OMElideConstants
|
||||||
|
OMElideKrnlGlobalConstants)
|
||||||
|
|
||||||
# Function to construct linkage option for the static libraries that must be
|
# Function to construct linkage option for the static libraries that must be
|
||||||
# linked with --whole-archive (or equivalent).
|
# linked with --whole-archive (or equivalent).
|
||||||
|
|
|
@ -34,6 +34,8 @@ target_link_libraries(onnx-mlir
|
||||||
OMShapeInferenceOpInterface
|
OMShapeInferenceOpInterface
|
||||||
OMAttributePromotion
|
OMAttributePromotion
|
||||||
OMPromotableConstOperandsOpInterface
|
OMPromotableConstOperandsOpInterface
|
||||||
|
OMElideConstants
|
||||||
|
OMElideKrnlGlobalConstants
|
||||||
OMKrnlToAffine
|
OMKrnlToAffine
|
||||||
OMKrnlToLLVM
|
OMKrnlToLLVM
|
||||||
OMONNXToKrnl
|
OMONNXToKrnl
|
||||||
|
|
|
@ -41,8 +41,8 @@ struct ONNXConstantOpLowering : public ConversionPattern {
|
||||||
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
|
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc,
|
||||||
memRefType,
|
memRefType,
|
||||||
rewriter.getI64ArrayAttr(shape),
|
rewriter.getI64ArrayAttr(shape),
|
||||||
constantOp.value().getValue(),
|
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
|
||||||
rewriter.getStringAttr("constant_" + std::to_string(constantID)));
|
constantOp.value().getValue());
|
||||||
|
|
||||||
// Increment constant ID:
|
// Increment constant ID:
|
||||||
constantID++;
|
constantID++;
|
||||||
|
|
|
@ -199,7 +199,7 @@ def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
|
||||||
Operation for holding global data values.
|
Operation for holding global data values.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyAttr:$shape, AnyAttr:$value, StrAttr:$name);
|
let arguments = (ins AnyAttr:$shape, StrAttr:$name, OptionalAttr<AnyAttr>:$value);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
|
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
|
||||||
|
|
||||||
let parser = ?;
|
let parser = ?;
|
||||||
|
|
|
@ -9,6 +9,9 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "src/MainUtils.hpp"
|
#include "src/MainUtils.hpp"
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
|
@ -35,9 +38,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
|
void EmitLLVMBitCode(
|
||||||
|
const mlir::OwningModuleRef &module, string outputFilename) {
|
||||||
error_code error;
|
error_code error;
|
||||||
llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error,
|
llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error,
|
||||||
llvm::sys::fs::F_None);
|
llvm::sys::fs::F_None);
|
||||||
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module),
|
||||||
moduleBitcodeStream);
|
moduleBitcodeStream);
|
||||||
|
@ -97,3 +101,67 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget,
|
||||||
LoadMLIR(inputFilename, context, module);
|
LoadMLIR(inputFilename, context, module);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void outputCode(
|
||||||
|
mlir::OwningModuleRef &module, string filename, string extension) {
|
||||||
|
// Start a separate process to redirect the model output. I/O redirection
|
||||||
|
// changes will not be visible to the parent process.
|
||||||
|
if (fork() == 0) {
|
||||||
|
const char * tempFilename = (filename + extension).c_str();
|
||||||
|
freopen(tempFilename, "w", stderr);
|
||||||
|
module->dump();
|
||||||
|
fclose(stderr);
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
|
||||||
|
mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||||
|
// For EmitONNXIR and EmitMLIR the constant value are embedded in the code
|
||||||
|
// thus making the code hard to read. These values can be elided by emitting
|
||||||
|
// two versions of the same source code:
|
||||||
|
// (1) a version with all the constant values included meant for being passed
|
||||||
|
// back to onnx-mlir for further processing and stored in:
|
||||||
|
//
|
||||||
|
// <name>.onnx.mlir
|
||||||
|
//
|
||||||
|
// (2) a version without constants meant for being inspected by users and
|
||||||
|
// stored in:
|
||||||
|
//
|
||||||
|
// <name>.mlir
|
||||||
|
//
|
||||||
|
// In the case of the LLVM Dialect IR the constant values are grouped
|
||||||
|
// outside the function code at the beginning of the file in which case the
|
||||||
|
// elision of these constants is not strictly required. Elision is also not
|
||||||
|
// necessary when emitting the .bc file.
|
||||||
|
if (emissionTarget == EmitLLVMBC) {
|
||||||
|
// Write LLVM bitcode to disk.
|
||||||
|
string outputFilename = outputBaseName + ".bc";
|
||||||
|
EmitLLVMBitCode(module, outputFilename);
|
||||||
|
printf("LLVM bitcode written to %s\n", outputFilename.c_str());
|
||||||
|
} else {
|
||||||
|
// Emit the version with all constants included.
|
||||||
|
outputCode(module, outputBaseName, ".onnx.mlir");
|
||||||
|
printf("Full MLIR code written to: \n\t%s\n\n",
|
||||||
|
(outputBaseName + ".onnx.mlir").c_str());
|
||||||
|
|
||||||
|
// Apply specific passes to clean up the code where necessary.
|
||||||
|
mlir::PassManager cleanSourcePM(&context);
|
||||||
|
if (emissionTarget == EmitONNXIR || emissionTarget == EmitONNXBasic)
|
||||||
|
cleanSourcePM.addPass(mlir::createElideConstantValuePass());
|
||||||
|
if (emissionTarget == EmitMLIR)
|
||||||
|
cleanSourcePM.addPass(mlir::createElideConstGlobalValuePass());
|
||||||
|
|
||||||
|
if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR ||
|
||||||
|
emissionTarget == EmitMLIR) {
|
||||||
|
if (mlir::failed(cleanSourcePM.run(*module)))
|
||||||
|
llvm::errs() << "Could not apply simplification passes.\n";
|
||||||
|
outputCode(module, outputBaseName, ".mlir");
|
||||||
|
printf("Constant-free MLIR Code written to: \n\t%s\n\n",
|
||||||
|
(outputBaseName + ".mlir").c_str());
|
||||||
|
|
||||||
|
printf("Use:\n\t%s\nto continue lowering the code to other dialects.\n",
|
||||||
|
(outputBaseName + ".onnx.mlir").c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -48,7 +48,8 @@ enum EmissionTargetType {
|
||||||
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
|
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
|
||||||
mlir::OwningModuleRef &module);
|
mlir::OwningModuleRef &module);
|
||||||
|
|
||||||
void EmitLLVMBitCode(const mlir::OwningModuleRef &module);
|
void EmitLLVMBitCode(
|
||||||
|
const mlir::OwningModuleRef &module, std::string outputFilename);
|
||||||
|
|
||||||
void registerDialects();
|
void registerDialects();
|
||||||
|
|
||||||
|
@ -63,3 +64,11 @@ void addKrnlToLLVMPasses(mlir::PassManager &pm);
|
||||||
void processInputFile(std::string inputFilename,
|
void processInputFile(std::string inputFilename,
|
||||||
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
||||||
mlir::OwningModuleRef &module);
|
mlir::OwningModuleRef &module);
|
||||||
|
|
||||||
|
void outputCode(
|
||||||
|
mlir::OwningModuleRef &module, std::string filename,
|
||||||
|
std::string extension);
|
||||||
|
|
||||||
|
void emitOutputFiles(std::string outputBaseName,
|
||||||
|
EmissionTargetType emissionTarget, mlir::MLIRContext &context,
|
||||||
|
mlir::OwningModuleRef &module);
|
|
@ -23,12 +23,18 @@ std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
/// Pass for promoting constant operands to attributes.
|
/// Pass for promoting constant operands to attributes.
|
||||||
std::unique_ptr<Pass> createAttributePromotionPass();
|
std::unique_ptr<Pass> createAttributePromotionPass();
|
||||||
|
|
||||||
|
/// Pass for eliding the values of constant operations.
|
||||||
|
std::unique_ptr<Pass> createElideConstantValuePass();
|
||||||
|
|
||||||
/// Add pass for lowering to Krnl IR.
|
/// Add pass for lowering to Krnl IR.
|
||||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||||
|
|
||||||
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
||||||
std::unique_ptr<Pass> createLowerKrnlPass();
|
std::unique_ptr<Pass> createLowerKrnlPass();
|
||||||
|
|
||||||
|
/// Pass for eliding the values of global Krnl operations.
|
||||||
|
std::unique_ptr<Pass> createElideConstGlobalValuePass();
|
||||||
|
|
||||||
/// Pass for lowering Krnl dialect to LLVM dialect.
|
/// Pass for lowering Krnl dialect to LLVM dialect.
|
||||||
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
|
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
|
||||||
|
|
||||||
|
|
|
@ -24,4 +24,16 @@ add_dependencies(OMKrnlToLLVM
|
||||||
OMKrnlOps
|
OMKrnlOps
|
||||||
OMONNXOps)
|
OMONNXOps)
|
||||||
|
|
||||||
|
add_library(OMElideKrnlGlobalConstants
|
||||||
|
ElideKrnlGlobalConstants.cpp)
|
||||||
|
target_include_directories(OMElideKrnlGlobalConstants
|
||||||
|
PRIVATE
|
||||||
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
target_link_libraries(OMElideKrnlGlobalConstants
|
||||||
|
${MLIRLibs}
|
||||||
|
OMKrnlOps
|
||||||
|
OMONNXOps)
|
||||||
|
|
||||||
add_subdirectory(ONNX)
|
add_subdirectory(ONNX)
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// In practice, the constant values of Global Krnl operations may be large
|
||||||
|
// enough to hinder the readability of the MLIR intermediate representation.
|
||||||
|
//
|
||||||
|
// This file creates a pass which elides the explicit values of constant
|
||||||
|
// global operations. This pass has purely cosmetic purposes and should only be
|
||||||
|
// run to obtain a compact representation of the program when emitting Krnl
|
||||||
|
// dialect code. This pass should never be invoked on code meant to be run.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* RewritePattern that replaces existing constant Krnl global values
|
||||||
|
* with a similar operation which preserves all attributes except the value
|
||||||
|
* attribute.
|
||||||
|
*/
|
||||||
|
|
||||||
|
class KrnlConstGlobalValueElision : public OpRewritePattern<KrnlGlobalOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(KrnlGlobalOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
if (op.value().hasValue()) {
|
||||||
|
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(
|
||||||
|
loc, op.getResult().getType(), op.shape(), op.name(), nullptr);
|
||||||
|
rewriter.replaceOp(op, newGlobalOp.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Function pass that performs constant value elision of Krnl globals.
|
||||||
|
*/
|
||||||
|
class ElideConstGlobalValuePass
|
||||||
|
: public mlir::FunctionPass<ElideConstGlobalValuePass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto function = getFunction();
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<KrnlConstGlobalValueElision>(&getContext());
|
||||||
|
|
||||||
|
applyPatternsGreedily(function, patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::createElideConstGlobalValuePass() {
|
||||||
|
return std::make_unique<ElideConstGlobalValuePass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<ElideConstGlobalValuePass> pass("elide-krnl-constants",
|
||||||
|
"Elide the constant values of the Global Krnl operations.");
|
|
@ -143,9 +143,12 @@ public:
|
||||||
OpBuilder::InsertionGuard insertGuard(rewriter);
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(module.getBody());
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
|
||||||
|
assert(krnlGlobalOp.value().hasValue() &&
|
||||||
|
"Krnl Global must always have a value");
|
||||||
|
|
||||||
global = rewriter.create<LLVM::GlobalOp>(loc,
|
global = rewriter.create<LLVM::GlobalOp>(loc,
|
||||||
llvmGlobalType, /*isConstant=*/true,
|
llvmGlobalType, /*isConstant=*/true,
|
||||||
LLVM::Linkage::Internal, name, krnlGlobalOp.value());
|
LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some frequently used types.
|
// Some frequently used types.
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
//===----- attribute_promotion.cpp - Attribute Promotion
|
//===----- attribute_promotion.cpp - Attribute Promotion ------------------===//
|
||||||
//-------------------===//
|
|
||||||
//
|
//
|
||||||
// Copyright 2020 The IBM Research Authors.
|
// Copyright 2020 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
|
|
@ -8,6 +8,12 @@ target_include_directories(OMAttributePromotion
|
||||||
add_dependencies(OMAttributePromotion
|
add_dependencies(OMAttributePromotion
|
||||||
OMPromotableConstOperandsOpInterface)
|
OMPromotableConstOperandsOpInterface)
|
||||||
|
|
||||||
|
add_library(OMElideConstants
|
||||||
|
ElideConstants.cpp)
|
||||||
|
target_include_directories(OMElideConstants
|
||||||
|
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||||
|
${ONNF_MLIR_SRC_ROOT})
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
|
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
|
||||||
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
|
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(OMONNXRewriteIncGen)
|
add_public_tablegen_target(OMONNXRewriteIncGen)
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
//===----- ElideConstants.cpp - Elide Constant Values ---------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// In practice, the constant values of Constant operations may be large enough
|
||||||
|
// to hinder the readability of the MLIR intermediate representation.
|
||||||
|
//
|
||||||
|
// This file creates a pass which elides the explicit values of Constant
|
||||||
|
// operations. This pass has purely cosmetic purposes and should only be run to
|
||||||
|
// obtain a compact representation of the program when emitting ONNX and KRNL
|
||||||
|
// Dialect code. This pass should never be invoked on code meant to be run.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* RewritePattern that replaces existing Constant operations
|
||||||
|
* with Constant operations with the same shape information but
|
||||||
|
* no values.
|
||||||
|
*/
|
||||||
|
|
||||||
|
class ConstantValueElision : public OpRewritePattern<ONNXConstantOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<ONNXConstantOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXConstantOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto constOp = llvm::dyn_cast<ONNXConstantOp>(&op);
|
||||||
|
|
||||||
|
if (constOp->sparse_value().hasValue())
|
||||||
|
emitError(loc, "Only support dense values at this time");
|
||||||
|
|
||||||
|
if (constOp->value().hasValue()) {
|
||||||
|
auto newConstOp = rewriter.create<ONNXConstantOp>(
|
||||||
|
loc, constOp->getResult().getType(), nullptr, nullptr);
|
||||||
|
rewriter.replaceOp(op, newConstOp.getResult());
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Function pass that performs constant value elision.
|
||||||
|
*/
|
||||||
|
class ElideConstantValuePass
|
||||||
|
: public mlir::FunctionPass<ElideConstantValuePass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto function = getFunction();
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<ConstantValueElision>(&getContext());
|
||||||
|
|
||||||
|
applyPatternsGreedily(function, patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Create a Constant Value Elision pass.
|
||||||
|
*/
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::createElideConstantValuePass() {
|
||||||
|
return std::make_unique<ElideConstantValuePass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<ElideConstantValuePass> pass(
|
||||||
|
"elide-constants", "Elide values of constant operations.");
|
14
src/main.cpp
14
src/main.cpp
|
@ -43,9 +43,14 @@ int main(int argc, char *argv[]) {
|
||||||
mlir::OwningModuleRef module;
|
mlir::OwningModuleRef module;
|
||||||
processInputFile(inputFilename, emissionTarget, context, module);
|
processInputFile(inputFilename, emissionTarget, context, module);
|
||||||
|
|
||||||
|
// Input file base name.
|
||||||
|
string outputBaseName =
|
||||||
|
inputFilename.substr(0, inputFilename.find_last_of("."));
|
||||||
|
|
||||||
mlir::PassManager pm(&context);
|
mlir::PassManager pm(&context);
|
||||||
if (emissionTarget >= EmitONNXIR)
|
if (emissionTarget >= EmitONNXIR) {
|
||||||
addONNXToMLIRPasses(pm);
|
addONNXToMLIRPasses(pm);
|
||||||
|
}
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
addONNXToKrnlPasses(pm);
|
addONNXToKrnlPasses(pm);
|
||||||
|
@ -58,12 +63,7 @@ int main(int argc, char *argv[]) {
|
||||||
if (mlir::failed(pm.run(*module)))
|
if (mlir::failed(pm.run(*module)))
|
||||||
return 4;
|
return 4;
|
||||||
|
|
||||||
if (emissionTarget == EmitLLVMBC) {
|
emitOutputFiles(outputBaseName, emissionTarget, context, module);
|
||||||
// Write LLVM bitcode to disk.
|
|
||||||
EmitLLVMBitCode(module);
|
|
||||||
printf("LLVM bitcode written to ./model.bc");
|
|
||||||
} else
|
|
||||||
module->dump();
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,13 +42,13 @@ class DummyBackend(onnx.backend.base.Backend):
|
||||||
execute_commands([ONNX_MLIR, "temp_model.onnx"])
|
execute_commands([ONNX_MLIR, "temp_model.onnx"])
|
||||||
# Call llc to generate object file from bitcode.
|
# Call llc to generate object file from bitcode.
|
||||||
execute_commands(
|
execute_commands(
|
||||||
[LLC, "-filetype=obj", "-relocation-model=pic", "model.bc"])
|
[LLC, "-filetype=obj", "-relocation-model=pic", "temp_model.bc"])
|
||||||
# Generate shared library from object file, linking with c runtime.
|
# Generate shared library from object file, linking with c runtime.
|
||||||
execute_commands([
|
execute_commands([
|
||||||
CXX, "-shared", "-fPIC", "model.o", "-o", "model.so",
|
CXX, "-shared", "-fPIC", "temp_model.o", "-o", "temp_model.so",
|
||||||
"-L" + RUNTIME_DIR, "-lcruntime"
|
"-L" + RUNTIME_DIR, "-lcruntime"
|
||||||
])
|
])
|
||||||
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
|
return ExecutionSession("./temp_model.so", "_dyn_entry_point_main_graph")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_device(cls, device):
|
def supports_device(cls, device):
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
// RUN: onnx-mlir-opt --elide-constants %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @test_elide_constant(%arg0: tensor<1xf32>) -> tensor<1x10xf32>
|
||||||
|
func @test_elide_constant(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
|
||||||
|
"std.return"(%0) : (tensor<1x10xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK: %0 = "onnx.Constant"() : () -> tensor<1x10xf32>
|
||||||
|
// CHECK: return %0 : tensor<1x10xf32>
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32>
|
||||||
|
func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> {
|
||||||
|
%0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> memref<1x10xf32>
|
||||||
|
return %0 : memref<1x10xf32>
|
||||||
|
|
||||||
|
// CHECK: %0 = "krnl.global"() {name = "constant_0", shape = [1, 10]} : () -> memref<1x10xf32>
|
||||||
|
// CHECK: return %0 : memref<1x10xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue