Compiling Models with Large Constant Arrays (#146)

* PoC works.

* MNist works.

* Clean up.

* Fix test.

* Make Linux work.

* Use consistent symbol name.

* Fix variable name.

* Fix array addr access.

* Bug fix.

* Bug fix.

* install before running e2e tests.

* Fix build config.

* Use sudo when installing.

* Make embeddedDataLoader position independent.

* Enable ResNet50.

* Format code.

* Format MainUtil.

* Try not using sudo to install.

* Supply runtime dir via environment variable.

* Dump problematic operation.

* Dump entire function.

* Debug.

* Dump input.

* Dump constant op.

* Debug.

* Debug.

* Debug.

* Print to stderr.

* take care of endianness.

* Use endianness-aware execution session.

* Fix ZLinux error.

* Include warning when desired output endianness can't be deduced.

* Remove debug code.

* Remove debug code in shape inference.

* Support binary-decoder for testing constants packing.

* Support filename, move-to-file, elision-threshold configurations in constant packing pass for easy testing.

* Add lit test, fix lit test type mismatch.

* Add more consts packing tests.

* Ensure intermediate files are properly cleaned up.

* No need for constant elimination.

* Link with threading libraries.

* Remove debug code.

* Format code.

* More tests.

* test nit.

* Remove debug code.

* Reduce hard-coded constants.

* Use temporary and unique working directory for hosting model parameters.

* Test if it works.

* Try to find objcopy.

* Rename symbols using objcopy.

* Move sanitized name to linux section.

* Use verbose mode for debugging.

* Disambiguate pass constructor.

* Fix symbol name.

* Use Command API to build and execute commands.

* Move linux to use Command API.

* Fix reset args.

* Execute redefine sym.

* Format code.

* Do not use verbose mode for CircleCI.

* Remove debug code.

* Prettify code, add comments.

* getSegmentData -> getEmbeddedConstPool

* vector -> std::vector.

* Make sure we properly clean up intermediate files.

* Fix test cases.

* Add runtime directory.

* Trigger rebuild.

* [Merge with master] fix debug script.

* Diable affine fusion pass for now.

* Support generic fallback const packing mechanism.

* Remove debug code.

* Handle the case where objcopy is not available.

* Fix Windows missing types.

* Support int64.

* Copy packed constant to a local directory for non-Linux/Mac platforms.

* Nit: remove debug code, refactor const pack preprocessing out as a separate function.

* Cannot make preprocessConstPack a standalone function because file removers are stack-allocated, and they are deallocated prematurely when function stack gets popped, deleteing intermediate files too early.

* Don't require executable filename.

* Import ONNX data types directly.

* Fix LIT test.

* Bug fix, use moved string value.

* Remove redundant filenames.

* Fix CMake script.

* Embed endianness information as a symbol, and check during runtime.

* More comments, update lit tests.

* Fix lit test on BE machine.

* Copyright notices.
This commit is contained in:
Tian Jin 2020-06-12 10:27:05 +08:00 committed by GitHub
parent 8c4d527eea
commit e0ae583da0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 994 additions and 129 deletions

View File

@ -71,4 +71,4 @@ cmake -DCMAKE_INSTALL_PREFIX=${INSTALL_PATH} .. \
make -j$(nproc) onnx-mlir
make -j$(nproc) check-onnx-lit
make -j$(nproc) check-onnx-backend
RUNTIME_DIR=$(pwd)/lib make -j$(nproc) check-onnx-backend

View File

@ -40,14 +40,14 @@ jobs:
command: |
sudo pip install -q -e ./onnx-mlir/third_party/onnx
cd onnx-mlir/build
VERBOSE=1 cmake --build . --target check-onnx-backend
RUNTIME_DIR=$(pwd)/lib cmake --build . --target check-onnx-backend
- run:
name: Run Unit Tests
command: |
cd onnx-mlir/build
# Need to include the bin directory in $PATH,
# otherwise CTest fails to find the test executables.
PATH=$(pwd)/bin:$PATH make test -j$(nproc)
RUNTIME_DIR=$(pwd)/lib PATH=$(pwd)/bin:$PATH make test -j$(nproc)
- run:
name: Run DocCheck
command: cd onnx-mlir/build && cmake --build . --target check-doc

View File

@ -269,6 +269,7 @@ set(ONNXMLIRWholeArchiveLibs
OMPromotableConstOperandsOpInterface
OMElideConstants
OMElideKrnlGlobalConstants
OMPackKrnlGlobalConstants
OMEnableMemoryPool)
# Function to construct linkage option for the static libraries that must be

View File

@ -7,6 +7,8 @@
// Helper methods for handling input ONNX models.
//
//===----------------------------------------------------------------------===//
#include <llvm/Support/Endian.h>
#include <llvm/Support/SwapByteOrder.h>
#include "src/Builder/FrontendDialectHelper.hpp"
@ -104,8 +106,16 @@ static std::vector<T> CreateArrayAttribute(onnx::TensorProto initializer) {
std::copy(initializer.raw_data().begin(), initializer.raw_data().end(),
back_inserter(byteInitializer));
size = initializer.raw_data().size() / sizeof(T);
T *res = reinterpret_cast<T *>(&byteInitializer[0]);
return std::vector<T>(res, res + size);
T *arrayPtr = reinterpret_cast<T *>(&byteInitializer[0]);
auto array = std::vector<T>(arrayPtr, arrayPtr + size);
// Perform byte swap if system endianness is BE.
// ONNX tensor content raw data is always in LE.
if (llvm::support::endian::system_endianness() !=
llvm::support::endianness::little)
for (int i = 0; i < array.size(); i++)
llvm::sys::swapByteOrder<T>(array[i]);
return array;
}
// copy, no need to take care of endianness

View File

@ -25,9 +25,6 @@ if(NOT EXISTS "${LLVM_PROJ_BUILD}/bin/llc")
message(ERROR "Cannot find llc.")
endif()
# Get the compiler command name, the C++ compiler is needed to to translate
# object files to shared libraries.
get_filename_component(CXX_COMPILER_FILENAME ${CMAKE_CXX_COMPILER} NAME)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in
${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp)
@ -57,6 +54,7 @@ target_link_libraries(MainUtils
OMResultTypeInferenceOpInterface
OMElideConstants
OMElideKrnlGlobalConstants
OMPackKrnlGlobalConstants
OMEnableMemoryPool
OMKrnlToAffine
OMKrnlToLLVM
@ -71,6 +69,9 @@ if (INCLUDE_ONNX_ML)
add_dependencies(MainUtils OMMLONNXOpsInc)
endif()
add_dependencies(onnx-mlir cruntime)
add_dependencies(onnx-mlir EmbeddedDataLoader)
target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_SRC_ROOT})
target_include_directories(onnx-mlir PRIVATE ${CMAKE_BINARY_DIR})
target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_BIN_ROOT})

View File

@ -38,9 +38,11 @@ struct ONNXConstantOpLowering : public ConversionPattern {
// Emit the constant global in Krnl dialect.
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
rewriter.getI64ArrayAttr(shape),
/*shape=*/rewriter.getI64ArrayAttr(shape),
/*name=*/
rewriter.getStringAttr("constant_" + std::to_string(constantID)),
constantOp.value().getValue());
/*value=*/constantOp.value().getValue(),
/*offset=*/nullptr);
// Increment constant ID:
constantID++;

View File

@ -196,16 +196,66 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
let summary = "Krnl global operation";
let description = [{
Operation for holding global data values.
Operation for holding global data values. A global constant can have a
meaningful name recorded as its `name` attribute. Its content is stored
in the `value` dense element attribute. Alternatively, if the constants
are packed together, `offset` records the byte offset in the global
constant pool from which the current constant is to be read.
}];
let arguments = (ins AnyAttr:$shape, StrAttr:$name, OptionalAttr<AnyAttr>:$value);
let arguments = (ins AnyAttr:$shape,
StrAttr:$name, OptionalAttr<AnyAttr>:$value, OptionalAttr<I64Attr>:$offset);
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
let parser = ?;
let printer = ?;
}
def KrnlPackedConstantOp : Op<Krnl_Dialect, "packed_const"> {
let summary = "Krnl packed constant operation";
let description = [{
Operation for holding packed constants.
}];
let arguments = (ins I64Attr:$size_in_bytes,
BoolAttr:$is_le,
OptionalAttr<AnyIntElementsAttr<8>>:$value,
OptionalAttr<StrAttr>:$file_name);
let results = (outs I64:$output);
let extraClassDeclaration = [{
// The *path* to the file storing the constant pack on disk is encoded
// as a global variable in the LLVM module of the lowered model.
// getConstPackFilePathSymbolName returns the name of this symbol;
// getConstPackFilePathStrLenSymbolName returns the name of the symbol holding
// a constant value equal to the length of the file path.
static StringRef getConstPackFilePathSymbolName() { return "constPackFilePath"; }
static StringRef getConstPackFilePathStrLenSymbolName() { return "constPackFilePathStrLen"; }
// The *name* of the file storing the constant pack is also recorded for
// convenience. Similarly, getConstPackFileNameSymbolName and
// getConstPackFileNameStrLenSymbolName returns records the symbol holding
// the string constant representing the filename and the length of this
// string constant.
static StringRef getConstPackFileNameSymbolName() { return "constPackFileName"; }
static StringRef getConstPackFileNameStrLenSymbolName() { return "constPackFileNameStrLen"; }
// We record whether the constant pack is stored in LE byte order as a
// int8 symbol; it is to be interpreted as a boolean switch - with 0
// meaning that the constant pack is not stored in LE byte order and
// non-0 values meaning that it is stored in LE byte order.
static StringRef getConstPackIsLESymbolName() { return "constPackIsLE"; }
// The name of a function we call to read packed constants embedded within
// the current binary executable/library, or in the case of unsupported platform,
// from a binary constant pack file.
static StringRef getEmbeddedDataLoaderMethodName() {
return "getEmbeddedConstPool";
}
}];
let parser = ?;
let printer = ?;
}
def KrnlGetRefOp : Op<Krnl_Dialect, "getref"> {
let summary = "Krnl a MemRef from within another MemRef starting at a specific offset.";
let description = [{

View File

@ -1138,7 +1138,6 @@ LogicalResult ONNXReshapeOp::inferShapes() {
if (constantOp) {
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
if (!valueAttribute)
return emitError("DenseElementsAttr expected");
// Get dims from valueAttribute.

View File

@ -1,9 +1,10 @@
#pragma once
#include <string>
namespace onnx_mlir {
const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc";
const std::string kCxxPath = "@CMAKE_CXX_COMPILER@";
const std::string kCxxFileName = "@CXX_COMPILER_FILENAME@";
const std::string kRuntimeDirPath = "@CMAKE_BINARY_DIR@/lib";
}
const std::string kLinkerPath = "@CMAKE_LINKER@";
const std::string kObjCopyPath = "@CMAKE_OBJCOPY@";
} // namespace onnx_mlir

View File

@ -9,8 +9,15 @@
//===----------------------------------------------------------------------===//
#include <cstdio>
#include <cstdlib>
#include <fcntl.h>
#include <regex>
#include <string>
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/Program.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/SymbolTable.h>
#include "src/ExternalUtil.hpp"
#include "src/MainUtils.hpp"
@ -26,6 +33,63 @@
using namespace std;
using namespace onnx_mlir;
namespace {
llvm::Optional<std::string> getEnvVar(std::string name) {
if (const char *envVerbose = std::getenv(name.c_str()))
return std::string(envVerbose);
return llvm::None;
}
// Helper struct to make command construction and execution easy & readable.
struct Command {
std::string _path;
std::vector<std::string> _args;
Command(std::string exePath)
: _path(std::move(exePath)),
_args({llvm::sys::path::filename(_path).str()}) {}
// Append a single string argument.
Command &appendStr(const std::string &arg) {
_args.emplace_back(arg);
return *this;
}
// Append a list of string arguments.
Command &appendList(const std::vector<std::string> &args) {
_args.insert(_args.end(), args.begin(), args.end());
return *this;
}
// Reset arguments.
Command &resetArgs() {
auto exeFileName = _args.front();
_args.clear();
_args.emplace_back(exeFileName);
return *this;
}
// Execute command.
void exec() {
auto argsRef = std::vector<llvm::StringRef>(_args.begin(), _args.end());
bool verbose = false;
if (const auto &verboseStr = getEnvVar("VERBOSE"))
istringstream(verboseStr.getValue()) >> verbose;
// If in verbose mode, print out command before execution.
if (verbose)
cout << llvm::join(argsRef, " ") << "\n";
int rc = llvm::sys::ExecuteAndWait(_path, llvm::makeArrayRef(argsRef));
if (rc != 0) {
fprintf(stderr, "%s\n", llvm::join(argsRef, " ").c_str());
llvm_unreachable("Command execution failed.");
}
}
};
} // namespace
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) {
// Handle '.mlir' input to the ONNX MLIR frontend.
@ -50,31 +114,123 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
void compileModuleToSharedLibrary(
const mlir::OwningModuleRef &module, string outputBaseName) {
// Extract constant pack file name, which is embedded as a symbol in the
// module being compiled.
auto constPackFilePathSym = (*module).lookupSymbol<mlir::LLVM::GlobalOp>(
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName());
auto constPackFilePath = constPackFilePathSym.valueAttr()
.dyn_cast_or_null<mlir::StringAttr>()
.getValue()
.str();
llvm::FileRemover constPackRemover(constPackFilePath);
llvm::Optional<std::string> constPackObjPath;
#if __APPLE__
// Create a empty stub file, compile it to an empty obj file.
llvm::SmallVector<char, 20> stubSrcPath;
llvm::sys::fs::createTemporaryFile("stub", "cpp", stubSrcPath);
llvm::FileRemover subSrcRemover(stubSrcPath);
std::string stubSrcPathStr(stubSrcPath.begin(), stubSrcPath.end());
Command createStubObj(/*exePath=*/kCxxPath);
std::string stubObjPathStr = stubSrcPathStr + ".o";
createStubObj.appendList({"-o", stubObjPathStr})
.appendList({"-c", stubSrcPathStr})
.exec();
llvm::FileRemover stubObjRemover(stubObjPathStr);
// Embed data into the empty stub obj file.
constPackObjPath = constPackFilePath + ".o";
Command genParamObj(/*exePath=*/kLinkerPath);
genParamObj.appendStr("-r")
.appendList({"-o", constPackObjPath.getValue()})
.appendList({"-sectcreate", "binary", "param", constPackFilePath})
.appendStr(stubObjPathStr)
.exec();
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
#elif __linux__
// Create param.o holding packed parameter values.
constPackObjPath = constPackFilePath + ".o";
Command genParamObj(/*exePath=*/kLinkerPath);
genParamObj.appendStr("-r")
.appendList({"-b", "binary"})
.appendList({"-o", constPackObjPath.getValue()})
.appendStr(constPackFilePath)
.exec();
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
// Figure out what is the default symbol name describing the start/end
// address of the embedded data.
std::regex e("[^0-9A-Za-z]");
auto sanitizedName =
"_binary_" + std::regex_replace(constPackFilePath, e, "_");
// Rename the symbols to saner ones expected by the runtime function.
Command redefineSym(/*exePath=*/kObjCopyPath);
redefineSym.appendStr("--redefine-sym")
.appendStr(sanitizedName + "_start=_binary_param_bin_start")
.appendStr(constPackObjPath.getValue())
.exec();
redefineSym.resetArgs()
.appendStr("--redefine-sym")
.appendStr(sanitizedName + "_end=_binary_param_bin_end")
.appendStr(constPackObjPath.getValue())
.exec();
#else
llvm::SmallVector<char, 10> permConstPackFileName(
constPackFilePath.begin(), constPackFilePath.end());
llvm::sys::path::replace_extension(permConstPackFileName, "bin");
std::string permConstPackFileNameStr(
permConstPackFileName.begin(), permConstPackFileName.end());
auto constPackFileName = llvm::sys::path::filename(outputBaseName) + "." +
llvm::sys::path::filename(permConstPackFileNameStr);
llvm::sys::fs::rename(constPackFilePath, constPackFileName);
mlir::Builder builder(*module);
(*module)
.lookupSymbol<mlir::LLVM::GlobalOp>(
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName())
.valueAttr(builder.getStringAttr(constPackFileName.str()));
(*module)
.lookupSymbol<mlir::LLVM::GlobalOp>(
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName())
.valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size()));
#endif
// Write LLVM bitcode.
string outputFilename = outputBaseName + ".bc";
error_code error;
llvm::raw_fd_ostream moduleBitcodeStream(
outputFilename, error, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile(
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
moduleBitcodeStream.flush();
llvm::FileRemover bcRemover(outputFilename);
// Compile bitcode to object file.
std::vector<std::string> llcArgs = {
"llc", "-filetype=obj", "-relocation-model=pic", outputFilename};
auto llcArgStrRefs =
std::vector<llvm::StringRef>(llcArgs.begin(), llcArgs.end());
llvm::sys::ExecuteAndWait(kLlcPath, llvm::makeArrayRef(llcArgStrRefs));
// Compile LLVM bitcode to object file.
Command llvmToObj(/*exePath=*/kLlcPath);
llvmToObj.appendStr("-filetype=obj");
llvmToObj.appendStr("-relocation-model=pic");
llvmToObj.appendStr(outputFilename);
llvmToObj.exec();
std::string modelObjPath = outputBaseName + ".o";
llvm::FileRemover modelObjRemover(modelObjPath);
// Link with runtime.
// TODO(tjingrant): link with runtime library in LLVM, and make the shared
// library more self-contained.
std::vector<std::string> cxxArgs = {kCxxFileName, "-shared", "-fPIC",
outputBaseName + ".o", "-o", outputBaseName + ".so",
"-L" + kRuntimeDirPath, "-lcruntime", "-Wl,-rpath," + kRuntimeDirPath};
auto argsArrayRefVector =
std::vector<llvm::StringRef>(cxxArgs.begin(), cxxArgs.end());
llvm::sys::ExecuteAndWait(kCxxPath, llvm::makeArrayRef(argsArrayRefVector));
std::string runtimeDirInclFlag;
if (getEnvVar("RUNTIME_DIR").hasValue())
runtimeDirInclFlag = "-L" + getEnvVar("RUNTIME_DIR").getValue();
// Link everything into a shared object.
Command link(kCxxPath);
link.appendList({"-shared", "-fPIC"})
.appendStr(modelObjPath)
.appendStr(constPackObjPath.getValueOr(""))
.appendList({"-o", outputBaseName + ".so"})
.appendStr(runtimeDirInclFlag)
.appendList({"-lEmbeddedDataLoader", "-lcruntime"})
.exec();
}
void registerDialects() {
@ -98,6 +254,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) {
void addONNXToKrnlPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerToKrnlPass());
pm.addPass(mlir::createPackKrnlGlobalConstantsPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
@ -110,7 +267,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) {
void addKrnlToAffinePasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerKrnlPass());
// Fuse loops in Affine dialect.
pm.addPass(mlir::createLoopFusionPass());
// pm.addPass(mlir::createLoopFusionPass());
}
void addKrnlToLLVMPasses(mlir::PassManager &pm) {

View File

@ -43,4 +43,7 @@ std::unique_ptr<Pass> createElideConstGlobalValuePass();
/// Pass for lowering Krnl dialect to LLVM dialect.
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
/// Pass for packing Krnl global constants.
std::unique_ptr<Pass> createPackKrnlGlobalConstantsPass();
} // end namespace mlir

View File

@ -1,6 +1,6 @@
# Create shared libcruntime.so since model.so linkage for backend tests
# will fail on x86 Linux if cruntime is statically linked.
add_library(cruntime SHARED
add_library(cruntime STATIC
DynMemRef.cpp
DynMemRef.h
DataType.h)
@ -34,6 +34,13 @@ target_include_directories(PyRuntime PRIVATE
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
add_library(EmbeddedDataLoader STATIC
GetEmbeddedConstPool.h
GetEmbeddedConstPool.cpp)
set_target_properties(EmbeddedDataLoader PROPERTIES
POSITION_INDEPENDENT_CODE TRUE)
add_dependencies(PyRuntime cruntime)
install(FILES DynMemRef.h DESTINATION include)
install(TARGETS cruntime DESTINATION lib)
install(TARGETS EmbeddedDataLoader DESTINATION lib)

View File

@ -14,6 +14,10 @@
#include <string>
#include <vector>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "DynMemRef.h"
namespace {

View File

@ -0,0 +1,88 @@
//===---- GetEmbeddedConstPool.h - Get Embedded Const Pool API Func Impl---===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains runtime API implementation to extract constant pool values
// embedded within the shared library binary files.
//
//===----------------------------------------------------------------------===//
#include "GetEmbeddedConstPool.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
// Adapted from:
// https://developer.ibm.com/technologies/systems/articles/au-endianc/
const int i = 1;
#define IS_SYSTEM_LE() (!((*(char *)&i) == 0))
#define XOR(a, b) (!(a) != !(b))
extern const char constPackIsLE;
void checkEndianness() {
if (XOR(IS_SYSTEM_LE(), constPackIsLE)) {
fprintf(stderr, "Constant pack is stored in a byte order that is not "
"native to this current system.");
exit(1);
}
}
#if __APPLE__
#include <mach-o/getsect.h>
extern const struct mach_header_64 _mh_dylib_header;
void *getEmbeddedConstPool(int64_t size_in_byte) {
checkEndianness();
size_t size = size_in_byte;
unsigned char *data =
getsectiondata(&_mh_dylib_header, "binary", "param", &size);
float *data_ptr = (float *)data;
void *buffer = malloc(size);
memcpy(buffer, data, size);
return data;
}
#elif __linux__
extern char _binary_param_bin_start;
extern char _binary_param_bin_end;
void *getEmbeddedConstPool(int64_t _) {
checkEndianness();
auto size = (unsigned int)(&_binary_param_bin_end - &_binary_param_bin_start);
void *buffer = malloc(size);
memcpy(buffer, &_binary_param_bin_start, size);
return buffer;
}
#else
extern char constPackFileName[];
extern int64_t constPackFileNameStrLen;
void *getEmbeddedConstPool(int64_t _) {
checkEndianness();
char *fname = (char *)calloc(1, constPackFileNameStrLen + 1);
memcpy(fname, constPackFileName, constPackFileNameStrLen);
// Adapted from https://stackoverflow.com/a/22059317 .
FILE *fileptr;
char *buffer;
long filelen;
fileptr = fopen(fname, "rb"); // Open the file in binary mode
fseek(fileptr, 0, SEEK_END); // Jump to the end of the file
filelen = ftell(fileptr); // Get the current byte offset in the file
rewind(fileptr); // Jump back to the beginning of the file
buffer = (char *)malloc(filelen * sizeof(char)); // Enough memory for the file
fread(buffer, filelen, 1, fileptr); // Read in the entire file
fclose(fileptr); // Close the file
return (void *)buffer;
}
#endif

View File

@ -0,0 +1,18 @@
//===---- GetEmbeddedConstPool.h - Get Embedded Const Pool API Func Decl---===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains runtime API declarations to extract constant pool values
// embedded within the shared library binary files.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <stdint.h>
extern "C" {
void *getEmbeddedConstPool(int64_t size_in_byte);
}

View File

@ -0,0 +1,95 @@
//===----- BinaryDecoder.cpp - Decode binary files into typed arrays ------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains implementation of a utility called BinaryDecoder, which
// decodes a sequence of binary data within a binary file specified by an
// offset and a length into a typed array and print to stdout.
//
//===----------------------------------------------------------------------===//
#include <fstream>
#include <iostream>
#include <string>
#include "onnx/onnx_pb.h"
#include <llvm/Support/CommandLine.h>
#if defined(_WIN32)
#include <stdint.h>
typedef uint8_t u_int8_t;
typedef uint16_t u_int16_t;
typedef uint32_t u_int32_t;
typedef uint64_t u_int64_t;
#endif
llvm::cl::opt<std::string> Filename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::Required);
llvm::cl::opt<int64_t> Start("s",
llvm::cl::desc("Specify the index of the starting byte"),
llvm::cl::value_desc("start"), llvm::cl::Required);
llvm::cl::opt<int64_t> Size("n",
llvm::cl::desc("Specify the number of bytes of data to decode"),
llvm::cl::value_desc("size"), llvm::cl::Required);
llvm::cl::opt<bool> Remove(
"rm", llvm::cl::desc(
"Whether to remove the file being decoded after inspection."));
llvm::cl::opt<onnx::TensorProto::DataType> DataType(
llvm::cl::desc("Choose data type to decode:"),
llvm::cl::values(clEnumVal(onnx::TensorProto::FLOAT, "FLOAT"),
clEnumVal(onnx::TensorProto::UINT8, "UINT8"),
clEnumVal(onnx::TensorProto::INT8, "INT8"),
clEnumVal(onnx::TensorProto::UINT16, "UINT16"),
clEnumVal(onnx::TensorProto::INT16, "INT16"),
clEnumVal(onnx::TensorProto::INT32, "INT32"),
clEnumVal(onnx::TensorProto::INT64, "INT64"),
clEnumVal(onnx::TensorProto::STRING, "STRING"),
clEnumVal(onnx::TensorProto::BOOL, "BOOL"),
clEnumVal(onnx::TensorProto::FLOAT16, "FLOAT16"),
clEnumVal(onnx::TensorProto::DOUBLE, "DOUBLE"),
clEnumVal(onnx::TensorProto::UINT32, "UINT32"),
clEnumVal(onnx::TensorProto::UINT64, "UINT64")));
template <typename T>
int printBuffer(std::vector<char> buffer) {
auto *ptr = (T *)&buffer[0];
auto data = std::vector<T>(ptr, ptr + buffer.size() / sizeof(T));
for (const auto &elem : data)
std::cout << elem << " ";
return 0;
}
int main(int argc, char **argv) {
llvm::cl::ParseCommandLineOptions(argc, argv);
std::vector<char> buffer(Size);
std::ifstream file(Filename, std::ios::in | std::ios::binary);
if (!file)
return -1;
file.seekg(Start, file.beg);
file.read(&buffer[0], Size);
file.close();
if (Remove)
llvm::sys::fs::remove(Filename);
#define PRINT_BUFFER_FOR_TYPE(ONNX_TYPE, CPP_TYPE) \
if (DataType == ONNX_TYPE) \
return printBuffer<CPP_TYPE>(buffer);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT8, u_int8_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT16, u_int16_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT16, int16_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT32, int32_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT64, int64_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::FLOAT, float);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::DOUBLE, double);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT32, u_int32_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT64, u_int64_t);
}

View File

@ -0,0 +1,9 @@
add_executable(binary-decoder BinaryDecoder.cpp)
target_link_libraries(binary-decoder
${LLVMSupport}
${LLVMDemangle}
${CURSES_LIBRARIES}
${CMAKE_THREAD_LIBS_INIT})
message(STATUS "incl dir" ${ONNX_INCLUDE_DIRS})
target_link_libraries(binary-decoder onnx)

View File

@ -1 +1,2 @@
add_subdirectory(ONNXMLIROpt)
add_subdirectory(BinaryDecoder)

View File

@ -27,7 +27,8 @@ target_link_libraries(OMKrnlToLLVM
onnx)
add_library(OMElideKrnlGlobalConstants
ElideKrnlGlobalConstants.cpp)
ElideKrnlGlobalConstants.cpp
ElideKrnlGlobalConstants.hpp)
target_include_directories(OMElideKrnlGlobalConstants
PRIVATE
${ONNX_MLIR_SRC_ROOT}
@ -38,6 +39,14 @@ add_dependencies(OMElideKrnlGlobalConstants OMKrnlOpsInc)
# Linking dependencies
add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps)
add_library(OMPackKrnlGlobalConstants
PackKrnlGlobalConstants.cpp)
target_include_directories(OMPackKrnlGlobalConstants
PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
add_library(OMEnableMemoryPool
EnableMemoryPool.cpp)
target_include_directories(OMEnableMemoryPool
@ -45,6 +54,9 @@ target_include_directories(OMEnableMemoryPool
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
add_dependencies(OMPackKrnlGlobalConstants
OMKrnlOps)
target_link_libraries(OMEnableMemoryPool
onnx)
add_dependencies(OMEnableMemoryPool

View File

@ -22,34 +22,31 @@
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp"
#include "ElideKrnlGlobalConstants.hpp"
using namespace mlir;
namespace {
const int64_t KrnlConstGlobalValueElision::kDefaultElisionThreshold = 32;
/*!
* RewritePattern that replaces existing constant Krnl global values
* with a similar operation which preserves all attributes except the value
* attribute.
*/
class KrnlConstGlobalValueElision : public OpRewritePattern<KrnlGlobalOp> {
public:
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
KrnlGlobalOp op, PatternRewriter &rewriter) const override {
mlir::LogicalResult KrnlConstGlobalValueElision::matchAndRewrite(
mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const {
auto loc = op.getLoc();
if (op.value().hasValue()) {
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(
loc, op.getResult().getType(), op.shape(), op.name(), nullptr);
const auto &valAttr = op.valueAttr().dyn_cast_or_null<DenseElementsAttr>();
if (valAttr.getNumElements() > elisionThreshold) {
IntegerAttr offsetAttr = op.offset() ? op.offsetAttr() : nullptr;
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(loc,
op.getResult().getType(), /*shape=*/op.shape(),
/*name=*/op.name(), /*value=*/nullptr, /*offset=*/offsetAttr);
rewriter.replaceOp(op, newGlobalOp.getResult());
}
}
return success();
}
};
namespace {
/*!
* Function pass that performs constant value elision of Krnl globals.
*/
@ -66,6 +63,7 @@ public:
applyPatternsAndFoldGreedily(function, patterns);
}
};
} // namespace
std::unique_ptr<Pass> mlir::createElideConstGlobalValuePass() {

View File

@ -0,0 +1,35 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/Krnl/KrnlOps.hpp"
/*!
* RewritePattern that replaces existing constant Krnl global values
* with a similar operation which preserves all attributes except the value
* attribute.
*/
class KrnlConstGlobalValueElision
: public mlir::OpRewritePattern<mlir::KrnlGlobalOp> {
public:
/*
* A threshold value specifying the maximum number of elements a constant
* operation can hold as an attribute. If the number exceeds this threshold,
* constants will be packed together and, in the case where `move-to-file`
* option is enabled, stored as a binary file on disk. This can help preserve
* readability of IR dump and improve compilation speed.
*/
static const int64_t kDefaultElisionThreshold;
int64_t elisionThreshold;
using mlir::OpRewritePattern<mlir::KrnlGlobalOp>::OpRewritePattern;
explicit KrnlConstGlobalValueElision(
mlir::MLIRContext *context, int64_t elisionThreshold)
: OpRewritePattern(context), elisionThreshold(elisionThreshold) {}
mlir::LogicalResult matchAndRewrite(
mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const override;
};

View File

@ -32,10 +32,9 @@ namespace {
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
auto symbolRef = SymbolRefAttr::get(funcName, context);
assert(symbolRef.getType() == funcType && "wrong symbol type");
return symbolRef;
if (auto sym = module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
assert(sym.getType() == funcType && "wrong symbol type");
return SymbolRefAttr::get(funcName, context);
}
// Insert the function into the body of the parent module.
@ -203,6 +202,8 @@ public:
// The llvm type of the global (example: [2 x [8 x float]])
auto llvmGlobalType = globalType.cast<LLVM::LLVMType>();
mlir::Value alloc;
if (krnlGlobalOp.value().hasValue()) {
{
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
@ -222,7 +223,7 @@ public:
// This is a region of local memory and needs to be emitted as an alloca.
auto one = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
auto alloc = rewriter.create<LLVM::AllocaOp>(
alloc = rewriter.create<LLVM::AllocaOp>(
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
// Copy constant value into the local alloca:
@ -234,12 +235,13 @@ public:
Value i8PtrGlobal =
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, globalValue);
// - Set size.
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
Value memRefElementSize =
rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy)));
Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
Value totalElementsSize =
rewriter.create<LLVM::MulOp>(loc, memRefElementSize, numElementsValue);
Value totalElementsSize = rewriter.create<LLVM::MulOp>(
loc, memRefElementSize, numElementsValue);
Value int64Size =
rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
// - Set volatile.
@ -251,7 +253,29 @@ public:
rewriter.create<CallOp>(loc, memcpyRef,
LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
} else {
// Some frequently used types.
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Allocate the memory where the constants will be used from.
// This is a region of local memory and needs to be emitted as an alloca.
auto one = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
auto base = module.lookupSymbol<LLVM::GlobalOp>("packedConst");
assert(base && "Cannot find symbol packedConst.");
Value constPackBasePtrAddr =
rewriter.create<LLVM::AddressOfOp>(loc, base);
Value constPackBasePtr = rewriter.create<LLVM::LoadOp>(
loc, base.getType(), constPackBasePtrAddr);
auto offset = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
rewriter.getI64IntegerAttr(
krnlGlobalOp.offsetAttr().getValue().getSExtValue()));
alloc = rewriter.create<LLVM::GEPOp>(
loc, llvmI8PtrTy, constPackBasePtr, ValueRange({offset}));
}
// Prepare data to be inserted into MemRef.
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
@ -681,6 +705,115 @@ private:
}
}
};
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlPackedConstOpLowering
//===----------------------------------------------------------------------===//
class KrnlPackedConstOpLowering : public ConvertToLLVMPattern {
public:
explicit KrnlPackedConstOpLowering(
MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(
KrnlPackedConstantOp::getOperationName(), context, lowering_) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
ModuleOp module = op->getParentOfType<ModuleOp>();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
auto packedConstOp = llvm::dyn_cast<KrnlPackedConstantOp>(op);
LLVM::GlobalOp globalBase;
// Some frequently used types.
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
{
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
globalBase = rewriter.create<LLVM::GlobalOp>(loc, llvmI8PtrTy,
/*isConstant=*/false, LLVM::Linkage::Internal, "packedConst",
nullptr);
}
auto mainFunc = module.lookupSymbol<FuncOp>("main_graph");
assert(mainFunc);
rewriter.setInsertionPoint(
&mainFunc.getBody().front(), mainFunc.getBody().front().begin());
// - Initialize the global constant base.
Value basePtrAddr = rewriter.create<LLVM::AddressOfOp>(loc, globalBase);
auto getEmbeddedConstPoolRef = getOrInsertExternFunc(
KrnlPackedConstantOp::getEmbeddedDataLoaderMethodName(), module,
LLVM::LLVMType::getFunctionTy(
llvmI8PtrTy, {llvmI64Ty}, /*isVarArg=*/false),
rewriter);
auto constPackSize = rewriter.create<LLVM::ConstantOp>(loc,
LLVM::LLVMType::getInt64Ty(llvmDialect),
packedConstOp.size_in_bytesAttr());
Value alloc = rewriter
.create<CallOp>(loc, getEmbeddedConstPoolRef, llvmI8PtrTy,
ArrayRef<Value>({constPackSize}))
.getResult(0);
rewriter.create<LLVM::StoreOp>(loc, alloc, basePtrAddr);
{
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
// Record constant pack *file path* as a global variable (by recording the
// file path string's underlying char array + its length).
const auto &fileNameAttr = packedConstOp.file_nameAttr();
auto type =
LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
fileNameAttr.getValue().size());
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName(),
fileNameAttr);
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFilePathStrLenSymbolName(),
rewriter.getI64IntegerAttr(fileNameAttr.getValue().size()));
// Record constant pack *file name* as a global variable (by recording the
// file name string's underlying char array + its length).
auto constPackFileName =
llvm::sys::path::filename(fileNameAttr.getValue());
type = LLVM::LLVMType::getArrayTy(
LLVM::LLVMType::getInt8Ty(llvmDialect), constPackFileName.size());
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName(),
rewriter.getStringAttr(constPackFileName));
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName(),
rewriter.getI64IntegerAttr(constPackFileName.size()));
type = LLVM::LLVMType::getInt8Ty(llvmDialect);
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackIsLESymbolName(),
rewriter.getI8IntegerAttr(packedConstOp.is_le()));
}
rewriter.eraseOp(op);
return success();
}
private:
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
}
};
} // end namespace
//===----------------------------------------------------------------------===//
@ -712,7 +845,8 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
/*emitCWrapperS=*/true,
/*useAlignedAlloc=*/false);
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
&getContext(), typeConverter);
patterns.insert<KrnlGetRefOpLowering>(&getContext(), typeConverter);
// Lower from the `krnl` dialect i.e. the Reshape operation.
@ -722,9 +856,10 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
if (failed(applyFullConversion(
getOperation(), target, patterns, &typeConverter)))
getOperation(), target, patterns, &typeConverter))) {
signalPassFailure();
}
}
/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.
std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {

View File

@ -50,8 +50,9 @@ public:
int64_t dynamicOperations = 0;
f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op))
if (returnsDynamicShape(op)) {
dynamicOperations++;
}
});
// If any dynamic operations remain, this indicates a failure.

View File

@ -0,0 +1,127 @@
//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// In practice, the constant values of Global Krnl operations may be large
// enough to hinder the readability of the MLIR intermediate representation.
//
// This file creates a pass which elides the explicit values of constant
// global operations. This pass has purely cosmetic purposes and should only be
// run to obtain a compact representation of the program when emitting Krnl
// dialect code. This pass should never be invoked on code meant to be run.
//
//===----------------------------------------------------------------------===//
#include <fstream>
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/FileSystem.h"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Transform/ElideKrnlGlobalConstants.hpp"
using namespace mlir;
namespace {
/*!
* Function pass that performs constant value elision of Krnl globals.
*/
class PackKrnlGlobalConstantsPass
: public PassWrapper<PackKrnlGlobalConstantsPass, OperationPass<ModuleOp>> {
public:
/// Make sure that we have a valid default constructor and copy constructor to
/// make sure that the options are initialized properly.
PackKrnlGlobalConstantsPass() = default;
PackKrnlGlobalConstantsPass(const PackKrnlGlobalConstantsPass &pass) {}
void runOnOperation() override {
auto module = getOperation();
OpBuilder builder(&getContext());
// Packing constant arrays to packedConst.
std::vector<char> packedConst;
module.walk([&](KrnlGlobalOp op) {
assert(op.value());
op.offsetAttr(builder.getI64IntegerAttr(packedConst.size()));
assert(op.value()->isa<DenseElementsAttr>());
const auto &denseAttr = op.valueAttr().cast<DenseElementsAttr>();
auto numElements = denseAttr.getNumElements();
if (numElements <= elisionThreshold)
return;
// TODO(tjingrant) verify we can actually use the raw data.
std::vector<char> rawData = denseAttr.getRawData();
packedConst.insert(packedConst.end(), rawData.begin(), rawData.end());
});
// Remove value attributes from krnl constant op.
ConversionTarget target(getContext());
OwningRewritePatternList patterns;
patterns.insert<KrnlConstGlobalValueElision>(
&getContext(), elisionThreshold);
// Apply constant value elision.
module.walk(
[&](FuncOp func) { applyPatternsAndFoldGreedily(func, patterns); });
bool isLE = llvm::support::endian::system_endianness() ==
llvm::support::endianness::little;
mlir::OperationState state(module.getLoc(), "krnl.packed_const");
KrnlPackedConstantOp::build(builder, state,
builder.getIntegerType(/*width=*/64),
/*size_in_bytes=*/builder.getI64IntegerAttr(packedConst.size()),
/*is_le=*/builder.getBoolAttr(isLE),
/*value=*/nullptr,
/*file_name=*/nullptr);
auto packedConstOp =
llvm::cast<mlir::KrnlPackedConstantOp>(mlir::Operation::create(state));
module.insert(module.begin(), packedConstOp);
if (moveToFile) {
std::string pathStr;
if (filename.hasValue()) {
pathStr = filename.getValue();
} else {
llvm::SmallVector<char, 10> path;
llvm::sys::fs::createTemporaryFile("packed_const", "tmp", path);
pathStr = std::string(path.begin(), path.end());
}
packedConstOp.file_nameAttr(builder.getStringAttr(pathStr));
std::ofstream outfile(pathStr, std::ofstream::binary);
outfile.write(packedConst.data(), packedConst.size());
} else {
auto shapeTy =
RankedTensorType::get({static_cast<int64_t>(packedConst.size())},
builder.getIntegerType(8));
auto denseAttr =
DenseIntElementsAttr::get(shapeTy, llvm::makeArrayRef(packedConst));
packedConstOp.valueAttr(denseAttr);
}
}
Option<bool> moveToFile{*this, "move-to-file",
llvm::cl::desc("Whether to move the packed constant to a file."),
llvm::cl::init(true)};
Option<int64_t> elisionThreshold{*this, "elision-threshold",
llvm::cl::desc(
"A threshold value specifying the maximum number of elements a "
"constant operation can hold as an attribute. If the number exceeds "
"this threshold, constants will be packed together and, in the case "
"where `move-to-file` option is enabled, stored as a binary file on "
"disk. This can help preserve readability of IR dump and improve "
"compilation speed."),
llvm::cl::init(KrnlConstGlobalValueElision::kDefaultElisionThreshold)};
Option<std::string> filename{*this, "filename",
llvm::cl::desc(
"Specify a file in which the packed constant is to be stored.")};
};
} // namespace
std::unique_ptr<Pass> mlir::createPackKrnlGlobalConstantsPass() {
return std::make_unique<PackKrnlGlobalConstantsPass>();
}
static PassRegistration<PackKrnlGlobalConstantsPass> pass("pack-krnl-constants",
"Elide the constant values of the Global Krnl operations.");

View File

@ -6,6 +6,7 @@ from __future__ import unicode_literals
import os
import sys
import unittest
import warnings
import onnx.backend.base
import onnx.backend.test
@ -29,7 +30,57 @@ from PyRuntime import ExecutionSession
def execute_commands(cmds):
if (VERBOSE):
print(" ".join(cmds))
subprocess.run(cmds, stdout=subprocess.PIPE)
subprocess.run(cmds)
# There are two issues, which necessitates the adoption of this endianness
# aware wrapper around Execution Session:
# 1. Input arrays are given sometimes in native byte order, sometime in
# LE byte order, and as soon as the python array enters into py::array
# C++ objects through pybind, we will no longer be able to query their
# endianness. So we must intercept the inputs and convert them into
# native endianness.
# 2. Output arrays are compared with reference outputs, the comparison
# unfortunately includes checking that our outputs and reference outputs
# share the same endianness. So we try to figure out what is the desired
# reference output endianness, and convert our outputs to this desired
# endianness.
class EndiannessAwareExecutionSession(ExecutionSession):
def __init__(self, path, entry_point):
super().__init__(path, entry_point)
def is_input_le(self, inputs):
inputs_endianness = list(map(lambda x: x.dtype.byteorder, inputs))
endianness_is_consistent = len(set(inputs_endianness)) <= 1
assert endianness_is_consistent, \
"Input arrays contain a mixture of endianness configuration."
sys_is_le = sys.byteorder == 'little'
# To interpret character symbols indicating endianness:
# https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html
explicitly_le = inputs_endianness[0] == "<"
implicitly_le = (inputs_endianness[0] == "=" and sys_is_le)
return explicitly_le or implicitly_le
def run(self, inputs, **kwargs):
if len(inputs):
# Deduce desired endianness of output from inputs.
sys_is_le = sys.byteorder == 'little'
inp_is_le = self.is_input_le(inputs)
if (sys_is_le != inp_is_le):
inputs = list(
map(lambda x: x.byteswap().newbyteorder(), inputs))
outputs = super().run(inputs)
if (sys_is_le != inp_is_le):
outputs = list(
map(lambda x: x.byteswap().newbyteorder(), outputs))
return outputs
else:
# Can't deduce desired output endianess, fingers crossed.
warnings.warn(
"Cannot deduce desired output endianness, using native endianness by default."
)
return super().run(inputs)
class DummyBackend(onnx.backend.base.Backend):
@ -40,7 +91,8 @@ class DummyBackend(onnx.backend.base.Backend):
onnx.save(model, "temp_model.onnx")
# Call frontend to process temp_model.onnx, bit code will be generated.
execute_commands([ONNX_MLIR, "temp_model.onnx"])
return ExecutionSession("./temp_model.so", "_dyn_entry_point_main_graph")
return EndiannessAwareExecutionSession("./temp_model.so",
"_dyn_entry_point_main_graph")
@classmethod
def supports_device(cls, device):
@ -80,7 +132,6 @@ test_to_enable = [
"test_concat_3d_axis_0_cpu",
"test_concat_3d_axis_1_cpu",
"test_concat_3d_axis_2_cpu",
"test_concat_1d_axis_negative_1_cpu",
"test_concat_2d_axis_negative_1_cpu",
"test_concat_2d_axis_negative_2_cpu",
@ -359,12 +410,18 @@ test_to_enable = [
"test_split_variable_parts_1d_cpu",
"test_split_variable_parts_2d_cpu",
"test_split_variable_parts_default_axis_cpu",
# ResNet
"test_resnet50_cpu",
]
# Extract name of all test cases.
import inspect
all_tests = inspect.getmembers(
all_tests = []
all_tests += inspect.getmembers(
backend_test.test_cases["OnnxBackendRealModelTest"])
all_tests += inspect.getmembers(
backend_test.test_cases["OnnxBackendNodeModelTest"])
all_test_names = list(map(lambda x: x[0], all_tests))
@ -372,8 +429,7 @@ all_test_names = list(map(lambda x: x[0], all_tests))
for test_name in test_to_enable:
assert test_name in all_test_names, """test name {} not found, it is likely
that you may have misspelled the test name or the specified test does not
exist in the version of onnx package you installed.""".format(
test_name)
exist in the version of onnx package you installed.""".format(test_name)
backend_test.include(r"^{}$".format(test_name))
# import all test cases at global scope to make them visible to python.unittest

View File

@ -6,7 +6,7 @@ configure_lit_site_cfg(${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
MAIN_CONFIG
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py)
set(ONNX_MLIR_TEST_DEPENDS onnx-mlir-opt)
set(ONNX_MLIR_TEST_DEPENDS onnx-mlir-opt binary-decoder)
add_lit_testsuite(check-onnx-lit
"Running the ONNX MLIR regression tests"

View File

@ -0,0 +1,10 @@
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test1.bin' %s -split-input-file && binary-decoder test1.bin -s 0 -n 16 --onnx::TensorProto::FLOAT -rm | FileCheck %s -check-prefix=BINARY_DECODER_1
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test2.bin' %s -split-input-file && binary-decoder test2.bin -s 16 -n 32 --onnx::TensorProto::INT32 -rm | FileCheck %s -check-prefix=BINARY_DECODER_2
// BINARY_DECODER_1: 0.1 0.2 0.3 0.4
// BINARY_DECODER_2: 1 2 3 4
func @test_krnl_const_packing_file_mixing_types() -> memref<1x4xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0.1, 0.2, 0.3, 0.4]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[1, 2, 3, 4]]> : tensor<1x4xi32>} : () -> memref<1x4xi32>
return %0 : memref<1x4xf32>
}

View File

@ -0,0 +1,8 @@
import sys
root = config.root
if sys.byteorder == "little":
config.unsupported = True
else:
config.unsupported = False

View File

@ -0,0 +1,11 @@
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=false' %s -split-input-file | FileCheck %s
// CHECK: [[CONST_PACK:%.+]] = "krnl.packed_const"() {is_le = false, size_in_bytes = 32 : i64, value = dense<[0, 0, 0, 0, 63, -128, 0, 0, 64, 0, 0, 0, 64, 64, 0, 0, 0, 0, 0, 0, 63, -128, 0, 0, 64, 0, 0, 0, 64, 64, 0, 0]> : tensor<32xi8>} : () -> i64
// CHECK-LABEL: func @test_krnl_const_packing() -> memref<1x4xf32> {
// CHECK-NEXT: [[CONST0:%.+]] = "krnl.global"() {name = "constant_0", offset = 0 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
// CHECK-NEXT: [[CONST1:%.+]] = "krnl.global"() {name = "constant_1", offset = 16 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
func @test_krnl_const_packing() -> memref<1x4xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
return %0 : memref<1x4xf32>
}

View File

@ -0,0 +1,8 @@
import sys
root = config.root
if sys.byteorder == "little":
config.unsupported = False
else:
config.unsupported = True

View File

@ -0,0 +1,11 @@
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=false' %s -split-input-file | FileCheck %s
// CHECK: [[CONST_PACK:%.+]] = "krnl.packed_const"() {is_le = true, size_in_bytes = 32 : i64, value = dense<[0, 0, 0, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 0, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64]> : tensor<32xi8>} : () -> i64
// CHECK-LABEL: func @test_krnl_const_packing() -> memref<1x4xf32> {
// CHECK-NEXT: [[CONST0:%.+]] = "krnl.global"() {name = "constant_0", offset = 0 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
// CHECK-NEXT: [[CONST1:%.+]] = "krnl.global"() {name = "constant_1", offset = 16 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
func @test_krnl_const_packing() -> memref<1x4xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
return %0 : memref<1x4xf32>
}

View File

@ -0,0 +1,8 @@
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test-pack-consts-to-file-same-type.bin' %s -split-input-file && binary-decoder test-pack-consts-to-file-same-type.bin -s 0 -n 32 --onnx::TensorProto::FLOAT -rm | FileCheck %s
// CHECK: 0 1 2 3 0 1 2 3
func @test_krnl_const_packing_file() -> memref<1x4xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
return %0 : memref<1x4xf32>
}

View File

@ -31,7 +31,7 @@ tool_dirs = [
config.onnx_mlir_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir
]
tool_names = [
'onnx-mlir-opt', 'mlir-opt', 'mlir-translate'
'onnx-mlir-opt', 'mlir-opt', 'mlir-translate', "binary-decoder"
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -1,4 +1,3 @@
import lit.llvm
if '@BUILD_SHARED_LIBS@' == 'ON':

View File

@ -1,10 +1,10 @@
// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s
// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32>
func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> memref<1x10xf32>
return %0 : memref<1x10xf32>
// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32>
func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 70], value = dense<[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]> : tensor<1x70xf32>} : () -> memref<1x70xf32>
return %0 : memref<1x70xf32>
// CHECK: %0 = "krnl.global"() {name = "constant_0", shape = [1, 10]} : () -> memref<1x10xf32>
// CHECK: return %0 : memref<1x10xf32>
// CHECK: {{.*}} = "krnl.global"() {name = "constant_0", shape = [1, 70]} : () -> memref<1x70xf32>
// CHECK: return {{.*}} : memref<1x70xf32>
}

View File

@ -37,7 +37,7 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
llvm::SmallVector<Type, 1> outputsType{yType};
auto funcType = builder.getFunctionType(inputsType, outputsType);
string funcName = "test_conv";
string funcName = "main_graph";
llvm::SmallVector<NamedAttribute, 1> attrs;
auto funcOp =
builder.create<FuncOp>(UnknownLoc::get(&ctx), funcName, funcType, attrs);
@ -88,13 +88,13 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
OwningModuleRef moduleRef(module);
llvm::SmallVector<char, 10> path;
llvm::sys::fs::createTemporaryFile("_test_conv", "", path);
llvm::sys::fs::createTemporaryFile("_main_graph", "", path);
string pathStr(path.begin(), path.end());
llvm::FileRemover remover(path);
compileModule(moduleRef, ctx, pathStr, EmitLib);
onnx_mlir::ExecutionSession sess(
pathStr + ".so", "_dyn_entry_point_test_conv");
pathStr + ".so", "_dyn_entry_point_main_graph");
std::vector<unique_ptr<DynMemRef>> inputs;
auto xDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({N, C, H, W}));